1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/kernel.h>
8 #include <zephyr/net/socket.h>
9 #include <zephyr/internal/syscall_handler.h>
10 #include <zephyr/sys/__assert.h>
11 #include <zephyr/sys/fdtable.h>
12 
13 #include "sockets_internal.h"
14 
15 enum {
16 	SPAIR_SIG_CANCEL, /**< operation has been canceled */
17 	SPAIR_SIG_DATA,   /**< @ref spair.recv_q has been updated */
18 };
19 
20 enum {
21 	SPAIR_FLAG_NONBLOCK = (1 << 0), /**< socket is non-blocking */
22 };
23 
24 #define SPAIR_FLAGS_DEFAULT 0
25 
26 /**
27  * Socketpair endpoint structure
28  *
29  * This structure represents one half of a socketpair (an 'endpoint').
30  *
31  * The implementation strives for compatibility with socketpair(2).
32  *
33  * Resources contained within this structure are said to be 'local', while
34  * resources contained within the other half of the socketpair (or other
35  * endpoint) are said to be 'remote'.
36  *
37  * Theory of operation:
38  * - each end of a socketpair owns a @a recv_q
39  * - since there is no write queue, data is either written or not
40  * - read and write operations may return partial transfers
41  * - read operations may block if the local @a recv_q is empty
42  * - write operations may block if the remote @a recv_q is full
43  * - each endpoint may be blocking or non-blocking
44  */
45 __net_socket struct spair {
46 	int remote; /**< the remote endpoint file descriptor */
47 	uint32_t flags; /**< status and option bits */
48 	struct k_sem sem; /**< semaphore for exclusive structure access */
49 	struct ring_buf recv_q;
50 	/** indicates local @a recv_q isn't empty */
51 	struct k_poll_signal readable;
52 	/** indicates local @a recv_q isn't full */
53 	struct k_poll_signal writeable;
54 	/** buffer for @a recv_q recv_q */
55 	uint8_t buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE];
56 };
57 
58 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
59 K_MEM_SLAB_DEFINE_STATIC(spair_slab, sizeof(struct spair), CONFIG_NET_SOCKETPAIR_MAX * 2,
60 			 __alignof__(struct spair));
61 #endif /* CONFIG_NET_SOCKETPAIR_STATIC */
62 
63 /* forward declaration */
64 static const struct socket_op_vtable spair_fd_op_vtable;
65 
66 #undef sock_is_nonblock
67 /** Determine if a @ref spair is in non-blocking mode */
sock_is_nonblock(const struct spair * spair)68 static inline bool sock_is_nonblock(const struct spair *spair)
69 {
70 	return !!(spair->flags & SPAIR_FLAG_NONBLOCK);
71 }
72 
73 /** Determine if a @ref spair is connected */
sock_is_connected(const struct spair * spair)74 static inline bool sock_is_connected(const struct spair *spair)
75 {
76 	const struct spair *remote = zvfs_get_fd_obj(spair->remote,
77 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
78 
79 	if (remote == NULL) {
80 		return false;
81 	}
82 
83 	return true;
84 }
85 
86 #undef sock_is_eof
87 /** Determine if a @ref spair has encountered end-of-file */
sock_is_eof(const struct spair * spair)88 static inline bool sock_is_eof(const struct spair *spair)
89 {
90 	return !sock_is_connected(spair);
91 }
92 
93 /**
94  * Determine bytes available to write
95  *
96  * Specifically, this function calculates the number of bytes that may be
97  * written to a given @ref spair without blocking.
98  */
spair_write_avail(struct spair * spair)99 static inline size_t spair_write_avail(struct spair *spair)
100 {
101 	struct spair *const remote = zvfs_get_fd_obj(spair->remote,
102 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
103 
104 	if (remote == NULL) {
105 		return 0;
106 	}
107 
108 	return ring_buf_space_get(&remote->recv_q);
109 }
110 
111 /**
112  * Determine bytes available to read
113  *
114  * Specifically, this function calculates the number of bytes that may be
115  * read from a given @ref spair without blocking.
116  */
spair_read_avail(struct spair * spair)117 static inline size_t spair_read_avail(struct spair *spair)
118 {
119 	return ring_buf_size_get(&spair->recv_q);
120 }
121 
122 /** Swap two 32-bit integers */
swap32(uint32_t * a,uint32_t * b)123 static inline void swap32(uint32_t *a, uint32_t *b)
124 {
125 	uint32_t c;
126 
127 	c = *b;
128 	*b = *a;
129 	*a = c;
130 }
131 
132 /**
133  * Delete @param spair
134  *
135  * This function deletes one endpoint of a socketpair.
136  *
137  * Theory of operation:
138  * - we have a socketpair with two endpoints: A and B
139  * - we have two threads: T1 and T2
140  * - T1 operates on endpoint A
141  * - T2 operates on endpoint B
142  *
143  * There are two possible cases where a blocking operation must be notified
144  * when one endpoint is closed:
145  * -# T1 is blocked reading from A and T2 closes B
146  *    T1 waits on A's write signal. T2 triggers the remote
147  *    @ref spair.readable
148  * -# T1 is blocked writing to A and T2 closes B
149  *    T1 is waits on B's read signal. T2 triggers the local
150  *    @ref spair.writeable.
151  *
152  * If the remote endpoint is already closed, the former operation does not
153  * take place. Otherwise, the @ref spair.remote of the local endpoint is
154  * set to -1.
155  *
156  * If no threads are blocking on A, then the signals have no effect.
157  *
158  * The memory associated with the local endpoint is cleared and freed.
159  */
spair_delete(struct spair * spair)160 static void spair_delete(struct spair *spair)
161 {
162 	int res;
163 	struct spair *remote = NULL;
164 	bool have_remote_sem = false;
165 
166 	if (spair == NULL) {
167 		return;
168 	}
169 
170 	if (spair->remote != -1) {
171 		remote = zvfs_get_fd_obj(spair->remote,
172 			(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
173 
174 		if (remote != NULL) {
175 			res = k_sem_take(&remote->sem, K_FOREVER);
176 			if (res == 0) {
177 				have_remote_sem = true;
178 				remote->remote = -1;
179 				res = k_poll_signal_raise(&remote->readable,
180 					SPAIR_SIG_CANCEL);
181 				__ASSERT(res == 0,
182 					"k_poll_signal_raise() failed: %d",
183 					res);
184 			}
185 		}
186 	}
187 
188 	spair->remote = -1;
189 
190 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_CANCEL);
191 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
192 
193 	if (remote != NULL && have_remote_sem) {
194 		k_sem_give(&remote->sem);
195 	}
196 
197 	/* ensure no private information is released to the memory pool */
198 	memset(spair, 0, sizeof(*spair));
199 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
200 	k_mem_slab_free(&spair_slab, (void *)spair);
201 #elif CONFIG_USERSPACE
202 	k_object_free(spair);
203 #else
204 	k_free(spair);
205 #endif
206 }
207 
208 /**
209  * Create a @ref spair (1/2 of a socketpair)
210  *
211  * The idea is to call this twice, but store the "local" side in the
212  * @ref spair.remote field initially.
213  *
214  * If both allocations are successful, then swap the @ref spair.remote
215  * fields in the two @ref spair instances.
216  */
spair_new(void)217 static struct spair *spair_new(void)
218 {
219 	struct spair *spair;
220 	int res;
221 
222 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
223 
224 	res = k_mem_slab_alloc(&spair_slab, (void **) &spair, K_NO_WAIT);
225 	if (res != 0) {
226 		spair = NULL;
227 	}
228 
229 #elif CONFIG_USERSPACE
230 	struct k_object *zo = k_object_create_dynamic(sizeof(*spair));
231 
232 	if (zo == NULL) {
233 		spair = NULL;
234 	} else {
235 		spair = zo->name;
236 		zo->type = K_OBJ_NET_SOCKET;
237 	}
238 #else
239 	spair = k_malloc(sizeof(*spair));
240 #endif
241 	if (spair == NULL) {
242 		errno = ENOMEM;
243 		goto out;
244 	}
245 	memset(spair, 0, sizeof(*spair));
246 
247 	/* initialize any non-zero default values */
248 	spair->remote = -1;
249 	spair->flags = SPAIR_FLAGS_DEFAULT;
250 
251 	k_sem_init(&spair->sem, 1, 1);
252 	ring_buf_init(&spair->recv_q, sizeof(spair->buf), spair->buf);
253 	k_poll_signal_init(&spair->readable);
254 	k_poll_signal_init(&spair->writeable);
255 
256 	/* A new socket is always writeable after creation */
257 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
258 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
259 
260 	spair->remote = zvfs_reserve_fd();
261 	if (spair->remote == -1) {
262 		errno = ENFILE;
263 		goto cleanup;
264 	}
265 
266 	zvfs_finalize_typed_fd(spair->remote, spair,
267 			       (const struct fd_op_vtable *)&spair_fd_op_vtable, ZVFS_MODE_IFSOCK);
268 
269 	goto out;
270 
271 cleanup:
272 	spair_delete(spair);
273 	spair = NULL;
274 
275 out:
276 	return spair;
277 }
278 
z_impl_zsock_socketpair(int family,int type,int proto,int * sv)279 int z_impl_zsock_socketpair(int family, int type, int proto, int *sv)
280 {
281 	int res;
282 	size_t i;
283 	struct spair *obj[2] = {};
284 
285 	SYS_PORT_TRACING_FUNC_ENTER(socket, socketpair, family, type, proto, sv);
286 
287 	if (family != NET_AF_UNIX) {
288 		errno = EAFNOSUPPORT;
289 		res = -1;
290 		goto errout;
291 	}
292 
293 	if (type != NET_SOCK_STREAM) {
294 		errno = EPROTOTYPE;
295 		res = -1;
296 		goto errout;
297 	}
298 
299 	if (proto != 0) {
300 		errno = EPROTONOSUPPORT;
301 		res = -1;
302 		goto errout;
303 	}
304 
305 	if (sv == NULL) {
306 		/* not listed in normative spec, but mimics Linux behaviour */
307 		errno = EFAULT;
308 		res = -1;
309 		goto errout;
310 	}
311 
312 	for (i = 0; i < 2; ++i) {
313 		obj[i] = spair_new();
314 		if (!obj[i]) {
315 			res = -1;
316 			goto cleanup;
317 		}
318 	}
319 
320 	/* connect the two endpoints */
321 	swap32(&obj[0]->remote, &obj[1]->remote);
322 
323 	for (i = 0; i < 2; ++i) {
324 		sv[i] = obj[i]->remote;
325 		k_sem_give(&obj[0]->sem);
326 	}
327 
328 	SYS_PORT_TRACING_FUNC_EXIT(socket, socketpair, sv[0], sv[1], 0);
329 
330 	return 0;
331 
332 cleanup:
333 	for (i = 0; i < 2; ++i) {
334 		spair_delete(obj[i]);
335 	}
336 
337 errout:
338 	SYS_PORT_TRACING_FUNC_EXIT(socket, socketpair, -1, -1, -errno);
339 
340 	return res;
341 }
342 
343 #ifdef CONFIG_USERSPACE
z_vrfy_zsock_socketpair(int family,int type,int proto,int * sv)344 int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv)
345 {
346 	int ret;
347 	int tmp[2];
348 
349 	if (!sv || K_SYSCALL_MEMORY_WRITE(sv, sizeof(tmp)) != 0) {
350 		/* not listed in normative spec, but mimics linux behaviour */
351 		errno = EFAULT;
352 		ret = -1;
353 		goto out;
354 	}
355 
356 	ret = z_impl_zsock_socketpair(family, type, proto, tmp);
357 	if (ret == 0) {
358 		K_OOPS(k_usermode_to_copy(sv, tmp, sizeof(tmp)));
359 	}
360 
361 out:
362 	return ret;
363 }
364 
365 #include <zephyr/syscalls/zsock_socketpair_mrsh.c>
366 #endif /* CONFIG_USERSPACE */
367 
368 /**
369  * Write data to one end of a @ref spair
370  *
371  * Data written on one file descriptor of a socketpair can be read at the
372  * other end using common POSIX calls such as read(2) or recv(2).
373  *
374  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
375  * this function will return immediately. If no data was written on a
376  * non-blocking file descriptor, then -1 will be returned and @ref errno will
377  * be set to @ref EAGAIN.
378  *
379  * Blocking write operations occur when the @ref O_NONBLOCK flag is @em not
380  * set and there is insufficient space in the @em remote @ref spair.pipe.
381  *
382  * Such a blocking write will suspend execution of the current thread until
383  * one of two possible results is received on the @em remote
384  * @ref spair.writeable:
385  *
386  * 1) @ref SPAIR_SIG_DATA - data has been read from the @em remote
387  *    @ref spair.pipe. Thus, allowing more data to be written.
388  *
389  * 2) @ref SPAIR_SIG_CANCEL - the @em remote socketpair endpoint was closed
390  *    Receipt of this result is analogous to SIGPIPE from POSIX
391  *    ("Write on a pipe with no one to read it."). In this case, the function
392  *    will return -1 and set @ref errno to @ref EPIPE.
393  *
394  * @param obj the address of an @ref spair object cast to `void *`
395  * @param buffer the buffer to write
396  * @param count the number of bytes to write from @p buffer
397  *
398  * @return on success, a number > 0 representing the number of bytes written
399  * @return -1 on error, with @ref errno set appropriately.
400  */
spair_write(void * obj,const void * buffer,size_t count)401 static ssize_t spair_write(void *obj, const void *buffer, size_t count)
402 {
403 	int res;
404 	size_t avail;
405 	bool is_nonblock;
406 	size_t bytes_written;
407 	bool have_local_sem = false;
408 	bool have_remote_sem = false;
409 	bool will_block = false;
410 	struct spair *const spair = (struct spair *)obj;
411 	struct spair *remote = NULL;
412 
413 	if (obj == NULL || buffer == NULL || count == 0) {
414 		errno = EINVAL;
415 		res = -1;
416 		goto out;
417 	}
418 
419 	res = k_sem_take(&spair->sem, K_NO_WAIT);
420 	is_nonblock = sock_is_nonblock(spair);
421 	if (res < 0) {
422 		if (is_nonblock) {
423 			errno = EAGAIN;
424 			res = -1;
425 			goto out;
426 		}
427 
428 		res = k_sem_take(&spair->sem, K_FOREVER);
429 		if (res < 0) {
430 			errno = -res;
431 			res = -1;
432 			goto out;
433 		}
434 		is_nonblock = sock_is_nonblock(spair);
435 	}
436 
437 	have_local_sem = true;
438 
439 	remote = zvfs_get_fd_obj(spair->remote,
440 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
441 
442 	if (remote == NULL) {
443 		errno = EPIPE;
444 		res = -1;
445 		goto out;
446 	}
447 
448 	res = k_sem_take(&remote->sem, K_NO_WAIT);
449 	if (res < 0) {
450 		if (is_nonblock) {
451 			errno = EAGAIN;
452 			res = -1;
453 			goto out;
454 		}
455 		res = k_sem_take(&remote->sem, K_FOREVER);
456 		if (res < 0) {
457 			errno = -res;
458 			res = -1;
459 			goto out;
460 		}
461 	}
462 
463 	have_remote_sem = true;
464 
465 	avail = spair_write_avail(spair);
466 
467 	if (avail == 0) {
468 		if (is_nonblock) {
469 			errno = EAGAIN;
470 			res = -1;
471 			goto out;
472 		}
473 		will_block = true;
474 	}
475 
476 	if (will_block) {
477 		if (k_is_in_isr()) {
478 			errno = EAGAIN;
479 			res = -1;
480 			goto out;
481 		}
482 
483 		for (int signaled = false, result = -1; !signaled;
484 			result = -1) {
485 
486 			struct k_poll_event events[] = {
487 				K_POLL_EVENT_INITIALIZER(
488 					K_POLL_TYPE_SIGNAL,
489 					K_POLL_MODE_NOTIFY_ONLY,
490 					&remote->writeable),
491 			};
492 
493 			k_sem_give(&remote->sem);
494 			have_remote_sem = false;
495 
496 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
497 			if (res < 0) {
498 				errno = -res;
499 				res = -1;
500 				goto out;
501 			}
502 
503 			remote = zvfs_get_fd_obj(spair->remote,
504 				(const struct fd_op_vtable *)
505 				&spair_fd_op_vtable, 0);
506 
507 			if (remote == NULL) {
508 				errno = EPIPE;
509 				res = -1;
510 				goto out;
511 			}
512 
513 			res = k_sem_take(&remote->sem, K_FOREVER);
514 			if (res < 0) {
515 				errno = -res;
516 				res = -1;
517 				goto out;
518 			}
519 
520 			have_remote_sem = true;
521 
522 			k_poll_signal_check(&remote->writeable, &signaled,
523 					    &result);
524 			if (!signaled) {
525 				continue;
526 			}
527 
528 			switch (result) {
529 				case SPAIR_SIG_DATA: {
530 					break;
531 				}
532 
533 				case SPAIR_SIG_CANCEL: {
534 					errno = EPIPE;
535 					res = -1;
536 					goto out;
537 				}
538 
539 				default: {
540 					__ASSERT(false,
541 						"unrecognized result: %d",
542 						result);
543 					continue;
544 				}
545 			}
546 
547 			/* SPAIR_SIG_DATA was received */
548 			break;
549 		}
550 	}
551 
552 	bytes_written = ring_buf_put(&remote->recv_q, (void *)buffer, count);
553 	if (spair_write_avail(spair) == 0) {
554 		k_poll_signal_reset(&remote->writeable);
555 	}
556 
557 	res = k_poll_signal_raise(&remote->readable, SPAIR_SIG_DATA);
558 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
559 
560 	res = bytes_written;
561 
562 out:
563 
564 	if (remote != NULL && have_remote_sem) {
565 		k_sem_give(&remote->sem);
566 	}
567 	if (spair != NULL && have_local_sem) {
568 		k_sem_give(&spair->sem);
569 	}
570 
571 	return res;
572 }
573 
574 /**
575  * Read data from one end of a @ref spair
576  *
577  * Data written on one file descriptor of a socketpair (with e.g. write(2) or
578  * send(2)) can be read at the other end using common POSIX calls such as
579  * read(2) or recv(2).
580  *
581  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
582  * this function will return immediately. If no data was read from a
583  * non-blocking file descriptor, then -1 will be returned and @ref errno will
584  * be set to @ref EAGAIN.
585  *
586  * Blocking read operations occur when the @ref O_NONBLOCK flag is @em not set
587  * and there are no bytes to read in the @em local @ref spair.pipe.
588  *
589  * Such a blocking read will suspend execution of the current thread until
590  * one of two possible results is received on the @em local
591  * @ref spair.readable:
592  *
593  * -# @ref SPAIR_SIG_DATA - data has been written to the @em local
594  *    @ref spair.pipe. Thus, allowing more data to be read.
595  *
596  * -# @ref SPAIR_SIG_CANCEL - read of the @em local @spair.pipe
597  *    must be cancelled for some reason (e.g. the file descriptor will be
598  *    closed imminently). In this case, the function will return -1 and set
599  *    @ref errno to @ref EINTR.
600  *
601  * @param obj the address of an @ref spair object cast to `void *`
602  * @param buffer the buffer in which to read
603  * @param count the number of bytes to read
604  *
605  * @return on success, a number > 0 representing the number of bytes written
606  * @return -1 on error, with @ref errno set appropriately.
607  */
spair_read(void * obj,void * buffer,size_t count)608 static ssize_t spair_read(void *obj, void *buffer, size_t count)
609 {
610 	int res;
611 	bool is_connected;
612 	size_t avail;
613 	bool is_nonblock;
614 	size_t bytes_read;
615 	bool have_local_sem = false;
616 	bool will_block = false;
617 	struct spair *const spair = (struct spair *)obj;
618 
619 	if (obj == NULL || buffer == NULL || count == 0) {
620 		errno = EINVAL;
621 		res = -1;
622 		goto out;
623 	}
624 
625 	res = k_sem_take(&spair->sem, K_NO_WAIT);
626 	is_nonblock = sock_is_nonblock(spair);
627 	if (res < 0) {
628 		if (is_nonblock) {
629 			errno = EAGAIN;
630 			res = -1;
631 			goto out;
632 		}
633 
634 		res = k_sem_take(&spair->sem, K_FOREVER);
635 		if (res < 0) {
636 			errno = -res;
637 			res = -1;
638 			goto out;
639 		}
640 		is_nonblock = sock_is_nonblock(spair);
641 	}
642 
643 	have_local_sem = true;
644 
645 	is_connected = sock_is_connected(spair);
646 	avail = spair_read_avail(spair);
647 
648 	if (avail == 0) {
649 		if (!is_connected) {
650 			/* signal EOF */
651 			res = 0;
652 			goto out;
653 		}
654 
655 		if (is_nonblock) {
656 			errno = EAGAIN;
657 			res = -1;
658 			goto out;
659 		}
660 
661 		will_block = true;
662 	}
663 
664 	if (will_block) {
665 		if (k_is_in_isr()) {
666 			errno = EAGAIN;
667 			res = -1;
668 			goto out;
669 		}
670 
671 		for (int signaled = false, result = -1; !signaled;
672 			result = -1) {
673 
674 			struct k_poll_event events[] = {
675 				K_POLL_EVENT_INITIALIZER(
676 					K_POLL_TYPE_SIGNAL,
677 					K_POLL_MODE_NOTIFY_ONLY,
678 					&spair->readable
679 				),
680 			};
681 
682 			k_sem_give(&spair->sem);
683 			have_local_sem = false;
684 
685 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
686 			__ASSERT(res == 0, "k_poll() failed: %d", res);
687 
688 			res = k_sem_take(&spair->sem, K_FOREVER);
689 			__ASSERT(res == 0, "failed to take local sem: %d", res);
690 
691 			have_local_sem = true;
692 
693 			k_poll_signal_check(&spair->readable, &signaled,
694 					    &result);
695 			if (!signaled) {
696 				continue;
697 			}
698 
699 			switch (result) {
700 				case SPAIR_SIG_DATA: {
701 					break;
702 				}
703 
704 				case SPAIR_SIG_CANCEL: {
705 					errno = EPIPE;
706 					res = -1;
707 					goto out;
708 				}
709 
710 				default: {
711 					__ASSERT(false,
712 						"unrecognized result: %d",
713 						result);
714 					continue;
715 				}
716 			}
717 
718 			/* SPAIR_SIG_DATA was received */
719 			break;
720 		}
721 	}
722 
723 	bytes_read = ring_buf_get(&spair->recv_q, (void *)buffer, count);
724 	if (spair_read_avail(spair) == 0 && !sock_is_eof(spair)) {
725 		k_poll_signal_reset(&spair->readable);
726 	}
727 
728 	if (is_connected) {
729 		res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
730 		__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
731 	}
732 
733 	res = bytes_read;
734 
735 out:
736 
737 	if (spair != NULL && have_local_sem) {
738 		k_sem_give(&spair->sem);
739 	}
740 
741 	return res;
742 }
743 
zsock_poll_prepare_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)744 static int zsock_poll_prepare_ctx(struct spair *const spair,
745 				  struct zsock_pollfd *const pfd,
746 				  struct k_poll_event **pev,
747 				  struct k_poll_event *pev_end)
748 {
749 	int res;
750 
751 	struct spair *remote = NULL;
752 	bool have_remote_sem = false;
753 
754 	if (pfd->events & ZSOCK_POLLIN) {
755 
756 		/* Tell poll() to short-circuit wait */
757 		if (sock_is_eof(spair)) {
758 			res = -EALREADY;
759 			goto out;
760 		}
761 
762 		if (*pev == pev_end) {
763 			res = -ENOMEM;
764 			goto out;
765 		}
766 
767 		/* Wait until data has been written to the local end */
768 		(*pev)->obj = &spair->readable;
769 	}
770 
771 	if (pfd->events & ZSOCK_POLLOUT) {
772 
773 		/* Tell poll() to short-circuit wait */
774 		if (!sock_is_connected(spair)) {
775 			res = -EALREADY;
776 			goto out;
777 		}
778 
779 		if (*pev == pev_end) {
780 			res = -ENOMEM;
781 			goto out;
782 		}
783 
784 		remote = zvfs_get_fd_obj(spair->remote,
785 			(const struct fd_op_vtable *)
786 			&spair_fd_op_vtable, 0);
787 
788 		__ASSERT(remote != NULL, "remote is NULL");
789 
790 		res = k_sem_take(&remote->sem, K_FOREVER);
791 		if (res < 0) {
792 			goto out;
793 		}
794 
795 		have_remote_sem = true;
796 
797 		/* Wait until the recv queue on the remote end is no longer full */
798 		(*pev)->obj = &remote->writeable;
799 	}
800 
801 	(*pev)->type = K_POLL_TYPE_SIGNAL;
802 	(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
803 	(*pev)->state = K_POLL_STATE_NOT_READY;
804 
805 	(*pev)++;
806 
807 	res = 0;
808 
809 out:
810 
811 	if (remote != NULL && have_remote_sem) {
812 		k_sem_give(&remote->sem);
813 	}
814 
815 	return res;
816 }
817 
zsock_poll_update_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev)818 static int zsock_poll_update_ctx(struct spair *const spair,
819 				 struct zsock_pollfd *const pfd,
820 				 struct k_poll_event **pev)
821 {
822 	int res;
823 	int signaled;
824 	int result;
825 	struct spair *remote = NULL;
826 	bool have_remote_sem = false;
827 
828 	if (pfd->events & ZSOCK_POLLOUT) {
829 		if (!sock_is_connected(spair)) {
830 			pfd->revents |= ZSOCK_POLLHUP;
831 			goto pollout_done;
832 		}
833 
834 		remote = zvfs_get_fd_obj(spair->remote,
835 			(const struct fd_op_vtable *) &spair_fd_op_vtable, 0);
836 
837 		__ASSERT(remote != NULL, "remote is NULL");
838 
839 		res = k_sem_take(&remote->sem, K_FOREVER);
840 		if (res < 0) {
841 			/* if other end is deleted, this might occur */
842 			goto pollout_done;
843 		}
844 
845 		have_remote_sem = true;
846 
847 		if (spair_write_avail(spair) > 0) {
848 			pfd->revents |= ZSOCK_POLLOUT;
849 			goto pollout_done;
850 		}
851 
852 		/* check to see if op was canceled */
853 		signaled = false;
854 		k_poll_signal_check(&remote->writeable, &signaled, &result);
855 		if (signaled) {
856 			/* Cannot be SPAIR_SIG_DATA, because
857 			 * spair_write_avail() would have
858 			 * returned 0
859 			 */
860 			__ASSERT(result == SPAIR_SIG_CANCEL,
861 				"invalid result %d", result);
862 			pfd->revents |= ZSOCK_POLLHUP;
863 		}
864 	}
865 
866 pollout_done:
867 
868 	if (pfd->events & ZSOCK_POLLIN) {
869 		if (sock_is_eof(spair)) {
870 			pfd->revents |= ZSOCK_POLLIN;
871 			goto pollin_done;
872 		}
873 
874 		if (spair_read_avail(spair) > 0) {
875 			pfd->revents |= ZSOCK_POLLIN;
876 			goto pollin_done;
877 		}
878 
879 		/* check to see if op was canceled */
880 		signaled = false;
881 		k_poll_signal_check(&spair->readable, &signaled, &result);
882 		if (signaled) {
883 			/* Cannot be SPAIR_SIG_DATA, because
884 			 * spair_read_avail() would have
885 			 * returned 0
886 			 */
887 			__ASSERT(result == SPAIR_SIG_CANCEL,
888 					 "invalid result %d", result);
889 			pfd->revents |= ZSOCK_POLLIN;
890 		}
891 	}
892 
893 pollin_done:
894 	res = 0;
895 
896 	(*pev)++;
897 
898 	if (remote != NULL && have_remote_sem) {
899 		k_sem_give(&remote->sem);
900 	}
901 
902 	return res;
903 }
904 
spair_ioctl(void * obj,unsigned int request,va_list args)905 static int spair_ioctl(void *obj, unsigned int request, va_list args)
906 {
907 	int res;
908 	struct zsock_pollfd *pfd;
909 	struct k_poll_event **pev;
910 	struct k_poll_event *pev_end;
911 	int flags = 0;
912 	bool have_local_sem = false;
913 	struct spair *const spair = (struct spair *)obj;
914 
915 	if (spair == NULL) {
916 		errno = EINVAL;
917 		res = -1;
918 		goto out;
919 	}
920 
921 	/* The local sem is always taken in this function. If a subsequent
922 	 * function call requires the remote sem, it must acquire and free the
923 	 * remote sem.
924 	 */
925 	res = k_sem_take(&spair->sem, K_FOREVER);
926 	__ASSERT(res == 0, "failed to take local sem: %d", res);
927 
928 	have_local_sem = true;
929 
930 	switch (request) {
931 		case ZVFS_F_GETFL: {
932 			if (sock_is_nonblock(spair)) {
933 				flags |= ZVFS_O_NONBLOCK;
934 			}
935 
936 			res = flags;
937 			goto out;
938 		}
939 
940 		case ZVFS_F_SETFL: {
941 			flags = va_arg(args, int);
942 
943 			if (flags & ZVFS_O_NONBLOCK) {
944 				spair->flags |= SPAIR_FLAG_NONBLOCK;
945 			} else {
946 				spair->flags &= ~SPAIR_FLAG_NONBLOCK;
947 			}
948 
949 			res = 0;
950 			goto out;
951 		}
952 
953 		case ZFD_IOCTL_FIONBIO: {
954 			spair->flags |= SPAIR_FLAG_NONBLOCK;
955 			res = 0;
956 			goto out;
957 		}
958 
959 		case ZFD_IOCTL_FIONREAD: {
960 			int *nbytes;
961 
962 			nbytes = va_arg(args, int *);
963 			*nbytes = spair_read_avail(spair);
964 
965 			res = 0;
966 			goto out;
967 		}
968 
969 		case ZFD_IOCTL_POLL_PREPARE: {
970 			pfd = va_arg(args, struct zsock_pollfd *);
971 			pev = va_arg(args, struct k_poll_event **);
972 			pev_end = va_arg(args, struct k_poll_event *);
973 
974 			res = zsock_poll_prepare_ctx(obj, pfd, pev, pev_end);
975 			goto out;
976 		}
977 
978 		case ZFD_IOCTL_POLL_UPDATE: {
979 			pfd = va_arg(args, struct zsock_pollfd *);
980 			pev = va_arg(args, struct k_poll_event **);
981 
982 			res = zsock_poll_update_ctx(obj, pfd, pev);
983 			goto out;
984 		}
985 
986 		default: {
987 			errno = EOPNOTSUPP;
988 			res = -1;
989 			goto out;
990 		}
991 	}
992 
993 out:
994 	if (spair != NULL && have_local_sem) {
995 		k_sem_give(&spair->sem);
996 	}
997 
998 	return res;
999 }
1000 
spair_bind(void * obj,const struct net_sockaddr * addr,net_socklen_t addrlen)1001 static int spair_bind(void *obj, const struct net_sockaddr *addr,
1002 		      net_socklen_t addrlen)
1003 {
1004 	ARG_UNUSED(obj);
1005 	ARG_UNUSED(addr);
1006 	ARG_UNUSED(addrlen);
1007 
1008 	errno = EISCONN;
1009 	return -1;
1010 }
1011 
spair_connect(void * obj,const struct net_sockaddr * addr,net_socklen_t addrlen)1012 static int spair_connect(void *obj, const struct net_sockaddr *addr,
1013 			 net_socklen_t addrlen)
1014 {
1015 	ARG_UNUSED(obj);
1016 	ARG_UNUSED(addr);
1017 	ARG_UNUSED(addrlen);
1018 
1019 	errno = EISCONN;
1020 	return -1;
1021 }
1022 
spair_listen(void * obj,int backlog)1023 static int spair_listen(void *obj, int backlog)
1024 {
1025 	ARG_UNUSED(obj);
1026 	ARG_UNUSED(backlog);
1027 
1028 	errno = EINVAL;
1029 	return -1;
1030 }
1031 
spair_accept(void * obj,struct net_sockaddr * addr,net_socklen_t * addrlen)1032 static int spair_accept(void *obj, struct net_sockaddr *addr,
1033 			net_socklen_t *addrlen)
1034 {
1035 	ARG_UNUSED(obj);
1036 	ARG_UNUSED(addr);
1037 	ARG_UNUSED(addrlen);
1038 
1039 	errno = EOPNOTSUPP;
1040 	return -1;
1041 }
1042 
spair_sendto(void * obj,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)1043 static ssize_t spair_sendto(void *obj, const void *buf, size_t len,
1044 			    int flags, const struct net_sockaddr *dest_addr,
1045 			    net_socklen_t addrlen)
1046 {
1047 	ARG_UNUSED(flags);
1048 	ARG_UNUSED(dest_addr);
1049 	ARG_UNUSED(addrlen);
1050 
1051 	return spair_write(obj, buf, len);
1052 }
1053 
spair_sendmsg(void * obj,const struct net_msghdr * msg,int flags)1054 static ssize_t spair_sendmsg(void *obj, const struct net_msghdr *msg,
1055 			     int flags)
1056 {
1057 	ARG_UNUSED(flags);
1058 
1059 	int res;
1060 	size_t len = 0;
1061 	bool is_connected;
1062 	size_t avail;
1063 	bool is_nonblock;
1064 	struct spair *const spair = (struct spair *)obj;
1065 
1066 	if (spair == NULL || msg == NULL) {
1067 		errno = EINVAL;
1068 		res = -1;
1069 		goto out;
1070 	}
1071 
1072 	is_connected = sock_is_connected(spair);
1073 	avail = is_connected ? spair_write_avail(spair) : 0;
1074 	is_nonblock = sock_is_nonblock(spair);
1075 
1076 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1077 		/* check & msg->msg_iov[i]? */
1078 		/* check & msg->msg_iov[i].iov_base? */
1079 		len += msg->msg_iov[i].iov_len;
1080 	}
1081 
1082 	if (!is_connected) {
1083 		errno = EPIPE;
1084 		res = -1;
1085 		goto out;
1086 	}
1087 
1088 	if (len == 0) {
1089 		res = 0;
1090 		goto out;
1091 	}
1092 
1093 	if (len > avail && is_nonblock) {
1094 		errno = EMSGSIZE;
1095 		res = -1;
1096 		goto out;
1097 	}
1098 
1099 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1100 		res = spair_write(spair, msg->msg_iov[i].iov_base,
1101 			msg->msg_iov[i].iov_len);
1102 		if (res == -1) {
1103 			goto out;
1104 		}
1105 	}
1106 
1107 	res = len;
1108 
1109 out:
1110 	return res;
1111 }
1112 
spair_recvfrom(void * obj,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)1113 static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len,
1114 			      int flags, struct net_sockaddr *src_addr,
1115 			      net_socklen_t *addrlen)
1116 {
1117 	(void)flags;
1118 	(void)src_addr;
1119 	(void)addrlen;
1120 
1121 	if (addrlen != NULL) {
1122 		/* Protocol (PF_UNIX) does not support addressing with connected
1123 		 * sockets and, therefore, it is unspecified behaviour to modify
1124 		 * src_addr. However, it would be ambiguous to leave addrlen
1125 		 * untouched if the user expects it to be updated. It is not
1126 		 * mentioned that modifying addrlen is unspecified. Therefore
1127 		 * we choose to eliminate ambiguity.
1128 		 *
1129 		 * Setting it to zero mimics Linux's behaviour.
1130 		 */
1131 		*addrlen = 0;
1132 	}
1133 
1134 	return spair_read(obj, buf, max_len);
1135 }
1136 
spair_getsockopt(void * obj,int level,int optname,void * optval,net_socklen_t * optlen)1137 static int spair_getsockopt(void *obj, int level, int optname,
1138 			    void *optval, net_socklen_t *optlen)
1139 {
1140 	ARG_UNUSED(obj);
1141 	ARG_UNUSED(level);
1142 	ARG_UNUSED(optname);
1143 	ARG_UNUSED(optval);
1144 	ARG_UNUSED(optlen);
1145 
1146 	errno = ENOPROTOOPT;
1147 	return -1;
1148 }
1149 
spair_setsockopt(void * obj,int level,int optname,const void * optval,net_socklen_t optlen)1150 static int spair_setsockopt(void *obj, int level, int optname,
1151 			    const void *optval, net_socklen_t optlen)
1152 {
1153 	ARG_UNUSED(obj);
1154 	ARG_UNUSED(level);
1155 	ARG_UNUSED(optname);
1156 	ARG_UNUSED(optval);
1157 	ARG_UNUSED(optlen);
1158 
1159 	errno = ENOPROTOOPT;
1160 	return -1;
1161 }
1162 
spair_close(void * obj,int fd)1163 static int spair_close(void *obj, int fd)
1164 {
1165 	struct spair *const spair = (struct spair *)obj;
1166 	int res;
1167 
1168 	ARG_UNUSED(fd);
1169 
1170 	res = k_sem_take(&spair->sem, K_FOREVER);
1171 	__ASSERT(res == 0, "failed to take local sem: %d", res);
1172 
1173 	/* disconnect the remote endpoint */
1174 	spair_delete(spair);
1175 
1176 	/* Note that the semaphore released already so need to do it here */
1177 
1178 	return 0;
1179 }
1180 
1181 static const struct socket_op_vtable spair_fd_op_vtable = {
1182 	.fd_vtable = {
1183 		.read = spair_read,
1184 		.write = spair_write,
1185 		.close2 = spair_close,
1186 		.ioctl = spair_ioctl,
1187 	},
1188 	.bind = spair_bind,
1189 	.connect = spair_connect,
1190 	.listen = spair_listen,
1191 	.accept = spair_accept,
1192 	.sendto = spair_sendto,
1193 	.sendmsg = spair_sendmsg,
1194 	.recvfrom = spair_recvfrom,
1195 	.getsockopt = spair_getsockopt,
1196 	.setsockopt = spair_setsockopt,
1197 };
1198