1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/kernel.h>
8 #include <zephyr/net/socket.h>
9 #include <zephyr/posix/fcntl.h>
10 #include <zephyr/internal/syscall_handler.h>
11 #include <zephyr/sys/__assert.h>
12 #include <zephyr/sys/fdtable.h>
13 
14 #include "sockets_internal.h"
15 
16 enum {
17 	SPAIR_SIG_CANCEL, /**< operation has been canceled */
18 	SPAIR_SIG_DATA,   /**< @ref spair.recv_q has been updated */
19 };
20 
21 enum {
22 	SPAIR_FLAG_NONBLOCK = (1 << 0), /**< socket is non-blocking */
23 };
24 
25 #define SPAIR_FLAGS_DEFAULT 0
26 
27 /**
28  * Socketpair endpoint structure
29  *
30  * This structure represents one half of a socketpair (an 'endpoint').
31  *
32  * The implementation strives for compatibility with socketpair(2).
33  *
34  * Resources contained within this structure are said to be 'local', while
35  * resources contained within the other half of the socketpair (or other
36  * endpoint) are said to be 'remote'.
37  *
38  * Theory of operation:
39  * - each end of a socketpair owns a @a recv_q
40  * - since there is no write queue, data is either written or not
41  * - read and write operations may return partial transfers
42  * - read operations may block if the local @a recv_q is empty
43  * - write operations may block if the remote @a recv_q is full
44  * - each endpoint may be blocking or non-blocking
45  */
46 __net_socket struct spair {
47 	int remote; /**< the remote endpoint file descriptor */
48 	uint32_t flags; /**< status and option bits */
49 	struct k_sem sem; /**< semaphore for exclusive structure access */
50 	struct k_pipe recv_q; /**< receive queue of local endpoint */
51 	/** indicates local @a recv_q isn't empty */
52 	struct k_poll_signal readable;
53 	/** indicates local @a recv_q isn't full */
54 	struct k_poll_signal writeable;
55 	/** buffer for @a recv_q recv_q */
56 	uint8_t buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE];
57 };
58 
59 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
60 K_MEM_SLAB_DEFINE_STATIC(spair_slab, sizeof(struct spair), CONFIG_NET_SOCKETPAIR_MAX * 2,
61 			 __alignof__(struct spair));
62 #endif /* CONFIG_NET_SOCKETPAIR_STATIC */
63 
64 /* forward declaration */
65 static const struct socket_op_vtable spair_fd_op_vtable;
66 
67 #undef sock_is_nonblock
68 /** Determine if a @ref spair is in non-blocking mode */
sock_is_nonblock(const struct spair * spair)69 static inline bool sock_is_nonblock(const struct spair *spair)
70 {
71 	return !!(spair->flags & SPAIR_FLAG_NONBLOCK);
72 }
73 
74 /** Determine if a @ref spair is connected */
sock_is_connected(const struct spair * spair)75 static inline bool sock_is_connected(const struct spair *spair)
76 {
77 	const struct spair *remote = zvfs_get_fd_obj(spair->remote,
78 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
79 
80 	if (remote == NULL) {
81 		return false;
82 	}
83 
84 	return true;
85 }
86 
87 #undef sock_is_eof
88 /** Determine if a @ref spair has encountered end-of-file */
sock_is_eof(const struct spair * spair)89 static inline bool sock_is_eof(const struct spair *spair)
90 {
91 	return !sock_is_connected(spair);
92 }
93 
94 /**
95  * Determine bytes available to write
96  *
97  * Specifically, this function calculates the number of bytes that may be
98  * written to a given @ref spair without blocking.
99  */
spair_write_avail(struct spair * spair)100 static inline size_t spair_write_avail(struct spair *spair)
101 {
102 	struct spair *const remote = zvfs_get_fd_obj(spair->remote,
103 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
104 
105 	if (remote == NULL) {
106 		return 0;
107 	}
108 
109 	return k_pipe_write_avail(&remote->recv_q);
110 }
111 
112 /**
113  * Determine bytes available to read
114  *
115  * Specifically, this function calculates the number of bytes that may be
116  * read from a given @ref spair without blocking.
117  */
spair_read_avail(struct spair * spair)118 static inline size_t spair_read_avail(struct spair *spair)
119 {
120 	return k_pipe_read_avail(&spair->recv_q);
121 }
122 
123 /** Swap two 32-bit integers */
swap32(uint32_t * a,uint32_t * b)124 static inline void swap32(uint32_t *a, uint32_t *b)
125 {
126 	uint32_t c;
127 
128 	c = *b;
129 	*b = *a;
130 	*a = c;
131 }
132 
133 /**
134  * Delete @param spair
135  *
136  * This function deletes one endpoint of a socketpair.
137  *
138  * Theory of operation:
139  * - we have a socketpair with two endpoints: A and B
140  * - we have two threads: T1 and T2
141  * - T1 operates on endpoint A
142  * - T2 operates on endpoint B
143  *
144  * There are two possible cases where a blocking operation must be notified
145  * when one endpoint is closed:
146  * -# T1 is blocked reading from A and T2 closes B
147  *    T1 waits on A's write signal. T2 triggers the remote
148  *    @ref spair.readable
149  * -# T1 is blocked writing to A and T2 closes B
150  *    T1 is waits on B's read signal. T2 triggers the local
151  *    @ref spair.writeable.
152  *
153  * If the remote endpoint is already closed, the former operation does not
154  * take place. Otherwise, the @ref spair.remote of the local endpoint is
155  * set to -1.
156  *
157  * If no threads are blocking on A, then the signals have no effect.
158  *
159  * The memory associated with the local endpoint is cleared and freed.
160  */
spair_delete(struct spair * spair)161 static void spair_delete(struct spair *spair)
162 {
163 	int res;
164 	struct spair *remote = NULL;
165 	bool have_remote_sem = false;
166 
167 	if (spair == NULL) {
168 		return;
169 	}
170 
171 	if (spair->remote != -1) {
172 		remote = zvfs_get_fd_obj(spair->remote,
173 			(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
174 
175 		if (remote != NULL) {
176 			res = k_sem_take(&remote->sem, K_FOREVER);
177 			if (res == 0) {
178 				have_remote_sem = true;
179 				remote->remote = -1;
180 				res = k_poll_signal_raise(&remote->readable,
181 					SPAIR_SIG_CANCEL);
182 				__ASSERT(res == 0,
183 					"k_poll_signal_raise() failed: %d",
184 					res);
185 			}
186 		}
187 	}
188 
189 	spair->remote = -1;
190 
191 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_CANCEL);
192 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
193 
194 	if (remote != NULL && have_remote_sem) {
195 		k_sem_give(&remote->sem);
196 	}
197 
198 	/* ensure no private information is released to the memory pool */
199 	memset(spair, 0, sizeof(*spair));
200 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
201 	k_mem_slab_free(&spair_slab, (void *)spair);
202 #elif CONFIG_USERSPACE
203 	k_object_free(spair);
204 #else
205 	k_free(spair);
206 #endif
207 }
208 
209 /**
210  * Create a @ref spair (1/2 of a socketpair)
211  *
212  * The idea is to call this twice, but store the "local" side in the
213  * @ref spair.remote field initially.
214  *
215  * If both allocations are successful, then swap the @ref spair.remote
216  * fields in the two @ref spair instances.
217  */
spair_new(void)218 static struct spair *spair_new(void)
219 {
220 	struct spair *spair;
221 	int res;
222 
223 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
224 
225 	res = k_mem_slab_alloc(&spair_slab, (void **) &spair, K_NO_WAIT);
226 	if (res != 0) {
227 		spair = NULL;
228 	}
229 
230 #elif CONFIG_USERSPACE
231 	struct k_object *zo = k_object_create_dynamic(sizeof(*spair));
232 
233 	if (zo == NULL) {
234 		spair = NULL;
235 	} else {
236 		spair = zo->name;
237 		zo->type = K_OBJ_NET_SOCKET;
238 	}
239 #else
240 	spair = k_malloc(sizeof(*spair));
241 #endif
242 	if (spair == NULL) {
243 		errno = ENOMEM;
244 		goto out;
245 	}
246 	memset(spair, 0, sizeof(*spair));
247 
248 	/* initialize any non-zero default values */
249 	spair->remote = -1;
250 	spair->flags = SPAIR_FLAGS_DEFAULT;
251 
252 	k_sem_init(&spair->sem, 1, 1);
253 	k_pipe_init(&spair->recv_q, spair->buf, sizeof(spair->buf));
254 	k_poll_signal_init(&spair->readable);
255 	k_poll_signal_init(&spair->writeable);
256 
257 	/* A new socket is always writeable after creation */
258 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
259 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
260 
261 	spair->remote = zvfs_reserve_fd();
262 	if (spair->remote == -1) {
263 		errno = ENFILE;
264 		goto cleanup;
265 	}
266 
267 	zvfs_finalize_typed_fd(spair->remote, spair,
268 			       (const struct fd_op_vtable *)&spair_fd_op_vtable, ZVFS_MODE_IFSOCK);
269 
270 	goto out;
271 
272 cleanup:
273 	spair_delete(spair);
274 	spair = NULL;
275 
276 out:
277 	return spair;
278 }
279 
z_impl_zsock_socketpair(int family,int type,int proto,int * sv)280 int z_impl_zsock_socketpair(int family, int type, int proto, int *sv)
281 {
282 	int res;
283 	size_t i;
284 	struct spair *obj[2] = {};
285 
286 	SYS_PORT_TRACING_OBJ_FUNC_ENTER(socket, socketpair, family, type, proto, sv);
287 
288 	if (family != AF_UNIX) {
289 		errno = EAFNOSUPPORT;
290 		res = -1;
291 		goto errout;
292 	}
293 
294 	if (type != SOCK_STREAM) {
295 		errno = EPROTOTYPE;
296 		res = -1;
297 		goto errout;
298 	}
299 
300 	if (proto != 0) {
301 		errno = EPROTONOSUPPORT;
302 		res = -1;
303 		goto errout;
304 	}
305 
306 	if (sv == NULL) {
307 		/* not listed in normative spec, but mimics Linux behaviour */
308 		errno = EFAULT;
309 		res = -1;
310 		goto errout;
311 	}
312 
313 	for (i = 0; i < 2; ++i) {
314 		obj[i] = spair_new();
315 		if (!obj[i]) {
316 			res = -1;
317 			goto cleanup;
318 		}
319 	}
320 
321 	/* connect the two endpoints */
322 	swap32(&obj[0]->remote, &obj[1]->remote);
323 
324 	for (i = 0; i < 2; ++i) {
325 		sv[i] = obj[i]->remote;
326 		k_sem_give(&obj[0]->sem);
327 	}
328 
329 	SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, socketpair, sv[0], sv[1], 0);
330 
331 	return 0;
332 
333 cleanup:
334 	for (i = 0; i < 2; ++i) {
335 		spair_delete(obj[i]);
336 	}
337 
338 errout:
339 	SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, socketpair, -1, -1, -errno);
340 
341 	return res;
342 }
343 
344 #ifdef CONFIG_USERSPACE
z_vrfy_zsock_socketpair(int family,int type,int proto,int * sv)345 int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv)
346 {
347 	int ret;
348 	int tmp[2];
349 
350 	if (!sv || K_SYSCALL_MEMORY_WRITE(sv, sizeof(tmp)) != 0) {
351 		/* not listed in normative spec, but mimics linux behaviour */
352 		errno = EFAULT;
353 		ret = -1;
354 		goto out;
355 	}
356 
357 	ret = z_impl_zsock_socketpair(family, type, proto, tmp);
358 	if (ret == 0) {
359 		K_OOPS(k_usermode_to_copy(sv, tmp, sizeof(tmp)));
360 	}
361 
362 out:
363 	return ret;
364 }
365 
366 #include <zephyr/syscalls/zsock_socketpair_mrsh.c>
367 #endif /* CONFIG_USERSPACE */
368 
369 /**
370  * Write data to one end of a @ref spair
371  *
372  * Data written on one file descriptor of a socketpair can be read at the
373  * other end using common POSIX calls such as read(2) or recv(2).
374  *
375  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
376  * this function will return immediately. If no data was written on a
377  * non-blocking file descriptor, then -1 will be returned and @ref errno will
378  * be set to @ref EAGAIN.
379  *
380  * Blocking write operations occur when the @ref O_NONBLOCK flag is @em not
381  * set and there is insufficient space in the @em remote @ref spair.pipe.
382  *
383  * Such a blocking write will suspend execution of the current thread until
384  * one of two possible results is received on the @em remote
385  * @ref spair.writeable:
386  *
387  * 1) @ref SPAIR_SIG_DATA - data has been read from the @em remote
388  *    @ref spair.pipe. Thus, allowing more data to be written.
389  *
390  * 2) @ref SPAIR_SIG_CANCEL - the @em remote socketpair endpoint was closed
391  *    Receipt of this result is analogous to SIGPIPE from POSIX
392  *    ("Write on a pipe with no one to read it."). In this case, the function
393  *    will return -1 and set @ref errno to @ref EPIPE.
394  *
395  * @param obj the address of an @ref spair object cast to `void *`
396  * @param buffer the buffer to write
397  * @param count the number of bytes to write from @p buffer
398  *
399  * @return on success, a number > 0 representing the number of bytes written
400  * @return -1 on error, with @ref errno set appropriately.
401  */
spair_write(void * obj,const void * buffer,size_t count)402 static ssize_t spair_write(void *obj, const void *buffer, size_t count)
403 {
404 	int res;
405 	size_t avail;
406 	bool is_nonblock;
407 	size_t bytes_written;
408 	bool have_local_sem = false;
409 	bool have_remote_sem = false;
410 	bool will_block = false;
411 	struct spair *const spair = (struct spair *)obj;
412 	struct spair *remote = NULL;
413 
414 	if (obj == NULL || buffer == NULL || count == 0) {
415 		errno = EINVAL;
416 		res = -1;
417 		goto out;
418 	}
419 
420 	res = k_sem_take(&spair->sem, K_NO_WAIT);
421 	is_nonblock = sock_is_nonblock(spair);
422 	if (res < 0) {
423 		if (is_nonblock) {
424 			errno = EAGAIN;
425 			res = -1;
426 			goto out;
427 		}
428 
429 		res = k_sem_take(&spair->sem, K_FOREVER);
430 		if (res < 0) {
431 			errno = -res;
432 			res = -1;
433 			goto out;
434 		}
435 		is_nonblock = sock_is_nonblock(spair);
436 	}
437 
438 	have_local_sem = true;
439 
440 	remote = zvfs_get_fd_obj(spair->remote,
441 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
442 
443 	if (remote == NULL) {
444 		errno = EPIPE;
445 		res = -1;
446 		goto out;
447 	}
448 
449 	res = k_sem_take(&remote->sem, K_NO_WAIT);
450 	if (res < 0) {
451 		if (is_nonblock) {
452 			errno = EAGAIN;
453 			res = -1;
454 			goto out;
455 		}
456 		res = k_sem_take(&remote->sem, K_FOREVER);
457 		if (res < 0) {
458 			errno = -res;
459 			res = -1;
460 			goto out;
461 		}
462 	}
463 
464 	have_remote_sem = true;
465 
466 	avail = spair_write_avail(spair);
467 
468 	if (avail == 0) {
469 		if (is_nonblock) {
470 			errno = EAGAIN;
471 			res = -1;
472 			goto out;
473 		}
474 		will_block = true;
475 	}
476 
477 	if (will_block) {
478 		if (k_is_in_isr()) {
479 			errno = EAGAIN;
480 			res = -1;
481 			goto out;
482 		}
483 
484 		for (int signaled = false, result = -1; !signaled;
485 			result = -1) {
486 
487 			struct k_poll_event events[] = {
488 				K_POLL_EVENT_INITIALIZER(
489 					K_POLL_TYPE_SIGNAL,
490 					K_POLL_MODE_NOTIFY_ONLY,
491 					&remote->writeable),
492 			};
493 
494 			k_sem_give(&remote->sem);
495 			have_remote_sem = false;
496 
497 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
498 			if (res < 0) {
499 				errno = -res;
500 				res = -1;
501 				goto out;
502 			}
503 
504 			remote = zvfs_get_fd_obj(spair->remote,
505 				(const struct fd_op_vtable *)
506 				&spair_fd_op_vtable, 0);
507 
508 			if (remote == NULL) {
509 				errno = EPIPE;
510 				res = -1;
511 				goto out;
512 			}
513 
514 			res = k_sem_take(&remote->sem, K_FOREVER);
515 			if (res < 0) {
516 				errno = -res;
517 				res = -1;
518 				goto out;
519 			}
520 
521 			have_remote_sem = true;
522 
523 			k_poll_signal_check(&remote->writeable, &signaled,
524 					    &result);
525 			if (!signaled) {
526 				continue;
527 			}
528 
529 			switch (result) {
530 				case SPAIR_SIG_DATA: {
531 					break;
532 				}
533 
534 				case SPAIR_SIG_CANCEL: {
535 					errno = EPIPE;
536 					res = -1;
537 					goto out;
538 				}
539 
540 				default: {
541 					__ASSERT(false,
542 						"unrecognized result: %d",
543 						result);
544 					continue;
545 				}
546 			}
547 
548 			/* SPAIR_SIG_DATA was received */
549 			break;
550 		}
551 	}
552 
553 	res = k_pipe_put(&remote->recv_q, (void *)buffer, count,
554 			 &bytes_written, 1, K_NO_WAIT);
555 	__ASSERT(res == 0, "k_pipe_put() failed: %d", res);
556 
557 	if (spair_write_avail(spair) == 0) {
558 		k_poll_signal_reset(&remote->writeable);
559 	}
560 
561 	res = k_poll_signal_raise(&remote->readable, SPAIR_SIG_DATA);
562 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
563 
564 	res = bytes_written;
565 
566 out:
567 
568 	if (remote != NULL && have_remote_sem) {
569 		k_sem_give(&remote->sem);
570 	}
571 	if (spair != NULL && have_local_sem) {
572 		k_sem_give(&spair->sem);
573 	}
574 
575 	return res;
576 }
577 
578 /**
579  * Read data from one end of a @ref spair
580  *
581  * Data written on one file descriptor of a socketpair (with e.g. write(2) or
582  * send(2)) can be read at the other end using common POSIX calls such as
583  * read(2) or recv(2).
584  *
585  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
586  * this function will return immediately. If no data was read from a
587  * non-blocking file descriptor, then -1 will be returned and @ref errno will
588  * be set to @ref EAGAIN.
589  *
590  * Blocking read operations occur when the @ref O_NONBLOCK flag is @em not set
591  * and there are no bytes to read in the @em local @ref spair.pipe.
592  *
593  * Such a blocking read will suspend execution of the current thread until
594  * one of two possible results is received on the @em local
595  * @ref spair.readable:
596  *
597  * -# @ref SPAIR_SIG_DATA - data has been written to the @em local
598  *    @ref spair.pipe. Thus, allowing more data to be read.
599  *
600  * -# @ref SPAIR_SIG_CANCEL - read of the @em local @spair.pipe
601  *    must be cancelled for some reason (e.g. the file descriptor will be
602  *    closed imminently). In this case, the function will return -1 and set
603  *    @ref errno to @ref EINTR.
604  *
605  * @param obj the address of an @ref spair object cast to `void *`
606  * @param buffer the buffer in which to read
607  * @param count the number of bytes to read
608  *
609  * @return on success, a number > 0 representing the number of bytes written
610  * @return -1 on error, with @ref errno set appropriately.
611  */
spair_read(void * obj,void * buffer,size_t count)612 static ssize_t spair_read(void *obj, void *buffer, size_t count)
613 {
614 	int res;
615 	bool is_connected;
616 	size_t avail;
617 	bool is_nonblock;
618 	size_t bytes_read;
619 	bool have_local_sem = false;
620 	bool will_block = false;
621 	struct spair *const spair = (struct spair *)obj;
622 
623 	if (obj == NULL || buffer == NULL || count == 0) {
624 		errno = EINVAL;
625 		res = -1;
626 		goto out;
627 	}
628 
629 	res = k_sem_take(&spair->sem, K_NO_WAIT);
630 	is_nonblock = sock_is_nonblock(spair);
631 	if (res < 0) {
632 		if (is_nonblock) {
633 			errno = EAGAIN;
634 			res = -1;
635 			goto out;
636 		}
637 
638 		res = k_sem_take(&spair->sem, K_FOREVER);
639 		if (res < 0) {
640 			errno = -res;
641 			res = -1;
642 			goto out;
643 		}
644 		is_nonblock = sock_is_nonblock(spair);
645 	}
646 
647 	have_local_sem = true;
648 
649 	is_connected = sock_is_connected(spair);
650 	avail = spair_read_avail(spair);
651 
652 	if (avail == 0) {
653 		if (!is_connected) {
654 			/* signal EOF */
655 			res = 0;
656 			goto out;
657 		}
658 
659 		if (is_nonblock) {
660 			errno = EAGAIN;
661 			res = -1;
662 			goto out;
663 		}
664 
665 		will_block = true;
666 	}
667 
668 	if (will_block) {
669 		if (k_is_in_isr()) {
670 			errno = EAGAIN;
671 			res = -1;
672 			goto out;
673 		}
674 
675 		for (int signaled = false, result = -1; !signaled;
676 			result = -1) {
677 
678 			struct k_poll_event events[] = {
679 				K_POLL_EVENT_INITIALIZER(
680 					K_POLL_TYPE_SIGNAL,
681 					K_POLL_MODE_NOTIFY_ONLY,
682 					&spair->readable
683 				),
684 			};
685 
686 			k_sem_give(&spair->sem);
687 			have_local_sem = false;
688 
689 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
690 			__ASSERT(res == 0, "k_poll() failed: %d", res);
691 
692 			res = k_sem_take(&spair->sem, K_FOREVER);
693 			__ASSERT(res == 0, "failed to take local sem: %d", res);
694 
695 			have_local_sem = true;
696 
697 			k_poll_signal_check(&spair->readable, &signaled,
698 					    &result);
699 			if (!signaled) {
700 				continue;
701 			}
702 
703 			switch (result) {
704 				case SPAIR_SIG_DATA: {
705 					break;
706 				}
707 
708 				case SPAIR_SIG_CANCEL: {
709 					errno = EPIPE;
710 					res = -1;
711 					goto out;
712 				}
713 
714 				default: {
715 					__ASSERT(false,
716 						"unrecognized result: %d",
717 						result);
718 					continue;
719 				}
720 			}
721 
722 			/* SPAIR_SIG_DATA was received */
723 			break;
724 		}
725 	}
726 
727 	res = k_pipe_get(&spair->recv_q, (void *)buffer, count, &bytes_read,
728 			 1, K_NO_WAIT);
729 	__ASSERT(res == 0, "k_pipe_get() failed: %d", res);
730 
731 	if (spair_read_avail(spair) == 0 && !sock_is_eof(spair)) {
732 		k_poll_signal_reset(&spair->readable);
733 	}
734 
735 	if (is_connected) {
736 		res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
737 		__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
738 	}
739 
740 	res = bytes_read;
741 
742 out:
743 
744 	if (spair != NULL && have_local_sem) {
745 		k_sem_give(&spair->sem);
746 	}
747 
748 	return res;
749 }
750 
zsock_poll_prepare_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)751 static int zsock_poll_prepare_ctx(struct spair *const spair,
752 				  struct zsock_pollfd *const pfd,
753 				  struct k_poll_event **pev,
754 				  struct k_poll_event *pev_end)
755 {
756 	int res;
757 
758 	struct spair *remote = NULL;
759 	bool have_remote_sem = false;
760 
761 	if (pfd->events & ZSOCK_POLLIN) {
762 
763 		/* Tell poll() to short-circuit wait */
764 		if (sock_is_eof(spair)) {
765 			res = -EALREADY;
766 			goto out;
767 		}
768 
769 		if (*pev == pev_end) {
770 			res = -ENOMEM;
771 			goto out;
772 		}
773 
774 		/* Wait until data has been written to the local end */
775 		(*pev)->obj = &spair->readable;
776 	}
777 
778 	if (pfd->events & ZSOCK_POLLOUT) {
779 
780 		/* Tell poll() to short-circuit wait */
781 		if (!sock_is_connected(spair)) {
782 			res = -EALREADY;
783 			goto out;
784 		}
785 
786 		if (*pev == pev_end) {
787 			res = -ENOMEM;
788 			goto out;
789 		}
790 
791 		remote = zvfs_get_fd_obj(spair->remote,
792 			(const struct fd_op_vtable *)
793 			&spair_fd_op_vtable, 0);
794 
795 		__ASSERT(remote != NULL, "remote is NULL");
796 
797 		res = k_sem_take(&remote->sem, K_FOREVER);
798 		if (res < 0) {
799 			goto out;
800 		}
801 
802 		have_remote_sem = true;
803 
804 		/* Wait until the recv queue on the remote end is no longer full */
805 		(*pev)->obj = &remote->writeable;
806 	}
807 
808 	(*pev)->type = K_POLL_TYPE_SIGNAL;
809 	(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
810 	(*pev)->state = K_POLL_STATE_NOT_READY;
811 
812 	(*pev)++;
813 
814 	res = 0;
815 
816 out:
817 
818 	if (remote != NULL && have_remote_sem) {
819 		k_sem_give(&remote->sem);
820 	}
821 
822 	return res;
823 }
824 
zsock_poll_update_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev)825 static int zsock_poll_update_ctx(struct spair *const spair,
826 				 struct zsock_pollfd *const pfd,
827 				 struct k_poll_event **pev)
828 {
829 	int res;
830 	int signaled;
831 	int result;
832 	struct spair *remote = NULL;
833 	bool have_remote_sem = false;
834 
835 	if (pfd->events & ZSOCK_POLLOUT) {
836 		if (!sock_is_connected(spair)) {
837 			pfd->revents |= ZSOCK_POLLHUP;
838 			goto pollout_done;
839 		}
840 
841 		remote = zvfs_get_fd_obj(spair->remote,
842 			(const struct fd_op_vtable *) &spair_fd_op_vtable, 0);
843 
844 		__ASSERT(remote != NULL, "remote is NULL");
845 
846 		res = k_sem_take(&remote->sem, K_FOREVER);
847 		if (res < 0) {
848 			/* if other end is deleted, this might occur */
849 			goto pollout_done;
850 		}
851 
852 		have_remote_sem = true;
853 
854 		if (spair_write_avail(spair) > 0) {
855 			pfd->revents |= ZSOCK_POLLOUT;
856 			goto pollout_done;
857 		}
858 
859 		/* check to see if op was canceled */
860 		signaled = false;
861 		k_poll_signal_check(&remote->writeable, &signaled, &result);
862 		if (signaled) {
863 			/* Cannot be SPAIR_SIG_DATA, because
864 			 * spair_write_avail() would have
865 			 * returned 0
866 			 */
867 			__ASSERT(result == SPAIR_SIG_CANCEL,
868 				"invalid result %d", result);
869 			pfd->revents |= ZSOCK_POLLHUP;
870 		}
871 	}
872 
873 pollout_done:
874 
875 	if (pfd->events & ZSOCK_POLLIN) {
876 		if (sock_is_eof(spair)) {
877 			pfd->revents |= ZSOCK_POLLIN;
878 			goto pollin_done;
879 		}
880 
881 		if (spair_read_avail(spair) > 0) {
882 			pfd->revents |= ZSOCK_POLLIN;
883 			goto pollin_done;
884 		}
885 
886 		/* check to see if op was canceled */
887 		signaled = false;
888 		k_poll_signal_check(&spair->readable, &signaled, &result);
889 		if (signaled) {
890 			/* Cannot be SPAIR_SIG_DATA, because
891 			 * spair_read_avail() would have
892 			 * returned 0
893 			 */
894 			__ASSERT(result == SPAIR_SIG_CANCEL,
895 					 "invalid result %d", result);
896 			pfd->revents |= ZSOCK_POLLIN;
897 		}
898 	}
899 
900 pollin_done:
901 	res = 0;
902 
903 	(*pev)++;
904 
905 	if (remote != NULL && have_remote_sem) {
906 		k_sem_give(&remote->sem);
907 	}
908 
909 	return res;
910 }
911 
spair_ioctl(void * obj,unsigned int request,va_list args)912 static int spair_ioctl(void *obj, unsigned int request, va_list args)
913 {
914 	int res;
915 	struct zsock_pollfd *pfd;
916 	struct k_poll_event **pev;
917 	struct k_poll_event *pev_end;
918 	int flags = 0;
919 	bool have_local_sem = false;
920 	struct spair *const spair = (struct spair *)obj;
921 
922 	if (spair == NULL) {
923 		errno = EINVAL;
924 		res = -1;
925 		goto out;
926 	}
927 
928 	/* The local sem is always taken in this function. If a subsequent
929 	 * function call requires the remote sem, it must acquire and free the
930 	 * remote sem.
931 	 */
932 	res = k_sem_take(&spair->sem, K_FOREVER);
933 	__ASSERT(res == 0, "failed to take local sem: %d", res);
934 
935 	have_local_sem = true;
936 
937 	switch (request) {
938 		case F_GETFL: {
939 			if (sock_is_nonblock(spair)) {
940 				flags |= O_NONBLOCK;
941 			}
942 
943 			res = flags;
944 			goto out;
945 		}
946 
947 		case F_SETFL: {
948 			flags = va_arg(args, int);
949 
950 			if (flags & O_NONBLOCK) {
951 				spair->flags |= SPAIR_FLAG_NONBLOCK;
952 			} else {
953 				spair->flags &= ~SPAIR_FLAG_NONBLOCK;
954 			}
955 
956 			res = 0;
957 			goto out;
958 		}
959 
960 		case ZFD_IOCTL_FIONBIO: {
961 			spair->flags |= SPAIR_FLAG_NONBLOCK;
962 			res = 0;
963 			goto out;
964 		}
965 
966 		case ZFD_IOCTL_FIONREAD: {
967 			int *nbytes;
968 
969 			nbytes = va_arg(args, int *);
970 			*nbytes = spair_read_avail(spair);
971 
972 			res = 0;
973 			goto out;
974 		}
975 
976 		case ZFD_IOCTL_POLL_PREPARE: {
977 			pfd = va_arg(args, struct zsock_pollfd *);
978 			pev = va_arg(args, struct k_poll_event **);
979 			pev_end = va_arg(args, struct k_poll_event *);
980 
981 			res = zsock_poll_prepare_ctx(obj, pfd, pev, pev_end);
982 			goto out;
983 		}
984 
985 		case ZFD_IOCTL_POLL_UPDATE: {
986 			pfd = va_arg(args, struct zsock_pollfd *);
987 			pev = va_arg(args, struct k_poll_event **);
988 
989 			res = zsock_poll_update_ctx(obj, pfd, pev);
990 			goto out;
991 		}
992 
993 		default: {
994 			errno = EOPNOTSUPP;
995 			res = -1;
996 			goto out;
997 		}
998 	}
999 
1000 out:
1001 	if (spair != NULL && have_local_sem) {
1002 		k_sem_give(&spair->sem);
1003 	}
1004 
1005 	return res;
1006 }
1007 
spair_bind(void * obj,const struct sockaddr * addr,socklen_t addrlen)1008 static int spair_bind(void *obj, const struct sockaddr *addr,
1009 		      socklen_t addrlen)
1010 {
1011 	ARG_UNUSED(obj);
1012 	ARG_UNUSED(addr);
1013 	ARG_UNUSED(addrlen);
1014 
1015 	errno = EISCONN;
1016 	return -1;
1017 }
1018 
spair_connect(void * obj,const struct sockaddr * addr,socklen_t addrlen)1019 static int spair_connect(void *obj, const struct sockaddr *addr,
1020 			 socklen_t addrlen)
1021 {
1022 	ARG_UNUSED(obj);
1023 	ARG_UNUSED(addr);
1024 	ARG_UNUSED(addrlen);
1025 
1026 	errno = EISCONN;
1027 	return -1;
1028 }
1029 
spair_listen(void * obj,int backlog)1030 static int spair_listen(void *obj, int backlog)
1031 {
1032 	ARG_UNUSED(obj);
1033 	ARG_UNUSED(backlog);
1034 
1035 	errno = EINVAL;
1036 	return -1;
1037 }
1038 
spair_accept(void * obj,struct sockaddr * addr,socklen_t * addrlen)1039 static int spair_accept(void *obj, struct sockaddr *addr,
1040 			socklen_t *addrlen)
1041 {
1042 	ARG_UNUSED(obj);
1043 	ARG_UNUSED(addr);
1044 	ARG_UNUSED(addrlen);
1045 
1046 	errno = EOPNOTSUPP;
1047 	return -1;
1048 }
1049 
spair_sendto(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1050 static ssize_t spair_sendto(void *obj, const void *buf, size_t len,
1051 			    int flags, const struct sockaddr *dest_addr,
1052 				 socklen_t addrlen)
1053 {
1054 	ARG_UNUSED(flags);
1055 	ARG_UNUSED(dest_addr);
1056 	ARG_UNUSED(addrlen);
1057 
1058 	return spair_write(obj, buf, len);
1059 }
1060 
spair_sendmsg(void * obj,const struct msghdr * msg,int flags)1061 static ssize_t spair_sendmsg(void *obj, const struct msghdr *msg,
1062 			     int flags)
1063 {
1064 	ARG_UNUSED(flags);
1065 
1066 	int res;
1067 	size_t len = 0;
1068 	bool is_connected;
1069 	size_t avail;
1070 	bool is_nonblock;
1071 	struct spair *const spair = (struct spair *)obj;
1072 
1073 	if (spair == NULL || msg == NULL) {
1074 		errno = EINVAL;
1075 		res = -1;
1076 		goto out;
1077 	}
1078 
1079 	is_connected = sock_is_connected(spair);
1080 	avail = is_connected ? spair_write_avail(spair) : 0;
1081 	is_nonblock = sock_is_nonblock(spair);
1082 
1083 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1084 		/* check & msg->msg_iov[i]? */
1085 		/* check & msg->msg_iov[i].iov_base? */
1086 		len += msg->msg_iov[i].iov_len;
1087 	}
1088 
1089 	if (!is_connected) {
1090 		errno = EPIPE;
1091 		res = -1;
1092 		goto out;
1093 	}
1094 
1095 	if (len == 0) {
1096 		res = 0;
1097 		goto out;
1098 	}
1099 
1100 	if (len > avail && is_nonblock) {
1101 		errno = EMSGSIZE;
1102 		res = -1;
1103 		goto out;
1104 	}
1105 
1106 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1107 		res = spair_write(spair, msg->msg_iov[i].iov_base,
1108 			msg->msg_iov[i].iov_len);
1109 		if (res == -1) {
1110 			goto out;
1111 		}
1112 	}
1113 
1114 	res = len;
1115 
1116 out:
1117 	return res;
1118 }
1119 
spair_recvfrom(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1120 static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len,
1121 			      int flags, struct sockaddr *src_addr,
1122 				   socklen_t *addrlen)
1123 {
1124 	(void)flags;
1125 	(void)src_addr;
1126 	(void)addrlen;
1127 
1128 	if (addrlen != NULL) {
1129 		/* Protocol (PF_UNIX) does not support addressing with connected
1130 		 * sockets and, therefore, it is unspecified behaviour to modify
1131 		 * src_addr. However, it would be ambiguous to leave addrlen
1132 		 * untouched if the user expects it to be updated. It is not
1133 		 * mentioned that modifying addrlen is unspecified. Therefore
1134 		 * we choose to eliminate ambiguity.
1135 		 *
1136 		 * Setting it to zero mimics Linux's behaviour.
1137 		 */
1138 		*addrlen = 0;
1139 	}
1140 
1141 	return spair_read(obj, buf, max_len);
1142 }
1143 
spair_getsockopt(void * obj,int level,int optname,void * optval,socklen_t * optlen)1144 static int spair_getsockopt(void *obj, int level, int optname,
1145 			    void *optval, socklen_t *optlen)
1146 {
1147 	ARG_UNUSED(obj);
1148 	ARG_UNUSED(level);
1149 	ARG_UNUSED(optname);
1150 	ARG_UNUSED(optval);
1151 	ARG_UNUSED(optlen);
1152 
1153 	errno = ENOPROTOOPT;
1154 	return -1;
1155 }
1156 
spair_setsockopt(void * obj,int level,int optname,const void * optval,socklen_t optlen)1157 static int spair_setsockopt(void *obj, int level, int optname,
1158 			    const void *optval, socklen_t optlen)
1159 {
1160 	ARG_UNUSED(obj);
1161 	ARG_UNUSED(level);
1162 	ARG_UNUSED(optname);
1163 	ARG_UNUSED(optval);
1164 	ARG_UNUSED(optlen);
1165 
1166 	errno = ENOPROTOOPT;
1167 	return -1;
1168 }
1169 
spair_close(void * obj)1170 static int spair_close(void *obj)
1171 {
1172 	struct spair *const spair = (struct spair *)obj;
1173 	int res;
1174 
1175 	res = k_sem_take(&spair->sem, K_FOREVER);
1176 	__ASSERT(res == 0, "failed to take local sem: %d", res);
1177 
1178 	/* disconnect the remote endpoint */
1179 	spair_delete(spair);
1180 
1181 	/* Note that the semaphore released already so need to do it here */
1182 
1183 	return 0;
1184 }
1185 
1186 static const struct socket_op_vtable spair_fd_op_vtable = {
1187 	.fd_vtable = {
1188 		.read = spair_read,
1189 		.write = spair_write,
1190 		.close = spair_close,
1191 		.ioctl = spair_ioctl,
1192 	},
1193 	.bind = spair_bind,
1194 	.connect = spair_connect,
1195 	.listen = spair_listen,
1196 	.accept = spair_accept,
1197 	.sendto = spair_sendto,
1198 	.sendmsg = spair_sendmsg,
1199 	.recvfrom = spair_recvfrom,
1200 	.getsockopt = spair_getsockopt,
1201 	.setsockopt = spair_setsockopt,
1202 };
1203