1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <fcntl.h>
8 
9 /* Zephyr headers */
10 #include <logging/log.h>
11 LOG_MODULE_REGISTER(net_spair, CONFIG_NET_SOCKETS_LOG_LEVEL);
12 
13 #include <kernel.h>
14 #include <net/socket.h>
15 #include <syscall_handler.h>
16 #include <sys/__assert.h>
17 #include <sys/fdtable.h>
18 
19 #include "sockets_internal.h"
20 
21 enum {
22 	SPAIR_SIG_CANCEL, /**< operation has been canceled */
23 	SPAIR_SIG_DATA,   /**< @ref spair.recv_q has been updated */
24 };
25 
26 enum {
27 	SPAIR_FLAG_NONBLOCK = (1 << 0), /**< socket is non-blocking */
28 };
29 
30 #define SPAIR_FLAGS_DEFAULT 0
31 
32 /**
33  * Socketpair endpoint structure
34  *
35  * This structure represents one half of a socketpair (an 'endpoint').
36  *
37  * The implementation strives for compatibility with socketpair(2).
38  *
39  * Resources contained within this structure are said to be 'local', while
40  * reources contained within the other half of the socketpair (or other
41  * endpoint) are said to be 'remote'.
42  *
43  * Theory of operation:
44  * - each end of a socketpair owns a @a recv_q
45  * - since there is no write queue, data is either written or not
46  * - read and write operations may return partial transfers
47  * - read operations may block if the local @a recv_q is empty
48  * - write operations may block if the remote @a recv_q is full
49  * - each endpoint may be blocking or non-blocking
50  */
51 __net_socket struct spair {
52 	int remote; /**< the remote endpoint file descriptor */
53 	uint32_t flags; /**< status and option bits */
54 	struct k_sem sem; /**< semaphore for exclusive structure access */
55 	struct k_pipe recv_q; /**< receive queue of local endpoint */
56 	/** indicates local @a recv_q isn't empty */
57 	struct k_poll_signal readable;
58 	/** indicates local @a recv_q isn't full */
59 	struct k_poll_signal writeable;
60 	/** buffer for @a recv_q recv_q */
61 	uint8_t buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE];
62 };
63 
64 /* forward declaration */
65 static const struct socket_op_vtable spair_fd_op_vtable;
66 
67 #undef sock_is_nonblock
68 /** Determine if a @ref spair is in non-blocking mode */
sock_is_nonblock(const struct spair * spair)69 static inline bool sock_is_nonblock(const struct spair *spair)
70 {
71 	return !!(spair->flags & SPAIR_FLAG_NONBLOCK);
72 }
73 
74 /** Determine if a @ref spair is connected */
sock_is_connected(const struct spair * spair)75 static inline bool sock_is_connected(const struct spair *spair)
76 {
77 	const struct spair *remote = z_get_fd_obj(spair->remote,
78 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
79 
80 	if (remote == NULL) {
81 		return false;
82 	}
83 
84 	return true;
85 }
86 
87 #undef sock_is_eof
88 /** Determine if a @ref spair has encountered end-of-file */
sock_is_eof(const struct spair * spair)89 static inline bool sock_is_eof(const struct spair *spair)
90 {
91 	return !sock_is_connected(spair);
92 }
93 
94 /**
95  * Determine bytes available to write
96  *
97  * Specifically, this function calculates the number of bytes that may be
98  * written to a given @ref spair without blocking.
99  */
spair_write_avail(struct spair * spair)100 static inline size_t spair_write_avail(struct spair *spair)
101 {
102 	struct spair *const remote = z_get_fd_obj(spair->remote,
103 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
104 
105 	if (remote == NULL) {
106 		return 0;
107 	}
108 
109 	return k_pipe_write_avail(&remote->recv_q);
110 }
111 
112 /**
113  * Determine bytes available to read
114  *
115  * Specifically, this function calculates the number of bytes that may be
116  * read from a given @ref spair without blocking.
117  */
spair_read_avail(struct spair * spair)118 static inline size_t spair_read_avail(struct spair *spair)
119 {
120 	return k_pipe_read_avail(&spair->recv_q);
121 }
122 
123 /** Swap two 32-bit integers */
swap32(uint32_t * a,uint32_t * b)124 static inline void swap32(uint32_t *a, uint32_t *b)
125 {
126 	uint32_t c;
127 
128 	c = *b;
129 	*b = *a;
130 	*a = c;
131 }
132 
133 /**
134  * Delete @param spair
135  *
136  * This function deletes one endpoint of a socketpair.
137  *
138  * Theory of operation:
139  * - we have a socketpair with two endpoints: A and B
140  * - we have two threads: T1 and T2
141  * - T1 operates on endpoint A
142  * - T2 operates on endpoint B
143  *
144  * There are two possible cases where a blocking operation must be notified
145  * when one endpoint is closed:
146  * -# T1 is blocked reading from A and T2 closes B
147  *    T1 waits on A's write signal. T2 triggers the remote
148  *    @ref spair.readable
149  * -# T1 is blocked writing to A and T2 closes B
150  *    T1 is waits on B's read signal. T2 triggers the local
151  *    @ref spair.writeable.
152  *
153  * If the remote endpoint is already closed, the former operation does not
154  * take place. Otherwise, the @ref spair.remote of the local endpoint is
155  * set to -1.
156  *
157  * If no threads are blocking on A, then the signals have no effect.
158  *
159  * The memeory associated with the local endpoint is cleared and freed.
160  */
spair_delete(struct spair * spair)161 static void spair_delete(struct spair *spair)
162 {
163 	int res;
164 	struct spair *remote = NULL;
165 	bool have_remote_sem = false;
166 
167 	if (spair == NULL) {
168 		return;
169 	}
170 
171 	if (spair->remote != -1) {
172 		remote = z_get_fd_obj(spair->remote,
173 			(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
174 
175 		if (remote != NULL) {
176 			res = k_sem_take(&remote->sem, K_FOREVER);
177 			if (res == 0) {
178 				have_remote_sem = true;
179 				remote->remote = -1;
180 				res = k_poll_signal_raise(&remote->readable,
181 					SPAIR_SIG_CANCEL);
182 				__ASSERT(res == 0,
183 					"k_poll_signal_raise() failed: %d",
184 					res);
185 			}
186 		}
187 	}
188 
189 	spair->remote = -1;
190 
191 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_CANCEL);
192 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
193 
194 	/* ensure no private information is released to the memory pool */
195 	memset(spair, 0, sizeof(*spair));
196 #ifdef CONFIG_USERSPACE
197 	k_object_free(spair);
198 #else
199 	k_free(spair);
200 #endif
201 
202 	if (remote != NULL && have_remote_sem) {
203 		k_sem_give(&remote->sem);
204 	}
205 }
206 
207 /**
208  * Create a @ref spair (1/2 of a socketpair)
209  *
210  * The idea is to call this twice, but store the "local" side in the
211  * @ref spair.remote field initially.
212  *
213  * If both allocations are successful, then swap the @ref spair.remote
214  * fields in the two @ref spair instances.
215  */
spair_new(void)216 static struct spair *spair_new(void)
217 {
218 	struct spair *spair;
219 	int res;
220 
221 #ifdef CONFIG_USERSPACE
222 	struct z_object *zo = z_dynamic_object_create(sizeof(*spair));
223 
224 	if (zo == NULL) {
225 		spair = NULL;
226 	} else {
227 		spair = zo->name;
228 		zo->type = K_OBJ_NET_SOCKET;
229 	}
230 #else
231 	spair = k_malloc(sizeof(*spair));
232 #endif
233 	if (spair == NULL) {
234 		errno = ENOMEM;
235 		goto out;
236 	}
237 	memset(spair, 0, sizeof(*spair));
238 
239 	/* initialize any non-zero default values */
240 	spair->remote = -1;
241 	spair->flags = SPAIR_FLAGS_DEFAULT;
242 
243 	k_sem_init(&spair->sem, 1, 1);
244 	k_pipe_init(&spair->recv_q, spair->buf, sizeof(spair->buf));
245 	k_poll_signal_init(&spair->readable);
246 	k_poll_signal_init(&spair->writeable);
247 
248 	/* A new socket is always writeable after creation */
249 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
250 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
251 
252 	spair->remote = z_reserve_fd();
253 	if (spair->remote == -1) {
254 		errno = ENFILE;
255 		goto cleanup;
256 	}
257 
258 	z_finalize_fd(spair->remote, spair,
259 		      (const struct fd_op_vtable *)&spair_fd_op_vtable);
260 
261 	goto out;
262 
263 cleanup:
264 	spair_delete(spair);
265 	spair = NULL;
266 
267 out:
268 	return spair;
269 }
270 
z_impl_zsock_socketpair(int family,int type,int proto,int * sv)271 int z_impl_zsock_socketpair(int family, int type, int proto, int *sv)
272 {
273 	int res;
274 	size_t i;
275 	struct spair *obj[2] = {};
276 
277 	if (family != AF_UNIX) {
278 		errno = EAFNOSUPPORT;
279 		res = -1;
280 		goto errout;
281 	}
282 
283 	if (type != SOCK_STREAM) {
284 		errno = EPROTOTYPE;
285 		res = -1;
286 		goto errout;
287 	}
288 
289 	if (proto != 0) {
290 		errno = EPROTONOSUPPORT;
291 		res = -1;
292 		goto errout;
293 	}
294 
295 	if (sv == NULL) {
296 		/* not listed in normative spec, but mimics Linux behaviour */
297 		errno = EFAULT;
298 		res = -1;
299 		goto errout;
300 	}
301 
302 	for (i = 0; i < 2; ++i) {
303 		obj[i] = spair_new();
304 		if (!obj[i]) {
305 			res = -1;
306 			goto cleanup;
307 		}
308 	}
309 
310 	/* connect the two endpoints */
311 	swap32(&obj[0]->remote, &obj[1]->remote);
312 
313 	for (i = 0; i < 2; ++i) {
314 		sv[i] = obj[i]->remote;
315 		k_sem_give(&obj[0]->sem);
316 	}
317 
318 	return 0;
319 
320 cleanup:
321 	for (i = 0; i < 2; ++i) {
322 		spair_delete(obj[i]);
323 	}
324 
325 errout:
326 	return res;
327 }
328 
329 #ifdef CONFIG_USERSPACE
z_vrfy_zsock_socketpair(int family,int type,int proto,int * sv)330 int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv)
331 {
332 	int ret;
333 	int tmp[2];
334 
335 	if (!sv || Z_SYSCALL_MEMORY_WRITE(sv, sizeof(tmp)) != 0) {
336 		/* not listed in normative spec, but mimics linux behaviour */
337 		errno = EFAULT;
338 		ret = -1;
339 		goto out;
340 	}
341 
342 	ret = z_impl_zsock_socketpair(family, type, proto, tmp);
343 	if (ret == 0) {
344 		Z_OOPS(z_user_to_copy(sv, tmp, sizeof(tmp)));
345 	}
346 
347 out:
348 	return ret;
349 }
350 
351 #include <syscalls/zsock_socketpair_mrsh.c>
352 #endif /* CONFIG_USERSPACE */
353 
354 /**
355  * Write data to one end of a @ref spair
356  *
357  * Data written on one file descriptor of a socketpair can be read at the
358  * other end using common POSIX calls such as read(2) or recv(2).
359  *
360  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
361  * this function will return immediately. If no data was written on a
362  * non-blocking file descriptor, then -1 will be returned and @ref errno will
363  * be set to @ref EAGAIN.
364  *
365  * Blocking write operations occur when the @ref O_NONBLOCK flag is @em not
366  * set and there is insufficient space in the @em remote @ref spair.pipe.
367  *
368  * Such a blocking write will suspend execution of the current thread until
369  * one of two possible results is received on the @em remote
370  * @ref spair.writeable:
371  *
372  * 1) @ref SPAIR_SIG_DATA - data has been read from the @em remote
373  *    @ref spair.pipe. Thus, allowing more data to be written.
374  *
375  * 2) @ref SPAIR_SIG_CANCEL - the @em remote socketpair endpoint was closed
376  *    Receipt of this result is analagous to SIGPIPE from POSIX
377  *    ("Write on a pipe with no one to read it."). In this case, the function
378  *    will return -1 and set @ref errno to @ref EPIPE.
379  *
380  * @param obj the address of an @ref spair object cast to `void *`
381  * @param buffer the buffer to write
382  * @param count the number of bytes to write from @p buffer
383  *
384  * @return on success, a number > 0 representing the number of bytes written
385  * @return -1 on error, with @ref errno set appropriately.
386  */
spair_write(void * obj,const void * buffer,size_t count)387 static ssize_t spair_write(void *obj, const void *buffer, size_t count)
388 {
389 	int res;
390 	int key;
391 	size_t avail;
392 	bool is_nonblock;
393 	size_t bytes_written;
394 	bool have_local_sem = false;
395 	bool have_remote_sem = false;
396 	bool will_block = false;
397 	struct spair *const spair = (struct spair *)obj;
398 	struct spair *remote = NULL;
399 
400 	if (obj == NULL || buffer == NULL || count == 0) {
401 		errno = EINVAL;
402 		res = -1;
403 		goto out;
404 	}
405 
406 	key = irq_lock();
407 	is_nonblock = sock_is_nonblock(spair);
408 	res = k_sem_take(&spair->sem, K_NO_WAIT);
409 	irq_unlock(key);
410 	if (res < 0) {
411 		if (is_nonblock) {
412 			errno = EAGAIN;
413 			res = -1;
414 			goto out;
415 		}
416 
417 		res = k_sem_take(&spair->sem, K_FOREVER);
418 		if (res < 0) {
419 			errno = -res;
420 			res = -1;
421 			goto out;
422 		}
423 		is_nonblock = sock_is_nonblock(spair);
424 	}
425 
426 	have_local_sem = true;
427 
428 	remote = z_get_fd_obj(spair->remote,
429 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
430 
431 	if (remote == NULL) {
432 		errno = EPIPE;
433 		res = -1;
434 		goto out;
435 	}
436 
437 	res = k_sem_take(&remote->sem, K_NO_WAIT);
438 	if (res < 0) {
439 		if (is_nonblock) {
440 			errno = EAGAIN;
441 			res = -1;
442 			goto out;
443 		}
444 		res = k_sem_take(&remote->sem, K_FOREVER);
445 		if (res < 0) {
446 			errno = -res;
447 			res = -1;
448 			goto out;
449 		}
450 	}
451 
452 	have_remote_sem = true;
453 
454 	avail = spair_write_avail(spair);
455 
456 	if (avail == 0) {
457 		if (is_nonblock) {
458 			errno = EAGAIN;
459 			res = -1;
460 			goto out;
461 		}
462 		will_block = true;
463 	}
464 
465 	if (will_block) {
466 		if (k_is_in_isr()) {
467 			errno = EAGAIN;
468 			res = -1;
469 			goto out;
470 		}
471 
472 		for (int signaled = false, result = -1; !signaled;
473 			result = -1) {
474 
475 			struct k_poll_event events[] = {
476 				K_POLL_EVENT_INITIALIZER(
477 					K_POLL_TYPE_SIGNAL,
478 					K_POLL_MODE_NOTIFY_ONLY,
479 					&remote->writeable),
480 			};
481 
482 			k_sem_give(&remote->sem);
483 			have_remote_sem = false;
484 
485 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
486 			if (res < 0) {
487 				errno = -res;
488 				res = -1;
489 				goto out;
490 			}
491 
492 			remote = z_get_fd_obj(spair->remote,
493 				(const struct fd_op_vtable *)
494 				&spair_fd_op_vtable, 0);
495 
496 			if (remote == NULL) {
497 				errno = EPIPE;
498 				res = -1;
499 				goto out;
500 			}
501 
502 			res = k_sem_take(&remote->sem, K_FOREVER);
503 			if (res < 0) {
504 				errno = -res;
505 				res = -1;
506 				goto out;
507 			}
508 
509 			have_remote_sem = true;
510 
511 			k_poll_signal_check(&remote->writeable, &signaled,
512 					    &result);
513 			if (!signaled) {
514 				continue;
515 			}
516 
517 			switch (result) {
518 				case SPAIR_SIG_DATA: {
519 					break;
520 				}
521 
522 				case SPAIR_SIG_CANCEL: {
523 					errno = EPIPE;
524 					res = -1;
525 					goto out;
526 				}
527 
528 				default: {
529 					__ASSERT(false,
530 						"unrecognized result: %d",
531 						result);
532 					continue;
533 				}
534 			}
535 
536 			/* SPAIR_SIG_DATA was received */
537 			break;
538 		}
539 	}
540 
541 	res = k_pipe_put(&remote->recv_q, (void *)buffer, count,
542 			 &bytes_written, 1, K_NO_WAIT);
543 	__ASSERT(res == 0, "k_pipe_put() failed: %d", res);
544 
545 	if (spair_write_avail(spair) == 0) {
546 		k_poll_signal_reset(&remote->writeable);
547 	}
548 
549 	res = k_poll_signal_raise(&remote->readable, SPAIR_SIG_DATA);
550 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
551 
552 	res = bytes_written;
553 
554 out:
555 
556 	if (remote != NULL && have_remote_sem) {
557 		k_sem_give(&remote->sem);
558 	}
559 	if (spair != NULL && have_local_sem) {
560 		k_sem_give(&spair->sem);
561 	}
562 
563 	return res;
564 }
565 
566 /**
567  * Read data from one end of a @ref spair
568  *
569  * Data written on one file descriptor of a socketpair (with e.g. write(2) or
570  * send(2)) can be read at the other end using common POSIX calls such as
571  * read(2) or recv(2).
572  *
573  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
574  * this function will return immediately. If no data was read from a
575  * non-blocking file descriptor, then -1 will be returned and @ref errno will
576  * be set to @ref EAGAIN.
577  *
578  * Blocking read operations occur when the @ref O_NONBLOCK flag is @em not set
579  * and there are no bytes to read in the @em local @ref spair.pipe.
580  *
581  * Such a blocking read will suspend execution of the current thread until
582  * one of two possible results is received on the @em local
583  * @ref spair.readable:
584  *
585  * -# @ref SPAIR_SIG_DATA - data has been written to the @em local
586  *    @ref spair.pipe. Thus, allowing more data to be read.
587  *
588  * -# @ref SPAIR_SIG_CANCEL - read of the the @em local @spair.pipe
589  *    must be cancelled for some reason (e.g. the file descriptor will be
590  *    closed imminently). In this case, the function will return -1 and set
591  *    @ref errno to @ref EINTR.
592  *
593  * @param obj the address of an @ref spair object cast to `void *`
594  * @param buffer the buffer in which to read
595  * @param count the number of bytes to read
596  *
597  * @return on success, a number > 0 representing the number of bytes written
598  * @return -1 on error, with @ref errno set appropriately.
599  */
spair_read(void * obj,void * buffer,size_t count)600 static ssize_t spair_read(void *obj, void *buffer, size_t count)
601 {
602 	int res;
603 	int key;
604 	bool is_connected;
605 	size_t avail;
606 	bool is_nonblock;
607 	size_t bytes_read;
608 	bool have_local_sem = false;
609 	bool will_block = false;
610 	struct spair *const spair = (struct spair *)obj;
611 
612 	if (obj == NULL || buffer == NULL || count == 0) {
613 		errno = EINVAL;
614 		res = -1;
615 		goto out;
616 	}
617 
618 	key = irq_lock();
619 	is_nonblock = sock_is_nonblock(spair);
620 	res = k_sem_take(&spair->sem, K_NO_WAIT);
621 	irq_unlock(key);
622 	if (res < 0) {
623 		if (is_nonblock) {
624 			errno = EAGAIN;
625 			res = -1;
626 			goto out;
627 		}
628 
629 		res = k_sem_take(&spair->sem, K_FOREVER);
630 		if (res < 0) {
631 			errno = -res;
632 			res = -1;
633 			goto out;
634 		}
635 		is_nonblock = sock_is_nonblock(spair);
636 	}
637 
638 	have_local_sem = true;
639 
640 	is_connected = sock_is_connected(spair);
641 	avail = spair_read_avail(spair);
642 
643 	if (avail == 0) {
644 		if (!is_connected) {
645 			/* signal EOF */
646 			res = 0;
647 			goto out;
648 		}
649 
650 		if (is_nonblock) {
651 			errno = EAGAIN;
652 			res = -1;
653 			goto out;
654 		}
655 
656 		will_block = true;
657 	}
658 
659 	if (will_block) {
660 		if (k_is_in_isr()) {
661 			errno = EAGAIN;
662 			res = -1;
663 			goto out;
664 		}
665 
666 		for (int signaled = false, result = -1; !signaled;
667 			result = -1) {
668 
669 			struct k_poll_event events[] = {
670 				K_POLL_EVENT_INITIALIZER(
671 					K_POLL_TYPE_SIGNAL,
672 					K_POLL_MODE_NOTIFY_ONLY,
673 					&spair->readable
674 				),
675 			};
676 
677 			k_sem_give(&spair->sem);
678 			have_local_sem = false;
679 
680 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
681 			__ASSERT(res == 0, "k_poll() failed: %d", res);
682 
683 			res = k_sem_take(&spair->sem, K_FOREVER);
684 			__ASSERT(res == 0, "failed to take local sem: %d", res);
685 
686 			have_local_sem = true;
687 
688 			k_poll_signal_check(&spair->readable, &signaled,
689 					    &result);
690 			if (!signaled) {
691 				continue;
692 			}
693 
694 			switch (result) {
695 				case SPAIR_SIG_DATA: {
696 					break;
697 				}
698 
699 				case SPAIR_SIG_CANCEL: {
700 					errno = EPIPE;
701 					res = -1;
702 					goto out;
703 				}
704 
705 				default: {
706 					__ASSERT(false,
707 						"unrecognized result: %d",
708 						result);
709 					continue;
710 				}
711 			}
712 
713 			/* SPAIR_SIG_DATA was received */
714 			break;
715 		}
716 	}
717 
718 	res = k_pipe_get(&spair->recv_q, (void *)buffer, count, &bytes_read,
719 			 1, K_NO_WAIT);
720 	__ASSERT(res == 0, "k_pipe_get() failed: %d", res);
721 
722 	if (spair_read_avail(spair) == 0 && !sock_is_eof(spair)) {
723 		k_poll_signal_reset(&spair->readable);
724 	}
725 
726 	if (is_connected) {
727 		res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
728 		__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
729 	}
730 
731 	res = bytes_read;
732 
733 out:
734 
735 	if (spair != NULL && have_local_sem) {
736 		k_sem_give(&spair->sem);
737 	}
738 
739 	return res;
740 }
741 
zsock_poll_prepare_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)742 static int zsock_poll_prepare_ctx(struct spair *const spair,
743 				  struct zsock_pollfd *const pfd,
744 				  struct k_poll_event **pev,
745 				  struct k_poll_event *pev_end)
746 {
747 	int res;
748 
749 	struct spair *remote = NULL;
750 	bool have_remote_sem = false;
751 
752 	if (pfd->events & ZSOCK_POLLIN) {
753 
754 		/* Tell poll() to short-circuit wait */
755 		if (sock_is_eof(spair)) {
756 			res = -EALREADY;
757 			goto out;
758 		}
759 
760 		if (*pev == pev_end) {
761 			res = -ENOMEM;
762 			goto out;
763 		}
764 
765 		/* Wait until data has been written to the local end */
766 		(*pev)->obj = &spair->readable;
767 	}
768 
769 	if (pfd->events & ZSOCK_POLLOUT) {
770 
771 		/* Tell poll() to short-circuit wait */
772 		if (!sock_is_connected(spair)) {
773 			res = -EALREADY;
774 			goto out;
775 		}
776 
777 		if (*pev == pev_end) {
778 			res = -ENOMEM;
779 			goto out;
780 		}
781 
782 		remote = z_get_fd_obj(spair->remote,
783 			(const struct fd_op_vtable *)
784 			&spair_fd_op_vtable, 0);
785 
786 		__ASSERT(remote != NULL, "remote is NULL");
787 
788 		res = k_sem_take(&remote->sem, K_FOREVER);
789 		if (res < 0) {
790 			goto out;
791 		}
792 
793 		have_remote_sem = true;
794 
795 		/* Wait until the recv queue on the remote end is no longer full */
796 		(*pev)->obj = &remote->writeable;
797 	}
798 
799 	(*pev)->type = K_POLL_TYPE_SIGNAL;
800 	(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
801 	(*pev)->state = K_POLL_STATE_NOT_READY;
802 
803 	(*pev)++;
804 
805 	res = 0;
806 
807 out:
808 
809 	if (remote != NULL && have_remote_sem) {
810 		k_sem_give(&remote->sem);
811 	}
812 
813 	return res;
814 }
815 
zsock_poll_update_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev)816 static int zsock_poll_update_ctx(struct spair *const spair,
817 				 struct zsock_pollfd *const pfd,
818 				 struct k_poll_event **pev)
819 {
820 	int res;
821 	int signaled;
822 	int result;
823 	struct spair *remote = NULL;
824 	bool have_remote_sem = false;
825 
826 	if (pfd->events & ZSOCK_POLLOUT) {
827 		if (!sock_is_connected(spair)) {
828 			pfd->revents |= ZSOCK_POLLHUP;
829 			goto pollout_done;
830 		}
831 
832 		remote = z_get_fd_obj(spair->remote,
833 			(const struct fd_op_vtable *) &spair_fd_op_vtable, 0);
834 
835 		__ASSERT(remote != NULL, "remote is NULL");
836 
837 		res = k_sem_take(&remote->sem, K_FOREVER);
838 		if (res < 0) {
839 			/* if other end is deleted, this might occur */
840 			goto pollout_done;
841 		}
842 
843 		have_remote_sem = true;
844 
845 		if (spair_write_avail(spair) > 0) {
846 			pfd->revents |= ZSOCK_POLLOUT;
847 			goto pollout_done;
848 		}
849 
850 		/* check to see if op was canceled */
851 		signaled = false;
852 		k_poll_signal_check(&remote->writeable, &signaled, &result);
853 		if (signaled) {
854 			/* Cannot be SPAIR_SIG_DATA, because
855 			 * spair_write_avail() would have
856 			 * returned 0
857 			 */
858 			__ASSERT(result == SPAIR_SIG_CANCEL,
859 				"invalid result %d", result);
860 			pfd->revents |= ZSOCK_POLLHUP;
861 		}
862 	}
863 
864 pollout_done:
865 
866 	if (pfd->events & ZSOCK_POLLIN) {
867 		if (sock_is_eof(spair)) {
868 			pfd->revents |= ZSOCK_POLLIN;
869 			goto pollin_done;
870 		}
871 
872 		if (spair_read_avail(spair) > 0) {
873 			pfd->revents |= ZSOCK_POLLIN;
874 			goto pollin_done;
875 		}
876 
877 		/* check to see if op was canceled */
878 		signaled = false;
879 		k_poll_signal_check(&spair->readable, &signaled, &result);
880 		if (signaled) {
881 			/* Cannot be SPAIR_SIG_DATA, because
882 			 * spair_read_avail() would have
883 			 * returned 0
884 			 */
885 			__ASSERT(result == SPAIR_SIG_CANCEL,
886 					 "invalid result %d", result);
887 			pfd->revents |= ZSOCK_POLLIN;
888 		}
889 	}
890 
891 pollin_done:
892 	res = 0;
893 
894 	(*pev)++;
895 
896 	if (remote != NULL && have_remote_sem) {
897 		k_sem_give(&remote->sem);
898 	}
899 
900 	return res;
901 }
902 
spair_ioctl(void * obj,unsigned int request,va_list args)903 static int spair_ioctl(void *obj, unsigned int request, va_list args)
904 {
905 	int res;
906 	struct zsock_pollfd *pfd;
907 	struct k_poll_event **pev;
908 	struct k_poll_event *pev_end;
909 	int flags = 0;
910 	bool have_local_sem = false;
911 	struct spair *const spair = (struct spair *)obj;
912 
913 	if (spair == NULL) {
914 		errno = EINVAL;
915 		res = -1;
916 		goto out;
917 	}
918 
919 	/* The local sem is always taken in this function. If a subsequent
920 	 * function call requires the remote sem, it must acquire and free the
921 	 * remote sem.
922 	 */
923 	res = k_sem_take(&spair->sem, K_FOREVER);
924 	__ASSERT(res == 0, "failed to take local sem: %d", res);
925 
926 	have_local_sem = true;
927 
928 	switch (request) {
929 		case F_GETFL: {
930 			if (sock_is_nonblock(spair)) {
931 				flags |= O_NONBLOCK;
932 			}
933 
934 			res = flags;
935 			goto out;
936 		}
937 
938 		case F_SETFL: {
939 			flags = va_arg(args, int);
940 
941 			if (flags & O_NONBLOCK) {
942 				spair->flags |= SPAIR_FLAG_NONBLOCK;
943 			} else {
944 				spair->flags &= ~SPAIR_FLAG_NONBLOCK;
945 			}
946 
947 			res = 0;
948 			goto out;
949 		}
950 
951 		case ZFD_IOCTL_POLL_PREPARE: {
952 			pfd = va_arg(args, struct zsock_pollfd *);
953 			pev = va_arg(args, struct k_poll_event **);
954 			pev_end = va_arg(args, struct k_poll_event *);
955 
956 			res = zsock_poll_prepare_ctx(obj, pfd, pev, pev_end);
957 			goto out;
958 		}
959 
960 		case ZFD_IOCTL_POLL_UPDATE: {
961 			pfd = va_arg(args, struct zsock_pollfd *);
962 			pev = va_arg(args, struct k_poll_event **);
963 
964 			res = zsock_poll_update_ctx(obj, pfd, pev);
965 			goto out;
966 		}
967 
968 		default: {
969 			errno = EOPNOTSUPP;
970 			res = -1;
971 			goto out;
972 		}
973 	}
974 
975 out:
976 	if (spair != NULL && have_local_sem) {
977 		k_sem_give(&spair->sem);
978 	}
979 
980 	return res;
981 }
982 
spair_bind(void * obj,const struct sockaddr * addr,socklen_t addrlen)983 static int spair_bind(void *obj, const struct sockaddr *addr,
984 		      socklen_t addrlen)
985 {
986 	ARG_UNUSED(obj);
987 	ARG_UNUSED(addr);
988 	ARG_UNUSED(addrlen);
989 
990 	errno = EISCONN;
991 	return -1;
992 }
993 
spair_connect(void * obj,const struct sockaddr * addr,socklen_t addrlen)994 static int spair_connect(void *obj, const struct sockaddr *addr,
995 			 socklen_t addrlen)
996 {
997 	ARG_UNUSED(obj);
998 	ARG_UNUSED(addr);
999 	ARG_UNUSED(addrlen);
1000 
1001 	errno = EISCONN;
1002 	return -1;
1003 }
1004 
spair_listen(void * obj,int backlog)1005 static int spair_listen(void *obj, int backlog)
1006 {
1007 	ARG_UNUSED(obj);
1008 	ARG_UNUSED(backlog);
1009 
1010 	errno = EINVAL;
1011 	return -1;
1012 }
1013 
spair_accept(void * obj,struct sockaddr * addr,socklen_t * addrlen)1014 static int spair_accept(void *obj, struct sockaddr *addr,
1015 			socklen_t *addrlen)
1016 {
1017 	ARG_UNUSED(obj);
1018 	ARG_UNUSED(addr);
1019 	ARG_UNUSED(addrlen);
1020 
1021 	errno = EOPNOTSUPP;
1022 	return -1;
1023 }
1024 
spair_sendto(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1025 static ssize_t spair_sendto(void *obj, const void *buf, size_t len,
1026 			    int flags, const struct sockaddr *dest_addr,
1027 				 socklen_t addrlen)
1028 {
1029 	ARG_UNUSED(flags);
1030 	ARG_UNUSED(dest_addr);
1031 	ARG_UNUSED(addrlen);
1032 
1033 	return spair_write(obj, buf, len);
1034 }
1035 
spair_sendmsg(void * obj,const struct msghdr * msg,int flags)1036 static ssize_t spair_sendmsg(void *obj, const struct msghdr *msg,
1037 			     int flags)
1038 {
1039 	ARG_UNUSED(flags);
1040 
1041 	int res;
1042 	size_t len = 0;
1043 	bool is_connected;
1044 	size_t avail;
1045 	bool is_nonblock;
1046 	struct spair *const spair = (struct spair *)obj;
1047 
1048 	if (spair == NULL || msg == NULL) {
1049 		errno = EINVAL;
1050 		res = -1;
1051 		goto out;
1052 	}
1053 
1054 	is_connected = sock_is_connected(spair);
1055 	avail = is_connected ? spair_write_avail(spair) : 0;
1056 	is_nonblock = sock_is_nonblock(spair);
1057 
1058 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1059 		/* check & msg->msg_iov[i]? */
1060 		/* check & msg->msg_iov[i].iov_base? */
1061 		len += msg->msg_iov[i].iov_len;
1062 	}
1063 
1064 	if (!is_connected) {
1065 		errno = EPIPE;
1066 		res = -1;
1067 		goto out;
1068 	}
1069 
1070 	if (len == 0) {
1071 		res = 0;
1072 		goto out;
1073 	}
1074 
1075 	if (len > avail && is_nonblock) {
1076 		errno = EMSGSIZE;
1077 		res = -1;
1078 		goto out;
1079 	}
1080 
1081 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1082 		res = spair_write(spair, msg->msg_iov[i].iov_base,
1083 			msg->msg_iov[i].iov_len);
1084 		if (res == -1) {
1085 			goto out;
1086 		}
1087 	}
1088 
1089 	res = len;
1090 
1091 out:
1092 	return res;
1093 }
1094 
spair_recvfrom(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1095 static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len,
1096 			      int flags, struct sockaddr *src_addr,
1097 				   socklen_t *addrlen)
1098 {
1099 	(void)flags;
1100 	(void)src_addr;
1101 	(void)addrlen;
1102 
1103 	if (addrlen != NULL) {
1104 		/* Protocol (PF_UNIX) does not support addressing with connected
1105 		 * sockets and, therefore, it is unspecified behaviour to modify
1106 		 * src_addr. However, it would be ambiguous to leave addrlen
1107 		 * untouched if the user expects it to be updated. It is not
1108 		 * mentioned that modifying addrlen is unspecified. Therefore
1109 		 * we choose to eliminate ambiguity.
1110 		 *
1111 		 * Setting it to zero mimics Linux's behaviour.
1112 		 */
1113 		*addrlen = 0;
1114 	}
1115 
1116 	return spair_read(obj, buf, max_len);
1117 }
1118 
spair_getsockopt(void * obj,int level,int optname,void * optval,socklen_t * optlen)1119 static int spair_getsockopt(void *obj, int level, int optname,
1120 			    void *optval, socklen_t *optlen)
1121 {
1122 	ARG_UNUSED(obj);
1123 	ARG_UNUSED(level);
1124 	ARG_UNUSED(optname);
1125 	ARG_UNUSED(optval);
1126 	ARG_UNUSED(optlen);
1127 
1128 	errno = ENOPROTOOPT;
1129 	return -1;
1130 }
1131 
spair_setsockopt(void * obj,int level,int optname,const void * optval,socklen_t optlen)1132 static int spair_setsockopt(void *obj, int level, int optname,
1133 			    const void *optval, socklen_t optlen)
1134 {
1135 	ARG_UNUSED(obj);
1136 	ARG_UNUSED(level);
1137 	ARG_UNUSED(optname);
1138 	ARG_UNUSED(optval);
1139 	ARG_UNUSED(optlen);
1140 
1141 	errno = ENOPROTOOPT;
1142 	return -1;
1143 }
1144 
spair_close(void * obj)1145 static int spair_close(void *obj)
1146 {
1147 	struct spair *const spair = (struct spair *)obj;
1148 	int res;
1149 
1150 	res = k_sem_take(&spair->sem, K_FOREVER);
1151 	__ASSERT(res == 0, "failed to take local sem: %d", res);
1152 
1153 	/* disconnect the remote endpoint */
1154 	spair_delete(spair);
1155 
1156 	/* Note that the semaphore released already so need to do it here */
1157 
1158 	return 0;
1159 }
1160 
1161 static const struct socket_op_vtable spair_fd_op_vtable = {
1162 	.fd_vtable = {
1163 		.read = spair_read,
1164 		.write = spair_write,
1165 		.close = spair_close,
1166 		.ioctl = spair_ioctl,
1167 	},
1168 	.bind = spair_bind,
1169 	.connect = spair_connect,
1170 	.listen = spair_listen,
1171 	.accept = spair_accept,
1172 	.sendto = spair_sendto,
1173 	.sendmsg = spair_sendmsg,
1174 	.recvfrom = spair_recvfrom,
1175 	.getsockopt = spair_getsockopt,
1176 	.setsockopt = spair_setsockopt,
1177 };
1178