1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/kernel.h>
8 #include <zephyr/net/socket.h>
9 #include <zephyr/posix/fcntl.h>
10 #include <zephyr/internal/syscall_handler.h>
11 #include <zephyr/sys/__assert.h>
12 #include <zephyr/sys/fdtable.h>
13 
14 #include "sockets_internal.h"
15 
16 enum {
17 	SPAIR_SIG_CANCEL, /**< operation has been canceled */
18 	SPAIR_SIG_DATA,   /**< @ref spair.recv_q has been updated */
19 };
20 
21 enum {
22 	SPAIR_FLAG_NONBLOCK = (1 << 0), /**< socket is non-blocking */
23 };
24 
25 #define SPAIR_FLAGS_DEFAULT 0
26 
27 /**
28  * Socketpair endpoint structure
29  *
30  * This structure represents one half of a socketpair (an 'endpoint').
31  *
32  * The implementation strives for compatibility with socketpair(2).
33  *
34  * Resources contained within this structure are said to be 'local', while
35  * resources contained within the other half of the socketpair (or other
36  * endpoint) are said to be 'remote'.
37  *
38  * Theory of operation:
39  * - each end of a socketpair owns a @a recv_q
40  * - since there is no write queue, data is either written or not
41  * - read and write operations may return partial transfers
42  * - read operations may block if the local @a recv_q is empty
43  * - write operations may block if the remote @a recv_q is full
44  * - each endpoint may be blocking or non-blocking
45  */
46 __net_socket struct spair {
47 	int remote; /**< the remote endpoint file descriptor */
48 	uint32_t flags; /**< status and option bits */
49 	struct k_sem sem; /**< semaphore for exclusive structure access */
50 	struct ring_buf recv_q;
51 	/** indicates local @a recv_q isn't empty */
52 	struct k_poll_signal readable;
53 	/** indicates local @a recv_q isn't full */
54 	struct k_poll_signal writeable;
55 	/** buffer for @a recv_q recv_q */
56 	uint8_t buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE];
57 };
58 
59 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
60 K_MEM_SLAB_DEFINE_STATIC(spair_slab, sizeof(struct spair), CONFIG_NET_SOCKETPAIR_MAX * 2,
61 			 __alignof__(struct spair));
62 #endif /* CONFIG_NET_SOCKETPAIR_STATIC */
63 
64 /* forward declaration */
65 static const struct socket_op_vtable spair_fd_op_vtable;
66 
67 #undef sock_is_nonblock
68 /** Determine if a @ref spair is in non-blocking mode */
sock_is_nonblock(const struct spair * spair)69 static inline bool sock_is_nonblock(const struct spair *spair)
70 {
71 	return !!(spair->flags & SPAIR_FLAG_NONBLOCK);
72 }
73 
74 /** Determine if a @ref spair is connected */
sock_is_connected(const struct spair * spair)75 static inline bool sock_is_connected(const struct spair *spair)
76 {
77 	const struct spair *remote = zvfs_get_fd_obj(spair->remote,
78 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
79 
80 	if (remote == NULL) {
81 		return false;
82 	}
83 
84 	return true;
85 }
86 
87 #undef sock_is_eof
88 /** Determine if a @ref spair has encountered end-of-file */
sock_is_eof(const struct spair * spair)89 static inline bool sock_is_eof(const struct spair *spair)
90 {
91 	return !sock_is_connected(spair);
92 }
93 
94 /**
95  * Determine bytes available to write
96  *
97  * Specifically, this function calculates the number of bytes that may be
98  * written to a given @ref spair without blocking.
99  */
spair_write_avail(struct spair * spair)100 static inline size_t spair_write_avail(struct spair *spair)
101 {
102 	struct spair *const remote = zvfs_get_fd_obj(spair->remote,
103 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
104 
105 	if (remote == NULL) {
106 		return 0;
107 	}
108 
109 	return ring_buf_space_get(&remote->recv_q);
110 }
111 
112 /**
113  * Determine bytes available to read
114  *
115  * Specifically, this function calculates the number of bytes that may be
116  * read from a given @ref spair without blocking.
117  */
spair_read_avail(struct spair * spair)118 static inline size_t spair_read_avail(struct spair *spair)
119 {
120 	return ring_buf_size_get(&spair->recv_q);
121 }
122 
123 /** Swap two 32-bit integers */
swap32(uint32_t * a,uint32_t * b)124 static inline void swap32(uint32_t *a, uint32_t *b)
125 {
126 	uint32_t c;
127 
128 	c = *b;
129 	*b = *a;
130 	*a = c;
131 }
132 
133 /**
134  * Delete @param spair
135  *
136  * This function deletes one endpoint of a socketpair.
137  *
138  * Theory of operation:
139  * - we have a socketpair with two endpoints: A and B
140  * - we have two threads: T1 and T2
141  * - T1 operates on endpoint A
142  * - T2 operates on endpoint B
143  *
144  * There are two possible cases where a blocking operation must be notified
145  * when one endpoint is closed:
146  * -# T1 is blocked reading from A and T2 closes B
147  *    T1 waits on A's write signal. T2 triggers the remote
148  *    @ref spair.readable
149  * -# T1 is blocked writing to A and T2 closes B
150  *    T1 is waits on B's read signal. T2 triggers the local
151  *    @ref spair.writeable.
152  *
153  * If the remote endpoint is already closed, the former operation does not
154  * take place. Otherwise, the @ref spair.remote of the local endpoint is
155  * set to -1.
156  *
157  * If no threads are blocking on A, then the signals have no effect.
158  *
159  * The memory associated with the local endpoint is cleared and freed.
160  */
spair_delete(struct spair * spair)161 static void spair_delete(struct spair *spair)
162 {
163 	int res;
164 	struct spair *remote = NULL;
165 	bool have_remote_sem = false;
166 
167 	if (spair == NULL) {
168 		return;
169 	}
170 
171 	if (spair->remote != -1) {
172 		remote = zvfs_get_fd_obj(spair->remote,
173 			(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
174 
175 		if (remote != NULL) {
176 			res = k_sem_take(&remote->sem, K_FOREVER);
177 			if (res == 0) {
178 				have_remote_sem = true;
179 				remote->remote = -1;
180 				res = k_poll_signal_raise(&remote->readable,
181 					SPAIR_SIG_CANCEL);
182 				__ASSERT(res == 0,
183 					"k_poll_signal_raise() failed: %d",
184 					res);
185 			}
186 		}
187 	}
188 
189 	spair->remote = -1;
190 
191 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_CANCEL);
192 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
193 
194 	if (remote != NULL && have_remote_sem) {
195 		k_sem_give(&remote->sem);
196 	}
197 
198 	/* ensure no private information is released to the memory pool */
199 	memset(spair, 0, sizeof(*spair));
200 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
201 	k_mem_slab_free(&spair_slab, (void *)spair);
202 #elif CONFIG_USERSPACE
203 	k_object_free(spair);
204 #else
205 	k_free(spair);
206 #endif
207 }
208 
209 /**
210  * Create a @ref spair (1/2 of a socketpair)
211  *
212  * The idea is to call this twice, but store the "local" side in the
213  * @ref spair.remote field initially.
214  *
215  * If both allocations are successful, then swap the @ref spair.remote
216  * fields in the two @ref spair instances.
217  */
spair_new(void)218 static struct spair *spair_new(void)
219 {
220 	struct spair *spair;
221 	int res;
222 
223 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
224 
225 	res = k_mem_slab_alloc(&spair_slab, (void **) &spair, K_NO_WAIT);
226 	if (res != 0) {
227 		spair = NULL;
228 	}
229 
230 #elif CONFIG_USERSPACE
231 	struct k_object *zo = k_object_create_dynamic(sizeof(*spair));
232 
233 	if (zo == NULL) {
234 		spair = NULL;
235 	} else {
236 		spair = zo->name;
237 		zo->type = K_OBJ_NET_SOCKET;
238 	}
239 #else
240 	spair = k_malloc(sizeof(*spair));
241 #endif
242 	if (spair == NULL) {
243 		errno = ENOMEM;
244 		goto out;
245 	}
246 	memset(spair, 0, sizeof(*spair));
247 
248 	/* initialize any non-zero default values */
249 	spair->remote = -1;
250 	spair->flags = SPAIR_FLAGS_DEFAULT;
251 
252 	k_sem_init(&spair->sem, 1, 1);
253 	ring_buf_init(&spair->recv_q, sizeof(spair->buf), spair->buf);
254 	k_poll_signal_init(&spair->readable);
255 	k_poll_signal_init(&spair->writeable);
256 
257 	/* A new socket is always writeable after creation */
258 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
259 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
260 
261 	spair->remote = zvfs_reserve_fd();
262 	if (spair->remote == -1) {
263 		errno = ENFILE;
264 		goto cleanup;
265 	}
266 
267 	zvfs_finalize_typed_fd(spair->remote, spair,
268 			       (const struct fd_op_vtable *)&spair_fd_op_vtable, ZVFS_MODE_IFSOCK);
269 
270 	goto out;
271 
272 cleanup:
273 	spair_delete(spair);
274 	spair = NULL;
275 
276 out:
277 	return spair;
278 }
279 
z_impl_zsock_socketpair(int family,int type,int proto,int * sv)280 int z_impl_zsock_socketpair(int family, int type, int proto, int *sv)
281 {
282 	int res;
283 	size_t i;
284 	struct spair *obj[2] = {};
285 
286 	SYS_PORT_TRACING_OBJ_FUNC_ENTER(socket, socketpair, family, type, proto, sv);
287 
288 	if (family != AF_UNIX) {
289 		errno = EAFNOSUPPORT;
290 		res = -1;
291 		goto errout;
292 	}
293 
294 	if (type != SOCK_STREAM) {
295 		errno = EPROTOTYPE;
296 		res = -1;
297 		goto errout;
298 	}
299 
300 	if (proto != 0) {
301 		errno = EPROTONOSUPPORT;
302 		res = -1;
303 		goto errout;
304 	}
305 
306 	if (sv == NULL) {
307 		/* not listed in normative spec, but mimics Linux behaviour */
308 		errno = EFAULT;
309 		res = -1;
310 		goto errout;
311 	}
312 
313 	for (i = 0; i < 2; ++i) {
314 		obj[i] = spair_new();
315 		if (!obj[i]) {
316 			res = -1;
317 			goto cleanup;
318 		}
319 	}
320 
321 	/* connect the two endpoints */
322 	swap32(&obj[0]->remote, &obj[1]->remote);
323 
324 	for (i = 0; i < 2; ++i) {
325 		sv[i] = obj[i]->remote;
326 		k_sem_give(&obj[0]->sem);
327 	}
328 
329 	SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, socketpair, sv[0], sv[1], 0);
330 
331 	return 0;
332 
333 cleanup:
334 	for (i = 0; i < 2; ++i) {
335 		spair_delete(obj[i]);
336 	}
337 
338 errout:
339 	SYS_PORT_TRACING_OBJ_FUNC_EXIT(socket, socketpair, -1, -1, -errno);
340 
341 	return res;
342 }
343 
344 #ifdef CONFIG_USERSPACE
z_vrfy_zsock_socketpair(int family,int type,int proto,int * sv)345 int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv)
346 {
347 	int ret;
348 	int tmp[2];
349 
350 	if (!sv || K_SYSCALL_MEMORY_WRITE(sv, sizeof(tmp)) != 0) {
351 		/* not listed in normative spec, but mimics linux behaviour */
352 		errno = EFAULT;
353 		ret = -1;
354 		goto out;
355 	}
356 
357 	ret = z_impl_zsock_socketpair(family, type, proto, tmp);
358 	if (ret == 0) {
359 		K_OOPS(k_usermode_to_copy(sv, tmp, sizeof(tmp)));
360 	}
361 
362 out:
363 	return ret;
364 }
365 
366 #include <zephyr/syscalls/zsock_socketpair_mrsh.c>
367 #endif /* CONFIG_USERSPACE */
368 
369 /**
370  * Write data to one end of a @ref spair
371  *
372  * Data written on one file descriptor of a socketpair can be read at the
373  * other end using common POSIX calls such as read(2) or recv(2).
374  *
375  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
376  * this function will return immediately. If no data was written on a
377  * non-blocking file descriptor, then -1 will be returned and @ref errno will
378  * be set to @ref EAGAIN.
379  *
380  * Blocking write operations occur when the @ref O_NONBLOCK flag is @em not
381  * set and there is insufficient space in the @em remote @ref spair.pipe.
382  *
383  * Such a blocking write will suspend execution of the current thread until
384  * one of two possible results is received on the @em remote
385  * @ref spair.writeable:
386  *
387  * 1) @ref SPAIR_SIG_DATA - data has been read from the @em remote
388  *    @ref spair.pipe. Thus, allowing more data to be written.
389  *
390  * 2) @ref SPAIR_SIG_CANCEL - the @em remote socketpair endpoint was closed
391  *    Receipt of this result is analogous to SIGPIPE from POSIX
392  *    ("Write on a pipe with no one to read it."). In this case, the function
393  *    will return -1 and set @ref errno to @ref EPIPE.
394  *
395  * @param obj the address of an @ref spair object cast to `void *`
396  * @param buffer the buffer to write
397  * @param count the number of bytes to write from @p buffer
398  *
399  * @return on success, a number > 0 representing the number of bytes written
400  * @return -1 on error, with @ref errno set appropriately.
401  */
spair_write(void * obj,const void * buffer,size_t count)402 static ssize_t spair_write(void *obj, const void *buffer, size_t count)
403 {
404 	int res;
405 	size_t avail;
406 	bool is_nonblock;
407 	size_t bytes_written;
408 	bool have_local_sem = false;
409 	bool have_remote_sem = false;
410 	bool will_block = false;
411 	struct spair *const spair = (struct spair *)obj;
412 	struct spair *remote = NULL;
413 
414 	if (obj == NULL || buffer == NULL || count == 0) {
415 		errno = EINVAL;
416 		res = -1;
417 		goto out;
418 	}
419 
420 	res = k_sem_take(&spair->sem, K_NO_WAIT);
421 	is_nonblock = sock_is_nonblock(spair);
422 	if (res < 0) {
423 		if (is_nonblock) {
424 			errno = EAGAIN;
425 			res = -1;
426 			goto out;
427 		}
428 
429 		res = k_sem_take(&spair->sem, K_FOREVER);
430 		if (res < 0) {
431 			errno = -res;
432 			res = -1;
433 			goto out;
434 		}
435 		is_nonblock = sock_is_nonblock(spair);
436 	}
437 
438 	have_local_sem = true;
439 
440 	remote = zvfs_get_fd_obj(spair->remote,
441 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
442 
443 	if (remote == NULL) {
444 		errno = EPIPE;
445 		res = -1;
446 		goto out;
447 	}
448 
449 	res = k_sem_take(&remote->sem, K_NO_WAIT);
450 	if (res < 0) {
451 		if (is_nonblock) {
452 			errno = EAGAIN;
453 			res = -1;
454 			goto out;
455 		}
456 		res = k_sem_take(&remote->sem, K_FOREVER);
457 		if (res < 0) {
458 			errno = -res;
459 			res = -1;
460 			goto out;
461 		}
462 	}
463 
464 	have_remote_sem = true;
465 
466 	avail = spair_write_avail(spair);
467 
468 	if (avail == 0) {
469 		if (is_nonblock) {
470 			errno = EAGAIN;
471 			res = -1;
472 			goto out;
473 		}
474 		will_block = true;
475 	}
476 
477 	if (will_block) {
478 		if (k_is_in_isr()) {
479 			errno = EAGAIN;
480 			res = -1;
481 			goto out;
482 		}
483 
484 		for (int signaled = false, result = -1; !signaled;
485 			result = -1) {
486 
487 			struct k_poll_event events[] = {
488 				K_POLL_EVENT_INITIALIZER(
489 					K_POLL_TYPE_SIGNAL,
490 					K_POLL_MODE_NOTIFY_ONLY,
491 					&remote->writeable),
492 			};
493 
494 			k_sem_give(&remote->sem);
495 			have_remote_sem = false;
496 
497 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
498 			if (res < 0) {
499 				errno = -res;
500 				res = -1;
501 				goto out;
502 			}
503 
504 			remote = zvfs_get_fd_obj(spair->remote,
505 				(const struct fd_op_vtable *)
506 				&spair_fd_op_vtable, 0);
507 
508 			if (remote == NULL) {
509 				errno = EPIPE;
510 				res = -1;
511 				goto out;
512 			}
513 
514 			res = k_sem_take(&remote->sem, K_FOREVER);
515 			if (res < 0) {
516 				errno = -res;
517 				res = -1;
518 				goto out;
519 			}
520 
521 			have_remote_sem = true;
522 
523 			k_poll_signal_check(&remote->writeable, &signaled,
524 					    &result);
525 			if (!signaled) {
526 				continue;
527 			}
528 
529 			switch (result) {
530 				case SPAIR_SIG_DATA: {
531 					break;
532 				}
533 
534 				case SPAIR_SIG_CANCEL: {
535 					errno = EPIPE;
536 					res = -1;
537 					goto out;
538 				}
539 
540 				default: {
541 					__ASSERT(false,
542 						"unrecognized result: %d",
543 						result);
544 					continue;
545 				}
546 			}
547 
548 			/* SPAIR_SIG_DATA was received */
549 			break;
550 		}
551 	}
552 
553 	bytes_written = ring_buf_put(&remote->recv_q, (void *)buffer, count);
554 	if (spair_write_avail(spair) == 0) {
555 		k_poll_signal_reset(&remote->writeable);
556 	}
557 
558 	res = k_poll_signal_raise(&remote->readable, SPAIR_SIG_DATA);
559 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
560 
561 	res = bytes_written;
562 
563 out:
564 
565 	if (remote != NULL && have_remote_sem) {
566 		k_sem_give(&remote->sem);
567 	}
568 	if (spair != NULL && have_local_sem) {
569 		k_sem_give(&spair->sem);
570 	}
571 
572 	return res;
573 }
574 
575 /**
576  * Read data from one end of a @ref spair
577  *
578  * Data written on one file descriptor of a socketpair (with e.g. write(2) or
579  * send(2)) can be read at the other end using common POSIX calls such as
580  * read(2) or recv(2).
581  *
582  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
583  * this function will return immediately. If no data was read from a
584  * non-blocking file descriptor, then -1 will be returned and @ref errno will
585  * be set to @ref EAGAIN.
586  *
587  * Blocking read operations occur when the @ref O_NONBLOCK flag is @em not set
588  * and there are no bytes to read in the @em local @ref spair.pipe.
589  *
590  * Such a blocking read will suspend execution of the current thread until
591  * one of two possible results is received on the @em local
592  * @ref spair.readable:
593  *
594  * -# @ref SPAIR_SIG_DATA - data has been written to the @em local
595  *    @ref spair.pipe. Thus, allowing more data to be read.
596  *
597  * -# @ref SPAIR_SIG_CANCEL - read of the @em local @spair.pipe
598  *    must be cancelled for some reason (e.g. the file descriptor will be
599  *    closed imminently). In this case, the function will return -1 and set
600  *    @ref errno to @ref EINTR.
601  *
602  * @param obj the address of an @ref spair object cast to `void *`
603  * @param buffer the buffer in which to read
604  * @param count the number of bytes to read
605  *
606  * @return on success, a number > 0 representing the number of bytes written
607  * @return -1 on error, with @ref errno set appropriately.
608  */
spair_read(void * obj,void * buffer,size_t count)609 static ssize_t spair_read(void *obj, void *buffer, size_t count)
610 {
611 	int res;
612 	bool is_connected;
613 	size_t avail;
614 	bool is_nonblock;
615 	size_t bytes_read;
616 	bool have_local_sem = false;
617 	bool will_block = false;
618 	struct spair *const spair = (struct spair *)obj;
619 
620 	if (obj == NULL || buffer == NULL || count == 0) {
621 		errno = EINVAL;
622 		res = -1;
623 		goto out;
624 	}
625 
626 	res = k_sem_take(&spair->sem, K_NO_WAIT);
627 	is_nonblock = sock_is_nonblock(spair);
628 	if (res < 0) {
629 		if (is_nonblock) {
630 			errno = EAGAIN;
631 			res = -1;
632 			goto out;
633 		}
634 
635 		res = k_sem_take(&spair->sem, K_FOREVER);
636 		if (res < 0) {
637 			errno = -res;
638 			res = -1;
639 			goto out;
640 		}
641 		is_nonblock = sock_is_nonblock(spair);
642 	}
643 
644 	have_local_sem = true;
645 
646 	is_connected = sock_is_connected(spair);
647 	avail = spair_read_avail(spair);
648 
649 	if (avail == 0) {
650 		if (!is_connected) {
651 			/* signal EOF */
652 			res = 0;
653 			goto out;
654 		}
655 
656 		if (is_nonblock) {
657 			errno = EAGAIN;
658 			res = -1;
659 			goto out;
660 		}
661 
662 		will_block = true;
663 	}
664 
665 	if (will_block) {
666 		if (k_is_in_isr()) {
667 			errno = EAGAIN;
668 			res = -1;
669 			goto out;
670 		}
671 
672 		for (int signaled = false, result = -1; !signaled;
673 			result = -1) {
674 
675 			struct k_poll_event events[] = {
676 				K_POLL_EVENT_INITIALIZER(
677 					K_POLL_TYPE_SIGNAL,
678 					K_POLL_MODE_NOTIFY_ONLY,
679 					&spair->readable
680 				),
681 			};
682 
683 			k_sem_give(&spair->sem);
684 			have_local_sem = false;
685 
686 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
687 			__ASSERT(res == 0, "k_poll() failed: %d", res);
688 
689 			res = k_sem_take(&spair->sem, K_FOREVER);
690 			__ASSERT(res == 0, "failed to take local sem: %d", res);
691 
692 			have_local_sem = true;
693 
694 			k_poll_signal_check(&spair->readable, &signaled,
695 					    &result);
696 			if (!signaled) {
697 				continue;
698 			}
699 
700 			switch (result) {
701 				case SPAIR_SIG_DATA: {
702 					break;
703 				}
704 
705 				case SPAIR_SIG_CANCEL: {
706 					errno = EPIPE;
707 					res = -1;
708 					goto out;
709 				}
710 
711 				default: {
712 					__ASSERT(false,
713 						"unrecognized result: %d",
714 						result);
715 					continue;
716 				}
717 			}
718 
719 			/* SPAIR_SIG_DATA was received */
720 			break;
721 		}
722 	}
723 
724 	bytes_read = ring_buf_get(&spair->recv_q, (void *)buffer, count);
725 	if (spair_read_avail(spair) == 0 && !sock_is_eof(spair)) {
726 		k_poll_signal_reset(&spair->readable);
727 	}
728 
729 	if (is_connected) {
730 		res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
731 		__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
732 	}
733 
734 	res = bytes_read;
735 
736 out:
737 
738 	if (spair != NULL && have_local_sem) {
739 		k_sem_give(&spair->sem);
740 	}
741 
742 	return res;
743 }
744 
zsock_poll_prepare_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)745 static int zsock_poll_prepare_ctx(struct spair *const spair,
746 				  struct zsock_pollfd *const pfd,
747 				  struct k_poll_event **pev,
748 				  struct k_poll_event *pev_end)
749 {
750 	int res;
751 
752 	struct spair *remote = NULL;
753 	bool have_remote_sem = false;
754 
755 	if (pfd->events & ZSOCK_POLLIN) {
756 
757 		/* Tell poll() to short-circuit wait */
758 		if (sock_is_eof(spair)) {
759 			res = -EALREADY;
760 			goto out;
761 		}
762 
763 		if (*pev == pev_end) {
764 			res = -ENOMEM;
765 			goto out;
766 		}
767 
768 		/* Wait until data has been written to the local end */
769 		(*pev)->obj = &spair->readable;
770 	}
771 
772 	if (pfd->events & ZSOCK_POLLOUT) {
773 
774 		/* Tell poll() to short-circuit wait */
775 		if (!sock_is_connected(spair)) {
776 			res = -EALREADY;
777 			goto out;
778 		}
779 
780 		if (*pev == pev_end) {
781 			res = -ENOMEM;
782 			goto out;
783 		}
784 
785 		remote = zvfs_get_fd_obj(spair->remote,
786 			(const struct fd_op_vtable *)
787 			&spair_fd_op_vtable, 0);
788 
789 		__ASSERT(remote != NULL, "remote is NULL");
790 
791 		res = k_sem_take(&remote->sem, K_FOREVER);
792 		if (res < 0) {
793 			goto out;
794 		}
795 
796 		have_remote_sem = true;
797 
798 		/* Wait until the recv queue on the remote end is no longer full */
799 		(*pev)->obj = &remote->writeable;
800 	}
801 
802 	(*pev)->type = K_POLL_TYPE_SIGNAL;
803 	(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
804 	(*pev)->state = K_POLL_STATE_NOT_READY;
805 
806 	(*pev)++;
807 
808 	res = 0;
809 
810 out:
811 
812 	if (remote != NULL && have_remote_sem) {
813 		k_sem_give(&remote->sem);
814 	}
815 
816 	return res;
817 }
818 
zsock_poll_update_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev)819 static int zsock_poll_update_ctx(struct spair *const spair,
820 				 struct zsock_pollfd *const pfd,
821 				 struct k_poll_event **pev)
822 {
823 	int res;
824 	int signaled;
825 	int result;
826 	struct spair *remote = NULL;
827 	bool have_remote_sem = false;
828 
829 	if (pfd->events & ZSOCK_POLLOUT) {
830 		if (!sock_is_connected(spair)) {
831 			pfd->revents |= ZSOCK_POLLHUP;
832 			goto pollout_done;
833 		}
834 
835 		remote = zvfs_get_fd_obj(spair->remote,
836 			(const struct fd_op_vtable *) &spair_fd_op_vtable, 0);
837 
838 		__ASSERT(remote != NULL, "remote is NULL");
839 
840 		res = k_sem_take(&remote->sem, K_FOREVER);
841 		if (res < 0) {
842 			/* if other end is deleted, this might occur */
843 			goto pollout_done;
844 		}
845 
846 		have_remote_sem = true;
847 
848 		if (spair_write_avail(spair) > 0) {
849 			pfd->revents |= ZSOCK_POLLOUT;
850 			goto pollout_done;
851 		}
852 
853 		/* check to see if op was canceled */
854 		signaled = false;
855 		k_poll_signal_check(&remote->writeable, &signaled, &result);
856 		if (signaled) {
857 			/* Cannot be SPAIR_SIG_DATA, because
858 			 * spair_write_avail() would have
859 			 * returned 0
860 			 */
861 			__ASSERT(result == SPAIR_SIG_CANCEL,
862 				"invalid result %d", result);
863 			pfd->revents |= ZSOCK_POLLHUP;
864 		}
865 	}
866 
867 pollout_done:
868 
869 	if (pfd->events & ZSOCK_POLLIN) {
870 		if (sock_is_eof(spair)) {
871 			pfd->revents |= ZSOCK_POLLIN;
872 			goto pollin_done;
873 		}
874 
875 		if (spair_read_avail(spair) > 0) {
876 			pfd->revents |= ZSOCK_POLLIN;
877 			goto pollin_done;
878 		}
879 
880 		/* check to see if op was canceled */
881 		signaled = false;
882 		k_poll_signal_check(&spair->readable, &signaled, &result);
883 		if (signaled) {
884 			/* Cannot be SPAIR_SIG_DATA, because
885 			 * spair_read_avail() would have
886 			 * returned 0
887 			 */
888 			__ASSERT(result == SPAIR_SIG_CANCEL,
889 					 "invalid result %d", result);
890 			pfd->revents |= ZSOCK_POLLIN;
891 		}
892 	}
893 
894 pollin_done:
895 	res = 0;
896 
897 	(*pev)++;
898 
899 	if (remote != NULL && have_remote_sem) {
900 		k_sem_give(&remote->sem);
901 	}
902 
903 	return res;
904 }
905 
spair_ioctl(void * obj,unsigned int request,va_list args)906 static int spair_ioctl(void *obj, unsigned int request, va_list args)
907 {
908 	int res;
909 	struct zsock_pollfd *pfd;
910 	struct k_poll_event **pev;
911 	struct k_poll_event *pev_end;
912 	int flags = 0;
913 	bool have_local_sem = false;
914 	struct spair *const spair = (struct spair *)obj;
915 
916 	if (spair == NULL) {
917 		errno = EINVAL;
918 		res = -1;
919 		goto out;
920 	}
921 
922 	/* The local sem is always taken in this function. If a subsequent
923 	 * function call requires the remote sem, it must acquire and free the
924 	 * remote sem.
925 	 */
926 	res = k_sem_take(&spair->sem, K_FOREVER);
927 	__ASSERT(res == 0, "failed to take local sem: %d", res);
928 
929 	have_local_sem = true;
930 
931 	switch (request) {
932 		case F_GETFL: {
933 			if (sock_is_nonblock(spair)) {
934 				flags |= O_NONBLOCK;
935 			}
936 
937 			res = flags;
938 			goto out;
939 		}
940 
941 		case F_SETFL: {
942 			flags = va_arg(args, int);
943 
944 			if (flags & O_NONBLOCK) {
945 				spair->flags |= SPAIR_FLAG_NONBLOCK;
946 			} else {
947 				spair->flags &= ~SPAIR_FLAG_NONBLOCK;
948 			}
949 
950 			res = 0;
951 			goto out;
952 		}
953 
954 		case ZFD_IOCTL_FIONBIO: {
955 			spair->flags |= SPAIR_FLAG_NONBLOCK;
956 			res = 0;
957 			goto out;
958 		}
959 
960 		case ZFD_IOCTL_FIONREAD: {
961 			int *nbytes;
962 
963 			nbytes = va_arg(args, int *);
964 			*nbytes = spair_read_avail(spair);
965 
966 			res = 0;
967 			goto out;
968 		}
969 
970 		case ZFD_IOCTL_POLL_PREPARE: {
971 			pfd = va_arg(args, struct zsock_pollfd *);
972 			pev = va_arg(args, struct k_poll_event **);
973 			pev_end = va_arg(args, struct k_poll_event *);
974 
975 			res = zsock_poll_prepare_ctx(obj, pfd, pev, pev_end);
976 			goto out;
977 		}
978 
979 		case ZFD_IOCTL_POLL_UPDATE: {
980 			pfd = va_arg(args, struct zsock_pollfd *);
981 			pev = va_arg(args, struct k_poll_event **);
982 
983 			res = zsock_poll_update_ctx(obj, pfd, pev);
984 			goto out;
985 		}
986 
987 		default: {
988 			errno = EOPNOTSUPP;
989 			res = -1;
990 			goto out;
991 		}
992 	}
993 
994 out:
995 	if (spair != NULL && have_local_sem) {
996 		k_sem_give(&spair->sem);
997 	}
998 
999 	return res;
1000 }
1001 
spair_bind(void * obj,const struct sockaddr * addr,socklen_t addrlen)1002 static int spair_bind(void *obj, const struct sockaddr *addr,
1003 		      socklen_t addrlen)
1004 {
1005 	ARG_UNUSED(obj);
1006 	ARG_UNUSED(addr);
1007 	ARG_UNUSED(addrlen);
1008 
1009 	errno = EISCONN;
1010 	return -1;
1011 }
1012 
spair_connect(void * obj,const struct sockaddr * addr,socklen_t addrlen)1013 static int spair_connect(void *obj, const struct sockaddr *addr,
1014 			 socklen_t addrlen)
1015 {
1016 	ARG_UNUSED(obj);
1017 	ARG_UNUSED(addr);
1018 	ARG_UNUSED(addrlen);
1019 
1020 	errno = EISCONN;
1021 	return -1;
1022 }
1023 
spair_listen(void * obj,int backlog)1024 static int spair_listen(void *obj, int backlog)
1025 {
1026 	ARG_UNUSED(obj);
1027 	ARG_UNUSED(backlog);
1028 
1029 	errno = EINVAL;
1030 	return -1;
1031 }
1032 
spair_accept(void * obj,struct sockaddr * addr,socklen_t * addrlen)1033 static int spair_accept(void *obj, struct sockaddr *addr,
1034 			socklen_t *addrlen)
1035 {
1036 	ARG_UNUSED(obj);
1037 	ARG_UNUSED(addr);
1038 	ARG_UNUSED(addrlen);
1039 
1040 	errno = EOPNOTSUPP;
1041 	return -1;
1042 }
1043 
spair_sendto(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1044 static ssize_t spair_sendto(void *obj, const void *buf, size_t len,
1045 			    int flags, const struct sockaddr *dest_addr,
1046 				 socklen_t addrlen)
1047 {
1048 	ARG_UNUSED(flags);
1049 	ARG_UNUSED(dest_addr);
1050 	ARG_UNUSED(addrlen);
1051 
1052 	return spair_write(obj, buf, len);
1053 }
1054 
spair_sendmsg(void * obj,const struct msghdr * msg,int flags)1055 static ssize_t spair_sendmsg(void *obj, const struct msghdr *msg,
1056 			     int flags)
1057 {
1058 	ARG_UNUSED(flags);
1059 
1060 	int res;
1061 	size_t len = 0;
1062 	bool is_connected;
1063 	size_t avail;
1064 	bool is_nonblock;
1065 	struct spair *const spair = (struct spair *)obj;
1066 
1067 	if (spair == NULL || msg == NULL) {
1068 		errno = EINVAL;
1069 		res = -1;
1070 		goto out;
1071 	}
1072 
1073 	is_connected = sock_is_connected(spair);
1074 	avail = is_connected ? spair_write_avail(spair) : 0;
1075 	is_nonblock = sock_is_nonblock(spair);
1076 
1077 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1078 		/* check & msg->msg_iov[i]? */
1079 		/* check & msg->msg_iov[i].iov_base? */
1080 		len += msg->msg_iov[i].iov_len;
1081 	}
1082 
1083 	if (!is_connected) {
1084 		errno = EPIPE;
1085 		res = -1;
1086 		goto out;
1087 	}
1088 
1089 	if (len == 0) {
1090 		res = 0;
1091 		goto out;
1092 	}
1093 
1094 	if (len > avail && is_nonblock) {
1095 		errno = EMSGSIZE;
1096 		res = -1;
1097 		goto out;
1098 	}
1099 
1100 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1101 		res = spair_write(spair, msg->msg_iov[i].iov_base,
1102 			msg->msg_iov[i].iov_len);
1103 		if (res == -1) {
1104 			goto out;
1105 		}
1106 	}
1107 
1108 	res = len;
1109 
1110 out:
1111 	return res;
1112 }
1113 
spair_recvfrom(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1114 static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len,
1115 			      int flags, struct sockaddr *src_addr,
1116 				   socklen_t *addrlen)
1117 {
1118 	(void)flags;
1119 	(void)src_addr;
1120 	(void)addrlen;
1121 
1122 	if (addrlen != NULL) {
1123 		/* Protocol (PF_UNIX) does not support addressing with connected
1124 		 * sockets and, therefore, it is unspecified behaviour to modify
1125 		 * src_addr. However, it would be ambiguous to leave addrlen
1126 		 * untouched if the user expects it to be updated. It is not
1127 		 * mentioned that modifying addrlen is unspecified. Therefore
1128 		 * we choose to eliminate ambiguity.
1129 		 *
1130 		 * Setting it to zero mimics Linux's behaviour.
1131 		 */
1132 		*addrlen = 0;
1133 	}
1134 
1135 	return spair_read(obj, buf, max_len);
1136 }
1137 
spair_getsockopt(void * obj,int level,int optname,void * optval,socklen_t * optlen)1138 static int spair_getsockopt(void *obj, int level, int optname,
1139 			    void *optval, socklen_t *optlen)
1140 {
1141 	ARG_UNUSED(obj);
1142 	ARG_UNUSED(level);
1143 	ARG_UNUSED(optname);
1144 	ARG_UNUSED(optval);
1145 	ARG_UNUSED(optlen);
1146 
1147 	errno = ENOPROTOOPT;
1148 	return -1;
1149 }
1150 
spair_setsockopt(void * obj,int level,int optname,const void * optval,socklen_t optlen)1151 static int spair_setsockopt(void *obj, int level, int optname,
1152 			    const void *optval, socklen_t optlen)
1153 {
1154 	ARG_UNUSED(obj);
1155 	ARG_UNUSED(level);
1156 	ARG_UNUSED(optname);
1157 	ARG_UNUSED(optval);
1158 	ARG_UNUSED(optlen);
1159 
1160 	errno = ENOPROTOOPT;
1161 	return -1;
1162 }
1163 
spair_close(void * obj)1164 static int spair_close(void *obj)
1165 {
1166 	struct spair *const spair = (struct spair *)obj;
1167 	int res;
1168 
1169 	res = k_sem_take(&spair->sem, K_FOREVER);
1170 	__ASSERT(res == 0, "failed to take local sem: %d", res);
1171 
1172 	/* disconnect the remote endpoint */
1173 	spair_delete(spair);
1174 
1175 	/* Note that the semaphore released already so need to do it here */
1176 
1177 	return 0;
1178 }
1179 
1180 static const struct socket_op_vtable spair_fd_op_vtable = {
1181 	.fd_vtable = {
1182 		.read = spair_read,
1183 		.write = spair_write,
1184 		.close = spair_close,
1185 		.ioctl = spair_ioctl,
1186 	},
1187 	.bind = spair_bind,
1188 	.connect = spair_connect,
1189 	.listen = spair_listen,
1190 	.accept = spair_accept,
1191 	.sendto = spair_sendto,
1192 	.sendmsg = spair_sendmsg,
1193 	.recvfrom = spair_recvfrom,
1194 	.getsockopt = spair_getsockopt,
1195 	.setsockopt = spair_setsockopt,
1196 };
1197