1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 // Copyright (c) 2019 Cloudflare
4
5 #include <limits.h>
6 #include <string.h>
7 #include <stdlib.h>
8 #include <unistd.h>
9
10 #include <arpa/inet.h>
11 #include <netinet/in.h>
12 #include <sys/types.h>
13 #include <sys/socket.h>
14
15 #include <bpf/bpf.h>
16 #include <bpf/libbpf.h>
17
18 #include "cgroup_helpers.h"
19
start_server(const struct sockaddr * addr,socklen_t len,bool dual)20 static int start_server(const struct sockaddr *addr, socklen_t len, bool dual)
21 {
22 int mode = !dual;
23 int fd;
24
25 fd = socket(addr->sa_family, SOCK_STREAM, 0);
26 if (fd == -1) {
27 log_err("Failed to create server socket");
28 goto out;
29 }
30
31 if (addr->sa_family == AF_INET6) {
32 if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (char *)&mode,
33 sizeof(mode)) == -1) {
34 log_err("Failed to set the dual-stack mode");
35 goto close_out;
36 }
37 }
38
39 if (bind(fd, addr, len) == -1) {
40 log_err("Failed to bind server socket");
41 goto close_out;
42 }
43
44 if (listen(fd, 128) == -1) {
45 log_err("Failed to listen on server socket");
46 goto close_out;
47 }
48
49 goto out;
50
51 close_out:
52 close(fd);
53 fd = -1;
54 out:
55 return fd;
56 }
57
connect_to_server(const struct sockaddr * addr,socklen_t len)58 static int connect_to_server(const struct sockaddr *addr, socklen_t len)
59 {
60 int fd = -1;
61
62 fd = socket(addr->sa_family, SOCK_STREAM, 0);
63 if (fd == -1) {
64 log_err("Failed to create client socket");
65 goto out;
66 }
67
68 if (connect(fd, (const struct sockaddr *)addr, len) == -1) {
69 log_err("Fail to connect to server");
70 goto close_out;
71 }
72
73 goto out;
74
75 close_out:
76 close(fd);
77 fd = -1;
78 out:
79 return fd;
80 }
81
get_map_fd_by_prog_id(int prog_id,bool * xdp)82 static int get_map_fd_by_prog_id(int prog_id, bool *xdp)
83 {
84 struct bpf_prog_info info = {};
85 __u32 info_len = sizeof(info);
86 __u32 map_ids[1];
87 int prog_fd = -1;
88 int map_fd = -1;
89
90 prog_fd = bpf_prog_get_fd_by_id(prog_id);
91 if (prog_fd < 0) {
92 log_err("Failed to get fd by prog id %d", prog_id);
93 goto err;
94 }
95
96 info.nr_map_ids = 1;
97 info.map_ids = (__u64)(unsigned long)map_ids;
98
99 if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) {
100 log_err("Failed to get info by prog fd %d", prog_fd);
101 goto err;
102 }
103
104 if (!info.nr_map_ids) {
105 log_err("No maps found for prog fd %d", prog_fd);
106 goto err;
107 }
108
109 *xdp = info.type == BPF_PROG_TYPE_XDP;
110
111 map_fd = bpf_map_get_fd_by_id(map_ids[0]);
112 if (map_fd < 0)
113 log_err("Failed to get fd by map id %d", map_ids[0]);
114 err:
115 if (prog_fd >= 0)
116 close(prog_fd);
117 return map_fd;
118 }
119
run_test(int server_fd,int results_fd,bool xdp,const struct sockaddr * addr,socklen_t len)120 static int run_test(int server_fd, int results_fd, bool xdp,
121 const struct sockaddr *addr, socklen_t len)
122 {
123 int client = -1, srv_client = -1;
124 int ret = 0;
125 __u32 key = 0;
126 __u32 key_gen = 1;
127 __u32 key_mss = 2;
128 __u32 value = 0;
129 __u32 value_gen = 0;
130 __u32 value_mss = 0;
131
132 if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) {
133 log_err("Can't clear results");
134 goto err;
135 }
136
137 if (bpf_map_update_elem(results_fd, &key_gen, &value_gen, 0) < 0) {
138 log_err("Can't clear results");
139 goto err;
140 }
141
142 if (bpf_map_update_elem(results_fd, &key_mss, &value_mss, 0) < 0) {
143 log_err("Can't clear results");
144 goto err;
145 }
146
147 client = connect_to_server(addr, len);
148 if (client == -1)
149 goto err;
150
151 srv_client = accept(server_fd, NULL, 0);
152 if (srv_client == -1) {
153 log_err("Can't accept connection");
154 goto err;
155 }
156
157 if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) {
158 log_err("Can't lookup result");
159 goto err;
160 }
161
162 if (value == 0) {
163 log_err("Didn't match syncookie: %u", value);
164 goto err;
165 }
166
167 if (bpf_map_lookup_elem(results_fd, &key_gen, &value_gen) < 0) {
168 log_err("Can't lookup result");
169 goto err;
170 }
171
172 if (xdp && value_gen == 0) {
173 // SYN packets do not get passed through generic XDP, skip the
174 // rest of the test.
175 printf("Skipping XDP cookie check\n");
176 goto out;
177 }
178
179 if (bpf_map_lookup_elem(results_fd, &key_mss, &value_mss) < 0) {
180 log_err("Can't lookup result");
181 goto err;
182 }
183
184 if (value != value_gen) {
185 log_err("BPF generated cookie does not match kernel one");
186 goto err;
187 }
188
189 if (value_mss < 536 || value_mss > USHRT_MAX) {
190 log_err("Unexpected MSS retrieved");
191 goto err;
192 }
193
194 goto out;
195
196 err:
197 ret = 1;
198 out:
199 close(client);
200 close(srv_client);
201 return ret;
202 }
203
get_port(int server_fd,in_port_t * port)204 static bool get_port(int server_fd, in_port_t *port)
205 {
206 struct sockaddr_in addr;
207 socklen_t len = sizeof(addr);
208
209 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
210 log_err("Failed to get server addr");
211 return false;
212 }
213
214 /* sin_port and sin6_port are located at the same offset. */
215 *port = addr.sin_port;
216 return true;
217 }
218
main(int argc,char ** argv)219 int main(int argc, char **argv)
220 {
221 struct sockaddr_in addr4;
222 struct sockaddr_in6 addr6;
223 struct sockaddr_in addr4dual;
224 struct sockaddr_in6 addr6dual;
225 int server = -1;
226 int server_v6 = -1;
227 int server_dual = -1;
228 int results = -1;
229 int err = 0;
230 bool xdp;
231
232 if (argc < 2) {
233 fprintf(stderr, "Usage: %s prog_id\n", argv[0]);
234 exit(1);
235 }
236
237 /* Use libbpf 1.0 API mode */
238 libbpf_set_strict_mode(LIBBPF_STRICT_ALL);
239
240 results = get_map_fd_by_prog_id(atoi(argv[1]), &xdp);
241 if (results < 0) {
242 log_err("Can't get map");
243 goto err;
244 }
245
246 memset(&addr4, 0, sizeof(addr4));
247 addr4.sin_family = AF_INET;
248 addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
249 addr4.sin_port = 0;
250 memcpy(&addr4dual, &addr4, sizeof(addr4dual));
251
252 memset(&addr6, 0, sizeof(addr6));
253 addr6.sin6_family = AF_INET6;
254 addr6.sin6_addr = in6addr_loopback;
255 addr6.sin6_port = 0;
256
257 memset(&addr6dual, 0, sizeof(addr6dual));
258 addr6dual.sin6_family = AF_INET6;
259 addr6dual.sin6_addr = in6addr_any;
260 addr6dual.sin6_port = 0;
261
262 server = start_server((const struct sockaddr *)&addr4, sizeof(addr4),
263 false);
264 if (server == -1 || !get_port(server, &addr4.sin_port))
265 goto err;
266
267 server_v6 = start_server((const struct sockaddr *)&addr6,
268 sizeof(addr6), false);
269 if (server_v6 == -1 || !get_port(server_v6, &addr6.sin6_port))
270 goto err;
271
272 server_dual = start_server((const struct sockaddr *)&addr6dual,
273 sizeof(addr6dual), true);
274 if (server_dual == -1 || !get_port(server_dual, &addr4dual.sin_port))
275 goto err;
276
277 if (run_test(server, results, xdp,
278 (const struct sockaddr *)&addr4, sizeof(addr4)))
279 goto err;
280
281 if (run_test(server_v6, results, xdp,
282 (const struct sockaddr *)&addr6, sizeof(addr6)))
283 goto err;
284
285 if (run_test(server_dual, results, xdp,
286 (const struct sockaddr *)&addr4dual, sizeof(addr4dual)))
287 goto err;
288
289 printf("ok\n");
290 goto out;
291 err:
292 err = 1;
293 out:
294 close(server);
295 close(server_v6);
296 close(server_dual);
297 close(results);
298 return err;
299 }
300