1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
24 
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28 
29 /* Per-socket status */
30 struct vsock_stat {
31 	struct list_head list;
32 	struct vsock_diag_msg msg;
33 };
34 
sock_type_str(int type)35 static const char *sock_type_str(int type)
36 {
37 	switch (type) {
38 	case SOCK_DGRAM:
39 		return "DGRAM";
40 	case SOCK_STREAM:
41 		return "STREAM";
42 	default:
43 		return "INVALID TYPE";
44 	}
45 }
46 
sock_state_str(int state)47 static const char *sock_state_str(int state)
48 {
49 	switch (state) {
50 	case TCP_CLOSE:
51 		return "UNCONNECTED";
52 	case TCP_SYN_SENT:
53 		return "CONNECTING";
54 	case TCP_ESTABLISHED:
55 		return "CONNECTED";
56 	case TCP_CLOSING:
57 		return "DISCONNECTING";
58 	case TCP_LISTEN:
59 		return "LISTEN";
60 	default:
61 		return "INVALID STATE";
62 	}
63 }
64 
sock_shutdown_str(int shutdown)65 static const char *sock_shutdown_str(int shutdown)
66 {
67 	switch (shutdown) {
68 	case 1:
69 		return "RCV_SHUTDOWN";
70 	case 2:
71 		return "SEND_SHUTDOWN";
72 	case 3:
73 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
74 	default:
75 		return "0";
76 	}
77 }
78 
print_vsock_addr(FILE * fp,unsigned int cid,unsigned int port)79 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
80 {
81 	if (cid == VMADDR_CID_ANY)
82 		fprintf(fp, "*:");
83 	else
84 		fprintf(fp, "%u:", cid);
85 
86 	if (port == VMADDR_PORT_ANY)
87 		fprintf(fp, "*");
88 	else
89 		fprintf(fp, "%u", port);
90 }
91 
print_vsock_stat(FILE * fp,struct vsock_stat * st)92 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
93 {
94 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
95 	fprintf(fp, " ");
96 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
97 	fprintf(fp, " %s %s %s %u\n",
98 		sock_type_str(st->msg.vdiag_type),
99 		sock_state_str(st->msg.vdiag_state),
100 		sock_shutdown_str(st->msg.vdiag_shutdown),
101 		st->msg.vdiag_ino);
102 }
103 
print_vsock_stats(FILE * fp,struct list_head * head)104 static void print_vsock_stats(FILE *fp, struct list_head *head)
105 {
106 	struct vsock_stat *st;
107 
108 	list_for_each_entry(st, head, list)
109 		print_vsock_stat(fp, st);
110 }
111 
find_vsock_stat(struct list_head * head,int fd)112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
113 {
114 	struct vsock_stat *st;
115 	struct stat stat;
116 
117 	if (fstat(fd, &stat) < 0) {
118 		perror("fstat");
119 		exit(EXIT_FAILURE);
120 	}
121 
122 	list_for_each_entry(st, head, list)
123 		if (st->msg.vdiag_ino == stat.st_ino)
124 			return st;
125 
126 	fprintf(stderr, "cannot find fd %d\n", fd);
127 	exit(EXIT_FAILURE);
128 }
129 
check_no_sockets(struct list_head * head)130 static void check_no_sockets(struct list_head *head)
131 {
132 	if (!list_empty(head)) {
133 		fprintf(stderr, "expected no sockets\n");
134 		print_vsock_stats(stderr, head);
135 		exit(1);
136 	}
137 }
138 
check_num_sockets(struct list_head * head,int expected)139 static void check_num_sockets(struct list_head *head, int expected)
140 {
141 	struct list_head *node;
142 	int n = 0;
143 
144 	list_for_each(node, head)
145 		n++;
146 
147 	if (n != expected) {
148 		fprintf(stderr, "expected %d sockets, found %d\n",
149 			expected, n);
150 		print_vsock_stats(stderr, head);
151 		exit(EXIT_FAILURE);
152 	}
153 }
154 
check_socket_state(struct vsock_stat * st,__u8 state)155 static void check_socket_state(struct vsock_stat *st, __u8 state)
156 {
157 	if (st->msg.vdiag_state != state) {
158 		fprintf(stderr, "expected socket state %#x, got %#x\n",
159 			state, st->msg.vdiag_state);
160 		exit(EXIT_FAILURE);
161 	}
162 }
163 
send_req(int fd)164 static void send_req(int fd)
165 {
166 	struct sockaddr_nl nladdr = {
167 		.nl_family = AF_NETLINK,
168 	};
169 	struct {
170 		struct nlmsghdr nlh;
171 		struct vsock_diag_req vreq;
172 	} req = {
173 		.nlh = {
174 			.nlmsg_len = sizeof(req),
175 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
176 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
177 		},
178 		.vreq = {
179 			.sdiag_family = AF_VSOCK,
180 			.vdiag_states = ~(__u32)0,
181 		},
182 	};
183 	struct iovec iov = {
184 		.iov_base = &req,
185 		.iov_len = sizeof(req),
186 	};
187 	struct msghdr msg = {
188 		.msg_name = &nladdr,
189 		.msg_namelen = sizeof(nladdr),
190 		.msg_iov = &iov,
191 		.msg_iovlen = 1,
192 	};
193 
194 	for (;;) {
195 		if (sendmsg(fd, &msg, 0) < 0) {
196 			if (errno == EINTR)
197 				continue;
198 
199 			perror("sendmsg");
200 			exit(EXIT_FAILURE);
201 		}
202 
203 		return;
204 	}
205 }
206 
recv_resp(int fd,void * buf,size_t len)207 static ssize_t recv_resp(int fd, void *buf, size_t len)
208 {
209 	struct sockaddr_nl nladdr = {
210 		.nl_family = AF_NETLINK,
211 	};
212 	struct iovec iov = {
213 		.iov_base = buf,
214 		.iov_len = len,
215 	};
216 	struct msghdr msg = {
217 		.msg_name = &nladdr,
218 		.msg_namelen = sizeof(nladdr),
219 		.msg_iov = &iov,
220 		.msg_iovlen = 1,
221 	};
222 	ssize_t ret;
223 
224 	do {
225 		ret = recvmsg(fd, &msg, 0);
226 	} while (ret < 0 && errno == EINTR);
227 
228 	if (ret < 0) {
229 		perror("recvmsg");
230 		exit(EXIT_FAILURE);
231 	}
232 
233 	return ret;
234 }
235 
add_vsock_stat(struct list_head * sockets,const struct vsock_diag_msg * resp)236 static void add_vsock_stat(struct list_head *sockets,
237 			   const struct vsock_diag_msg *resp)
238 {
239 	struct vsock_stat *st;
240 
241 	st = malloc(sizeof(*st));
242 	if (!st) {
243 		perror("malloc");
244 		exit(EXIT_FAILURE);
245 	}
246 
247 	st->msg = *resp;
248 	list_add_tail(&st->list, sockets);
249 }
250 
251 /*
252  * Read vsock stats into a list.
253  */
read_vsock_stat(struct list_head * sockets)254 static void read_vsock_stat(struct list_head *sockets)
255 {
256 	long buf[8192 / sizeof(long)];
257 	int fd;
258 
259 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260 	if (fd < 0) {
261 		perror("socket");
262 		exit(EXIT_FAILURE);
263 	}
264 
265 	send_req(fd);
266 
267 	for (;;) {
268 		const struct nlmsghdr *h;
269 		ssize_t ret;
270 
271 		ret = recv_resp(fd, buf, sizeof(buf));
272 		if (ret == 0)
273 			goto done;
274 		if (ret < sizeof(*h)) {
275 			fprintf(stderr, "short read of %zd bytes\n", ret);
276 			exit(EXIT_FAILURE);
277 		}
278 
279 		h = (struct nlmsghdr *)buf;
280 
281 		while (NLMSG_OK(h, ret)) {
282 			if (h->nlmsg_type == NLMSG_DONE)
283 				goto done;
284 
285 			if (h->nlmsg_type == NLMSG_ERROR) {
286 				const struct nlmsgerr *err = NLMSG_DATA(h);
287 
288 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289 					fprintf(stderr, "NLMSG_ERROR\n");
290 				else {
291 					errno = -err->error;
292 					perror("NLMSG_ERROR");
293 				}
294 
295 				exit(EXIT_FAILURE);
296 			}
297 
298 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
300 					h->nlmsg_type);
301 				exit(EXIT_FAILURE);
302 			}
303 			if (h->nlmsg_len <
304 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305 				fprintf(stderr, "short vsock_diag_msg\n");
306 				exit(EXIT_FAILURE);
307 			}
308 
309 			add_vsock_stat(sockets, NLMSG_DATA(h));
310 
311 			h = NLMSG_NEXT(h, ret);
312 		}
313 	}
314 
315 done:
316 	close(fd);
317 }
318 
free_sock_stat(struct list_head * sockets)319 static void free_sock_stat(struct list_head *sockets)
320 {
321 	struct vsock_stat *st;
322 	struct vsock_stat *next;
323 
324 	list_for_each_entry_safe(st, next, sockets, list)
325 		free(st);
326 }
327 
test_no_sockets(const struct test_opts * opts)328 static void test_no_sockets(const struct test_opts *opts)
329 {
330 	LIST_HEAD(sockets);
331 
332 	read_vsock_stat(&sockets);
333 
334 	check_no_sockets(&sockets);
335 
336 	free_sock_stat(&sockets);
337 }
338 
test_listen_socket_server(const struct test_opts * opts)339 static void test_listen_socket_server(const struct test_opts *opts)
340 {
341 	union {
342 		struct sockaddr sa;
343 		struct sockaddr_vm svm;
344 	} addr = {
345 		.svm = {
346 			.svm_family = AF_VSOCK,
347 			.svm_port = 1234,
348 			.svm_cid = VMADDR_CID_ANY,
349 		},
350 	};
351 	LIST_HEAD(sockets);
352 	struct vsock_stat *st;
353 	int fd;
354 
355 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356 
357 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358 		perror("bind");
359 		exit(EXIT_FAILURE);
360 	}
361 
362 	if (listen(fd, 1) < 0) {
363 		perror("listen");
364 		exit(EXIT_FAILURE);
365 	}
366 
367 	read_vsock_stat(&sockets);
368 
369 	check_num_sockets(&sockets, 1);
370 	st = find_vsock_stat(&sockets, fd);
371 	check_socket_state(st, TCP_LISTEN);
372 
373 	close(fd);
374 	free_sock_stat(&sockets);
375 }
376 
test_connect_client(const struct test_opts * opts)377 static void test_connect_client(const struct test_opts *opts)
378 {
379 	int fd;
380 	LIST_HEAD(sockets);
381 	struct vsock_stat *st;
382 
383 	fd = vsock_stream_connect(opts->peer_cid, 1234);
384 	if (fd < 0) {
385 		perror("connect");
386 		exit(EXIT_FAILURE);
387 	}
388 
389 	read_vsock_stat(&sockets);
390 
391 	check_num_sockets(&sockets, 1);
392 	st = find_vsock_stat(&sockets, fd);
393 	check_socket_state(st, TCP_ESTABLISHED);
394 
395 	control_expectln("DONE");
396 	control_writeln("DONE");
397 
398 	close(fd);
399 	free_sock_stat(&sockets);
400 }
401 
test_connect_server(const struct test_opts * opts)402 static void test_connect_server(const struct test_opts *opts)
403 {
404 	struct vsock_stat *st;
405 	LIST_HEAD(sockets);
406 	int client_fd;
407 
408 	client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
409 	if (client_fd < 0) {
410 		perror("accept");
411 		exit(EXIT_FAILURE);
412 	}
413 
414 	read_vsock_stat(&sockets);
415 
416 	check_num_sockets(&sockets, 1);
417 	st = find_vsock_stat(&sockets, client_fd);
418 	check_socket_state(st, TCP_ESTABLISHED);
419 
420 	control_writeln("DONE");
421 	control_expectln("DONE");
422 
423 	close(client_fd);
424 	free_sock_stat(&sockets);
425 }
426 
427 static struct test_case test_cases[] = {
428 	{
429 		.name = "No sockets",
430 		.run_server = test_no_sockets,
431 	},
432 	{
433 		.name = "Listen socket",
434 		.run_server = test_listen_socket_server,
435 	},
436 	{
437 		.name = "Connect",
438 		.run_client = test_connect_client,
439 		.run_server = test_connect_server,
440 	},
441 	{},
442 };
443 
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
446 	{
447 		.name = "control-host",
448 		.has_arg = required_argument,
449 		.val = 'H',
450 	},
451 	{
452 		.name = "control-port",
453 		.has_arg = required_argument,
454 		.val = 'P',
455 	},
456 	{
457 		.name = "mode",
458 		.has_arg = required_argument,
459 		.val = 'm',
460 	},
461 	{
462 		.name = "peer-cid",
463 		.has_arg = required_argument,
464 		.val = 'p',
465 	},
466 	{
467 		.name = "list",
468 		.has_arg = no_argument,
469 		.val = 'l',
470 	},
471 	{
472 		.name = "skip",
473 		.has_arg = required_argument,
474 		.val = 's',
475 	},
476 	{
477 		.name = "help",
478 		.has_arg = no_argument,
479 		.val = '?',
480 	},
481 	{},
482 };
483 
usage(void)484 static void usage(void)
485 {
486 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
487 		"\n"
488 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
489 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
490 		"\n"
491 		"Run vsock_diag.ko tests.  Must be launched in both\n"
492 		"guest and host.  One side must use --mode=client and\n"
493 		"the other side must use --mode=server.\n"
494 		"\n"
495 		"A TCP control socket connection is used to coordinate tests\n"
496 		"between the client and the server.  The server requires a\n"
497 		"listen address and the client requires an address to\n"
498 		"connect to.\n"
499 		"\n"
500 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
501 		"\n"
502 		"Options:\n"
503 		"  --help                 This help message\n"
504 		"  --control-host <host>  Server IP address to connect to\n"
505 		"  --control-port <port>  Server port to listen on/connect to\n"
506 		"  --mode client|server   Server or client mode\n"
507 		"  --peer-cid <cid>       CID of the other side\n"
508 		"  --list                 List of tests that will be executed\n"
509 		"  --skip <test_id>       Test ID to skip;\n"
510 		"                         use multiple --skip options to skip more tests\n"
511 		);
512 	exit(EXIT_FAILURE);
513 }
514 
main(int argc,char ** argv)515 int main(int argc, char **argv)
516 {
517 	const char *control_host = NULL;
518 	const char *control_port = NULL;
519 	struct test_opts opts = {
520 		.mode = TEST_MODE_UNSET,
521 		.peer_cid = VMADDR_CID_ANY,
522 	};
523 
524 	init_signals();
525 
526 	for (;;) {
527 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
528 
529 		if (opt == -1)
530 			break;
531 
532 		switch (opt) {
533 		case 'H':
534 			control_host = optarg;
535 			break;
536 		case 'm':
537 			if (strcmp(optarg, "client") == 0)
538 				opts.mode = TEST_MODE_CLIENT;
539 			else if (strcmp(optarg, "server") == 0)
540 				opts.mode = TEST_MODE_SERVER;
541 			else {
542 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
543 				return EXIT_FAILURE;
544 			}
545 			break;
546 		case 'p':
547 			opts.peer_cid = parse_cid(optarg);
548 			break;
549 		case 'P':
550 			control_port = optarg;
551 			break;
552 		case 'l':
553 			list_tests(test_cases);
554 			break;
555 		case 's':
556 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
557 				  optarg);
558 			break;
559 		case '?':
560 		default:
561 			usage();
562 		}
563 	}
564 
565 	if (!control_port)
566 		usage();
567 	if (opts.mode == TEST_MODE_UNSET)
568 		usage();
569 	if (opts.peer_cid == VMADDR_CID_ANY)
570 		usage();
571 
572 	if (!control_host) {
573 		if (opts.mode != TEST_MODE_SERVER)
574 			usage();
575 		control_host = "0.0.0.0";
576 	}
577 
578 	control_init(control_host, control_port,
579 		     opts.mode == TEST_MODE_SERVER);
580 
581 	run_tests(test_cases, &opts);
582 
583 	control_cleanup();
584 	return EXIT_SUCCESS;
585 }
586