1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <test_progs.h>
4 #include "cgroup_helpers.h"
5 #include "network_helpers.h"
6 
verify_ports(int family,int fd,__u16 expected_local,__u16 expected_peer)7 static int verify_ports(int family, int fd,
8 			__u16 expected_local, __u16 expected_peer)
9 {
10 	struct sockaddr_storage addr;
11 	socklen_t len = sizeof(addr);
12 	__u16 port;
13 
14 	if (getsockname(fd, (struct sockaddr *)&addr, &len)) {
15 		log_err("Failed to get server addr");
16 		return -1;
17 	}
18 
19 	if (family == AF_INET)
20 		port = ((struct sockaddr_in *)&addr)->sin_port;
21 	else
22 		port = ((struct sockaddr_in6 *)&addr)->sin6_port;
23 
24 	if (ntohs(port) != expected_local) {
25 		log_err("Unexpected local port %d, expected %d", ntohs(port),
26 			expected_local);
27 		return -1;
28 	}
29 
30 	if (getpeername(fd, (struct sockaddr *)&addr, &len)) {
31 		log_err("Failed to get peer addr");
32 		return -1;
33 	}
34 
35 	if (family == AF_INET)
36 		port = ((struct sockaddr_in *)&addr)->sin_port;
37 	else
38 		port = ((struct sockaddr_in6 *)&addr)->sin6_port;
39 
40 	if (ntohs(port) != expected_peer) {
41 		log_err("Unexpected peer port %d, expected %d", ntohs(port),
42 			expected_peer);
43 		return -1;
44 	}
45 
46 	return 0;
47 }
48 
run_test(int cgroup_fd,int server_fd,int family,int type)49 static int run_test(int cgroup_fd, int server_fd, int family, int type)
50 {
51 	bool v4 = family == AF_INET;
52 	__u16 expected_local_port = v4 ? 22222 : 22223;
53 	__u16 expected_peer_port = 60000;
54 	struct bpf_prog_load_attr attr = {
55 		.file = v4 ? "./connect_force_port4.o" :
56 			     "./connect_force_port6.o",
57 	};
58 	struct bpf_program *prog;
59 	struct bpf_object *obj;
60 	int xlate_fd, fd, err;
61 	__u32 duration = 0;
62 
63 	err = bpf_prog_load_xattr(&attr, &obj, &xlate_fd);
64 	if (err) {
65 		log_err("Failed to load BPF object");
66 		return -1;
67 	}
68 
69 	prog = bpf_object__find_program_by_title(obj, v4 ?
70 						 "cgroup/connect4" :
71 						 "cgroup/connect6");
72 	if (CHECK(!prog, "find_prog", "connect prog not found\n")) {
73 		err = -EIO;
74 		goto close_bpf_object;
75 	}
76 
77 	err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd, v4 ?
78 			      BPF_CGROUP_INET4_CONNECT :
79 			      BPF_CGROUP_INET6_CONNECT, 0);
80 	if (err) {
81 		log_err("Failed to attach BPF program");
82 		goto close_bpf_object;
83 	}
84 
85 	prog = bpf_object__find_program_by_title(obj, v4 ?
86 						 "cgroup/getpeername4" :
87 						 "cgroup/getpeername6");
88 	if (CHECK(!prog, "find_prog", "getpeername prog not found\n")) {
89 		err = -EIO;
90 		goto close_bpf_object;
91 	}
92 
93 	err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd, v4 ?
94 			      BPF_CGROUP_INET4_GETPEERNAME :
95 			      BPF_CGROUP_INET6_GETPEERNAME, 0);
96 	if (err) {
97 		log_err("Failed to attach BPF program");
98 		goto close_bpf_object;
99 	}
100 
101 	prog = bpf_object__find_program_by_title(obj, v4 ?
102 						 "cgroup/getsockname4" :
103 						 "cgroup/getsockname6");
104 	if (CHECK(!prog, "find_prog", "getsockname prog not found\n")) {
105 		err = -EIO;
106 		goto close_bpf_object;
107 	}
108 
109 	err = bpf_prog_attach(bpf_program__fd(prog), cgroup_fd, v4 ?
110 			      BPF_CGROUP_INET4_GETSOCKNAME :
111 			      BPF_CGROUP_INET6_GETSOCKNAME, 0);
112 	if (err) {
113 		log_err("Failed to attach BPF program");
114 		goto close_bpf_object;
115 	}
116 
117 	fd = connect_to_fd(server_fd, 0);
118 	if (fd < 0) {
119 		err = -1;
120 		goto close_bpf_object;
121 	}
122 
123 	err = verify_ports(family, fd, expected_local_port,
124 			   expected_peer_port);
125 	close(fd);
126 
127 close_bpf_object:
128 	bpf_object__close(obj);
129 	return err;
130 }
131 
test_connect_force_port(void)132 void test_connect_force_port(void)
133 {
134 	int server_fd, cgroup_fd;
135 
136 	cgroup_fd = test__join_cgroup("/connect_force_port");
137 	if (CHECK_FAIL(cgroup_fd < 0))
138 		return;
139 
140 	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 60123, 0);
141 	if (CHECK_FAIL(server_fd < 0))
142 		goto close_cgroup_fd;
143 	CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_STREAM));
144 	close(server_fd);
145 
146 	server_fd = start_server(AF_INET6, SOCK_STREAM, NULL, 60124, 0);
147 	if (CHECK_FAIL(server_fd < 0))
148 		goto close_cgroup_fd;
149 	CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_STREAM));
150 	close(server_fd);
151 
152 	server_fd = start_server(AF_INET, SOCK_DGRAM, NULL, 60123, 0);
153 	if (CHECK_FAIL(server_fd < 0))
154 		goto close_cgroup_fd;
155 	CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_DGRAM));
156 	close(server_fd);
157 
158 	server_fd = start_server(AF_INET6, SOCK_DGRAM, NULL, 60124, 0);
159 	if (CHECK_FAIL(server_fd < 0))
160 		goto close_cgroup_fd;
161 	CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_DGRAM));
162 	close(server_fd);
163 
164 close_cgroup_fd:
165 	close(cgroup_fd);
166 }
167