1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/kernel.h>
8 #include <zephyr/net/socket.h>
9 #include <zephyr/posix/fcntl.h>
10 #include <zephyr/syscall_handler.h>
11 #include <zephyr/sys/__assert.h>
12 #include <zephyr/sys/fdtable.h>
13 
14 #include "sockets_internal.h"
15 
16 enum {
17 	SPAIR_SIG_CANCEL, /**< operation has been canceled */
18 	SPAIR_SIG_DATA,   /**< @ref spair.recv_q has been updated */
19 };
20 
21 enum {
22 	SPAIR_FLAG_NONBLOCK = (1 << 0), /**< socket is non-blocking */
23 };
24 
25 #define SPAIR_FLAGS_DEFAULT 0
26 
27 /**
28  * Socketpair endpoint structure
29  *
30  * This structure represents one half of a socketpair (an 'endpoint').
31  *
32  * The implementation strives for compatibility with socketpair(2).
33  *
34  * Resources contained within this structure are said to be 'local', while
35  * resources contained within the other half of the socketpair (or other
36  * endpoint) are said to be 'remote'.
37  *
38  * Theory of operation:
39  * - each end of a socketpair owns a @a recv_q
40  * - since there is no write queue, data is either written or not
41  * - read and write operations may return partial transfers
42  * - read operations may block if the local @a recv_q is empty
43  * - write operations may block if the remote @a recv_q is full
44  * - each endpoint may be blocking or non-blocking
45  */
46 __net_socket struct spair {
47 	int remote; /**< the remote endpoint file descriptor */
48 	uint32_t flags; /**< status and option bits */
49 	struct k_sem sem; /**< semaphore for exclusive structure access */
50 	struct k_pipe recv_q; /**< receive queue of local endpoint */
51 	/** indicates local @a recv_q isn't empty */
52 	struct k_poll_signal readable;
53 	/** indicates local @a recv_q isn't full */
54 	struct k_poll_signal writeable;
55 	/** buffer for @a recv_q recv_q */
56 	uint8_t buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE];
57 };
58 
59 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
60 K_MEM_SLAB_DEFINE_STATIC(spair_slab, sizeof(struct spair), CONFIG_NET_SOCKETPAIR_MAX * 2,
61 			 __alignof__(struct spair));
62 #endif /* CONFIG_NET_SOCKETPAIR_STATIC */
63 
64 /* forward declaration */
65 static const struct socket_op_vtable spair_fd_op_vtable;
66 
67 #undef sock_is_nonblock
68 /** Determine if a @ref spair is in non-blocking mode */
sock_is_nonblock(const struct spair * spair)69 static inline bool sock_is_nonblock(const struct spair *spair)
70 {
71 	return !!(spair->flags & SPAIR_FLAG_NONBLOCK);
72 }
73 
74 /** Determine if a @ref spair is connected */
sock_is_connected(const struct spair * spair)75 static inline bool sock_is_connected(const struct spair *spair)
76 {
77 	const struct spair *remote = z_get_fd_obj(spair->remote,
78 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
79 
80 	if (remote == NULL) {
81 		return false;
82 	}
83 
84 	return true;
85 }
86 
87 #undef sock_is_eof
88 /** Determine if a @ref spair has encountered end-of-file */
sock_is_eof(const struct spair * spair)89 static inline bool sock_is_eof(const struct spair *spair)
90 {
91 	return !sock_is_connected(spair);
92 }
93 
94 /**
95  * Determine bytes available to write
96  *
97  * Specifically, this function calculates the number of bytes that may be
98  * written to a given @ref spair without blocking.
99  */
spair_write_avail(struct spair * spair)100 static inline size_t spair_write_avail(struct spair *spair)
101 {
102 	struct spair *const remote = z_get_fd_obj(spair->remote,
103 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
104 
105 	if (remote == NULL) {
106 		return 0;
107 	}
108 
109 	return k_pipe_write_avail(&remote->recv_q);
110 }
111 
112 /**
113  * Determine bytes available to read
114  *
115  * Specifically, this function calculates the number of bytes that may be
116  * read from a given @ref spair without blocking.
117  */
spair_read_avail(struct spair * spair)118 static inline size_t spair_read_avail(struct spair *spair)
119 {
120 	return k_pipe_read_avail(&spair->recv_q);
121 }
122 
123 /** Swap two 32-bit integers */
swap32(uint32_t * a,uint32_t * b)124 static inline void swap32(uint32_t *a, uint32_t *b)
125 {
126 	uint32_t c;
127 
128 	c = *b;
129 	*b = *a;
130 	*a = c;
131 }
132 
133 /**
134  * Delete @param spair
135  *
136  * This function deletes one endpoint of a socketpair.
137  *
138  * Theory of operation:
139  * - we have a socketpair with two endpoints: A and B
140  * - we have two threads: T1 and T2
141  * - T1 operates on endpoint A
142  * - T2 operates on endpoint B
143  *
144  * There are two possible cases where a blocking operation must be notified
145  * when one endpoint is closed:
146  * -# T1 is blocked reading from A and T2 closes B
147  *    T1 waits on A's write signal. T2 triggers the remote
148  *    @ref spair.readable
149  * -# T1 is blocked writing to A and T2 closes B
150  *    T1 is waits on B's read signal. T2 triggers the local
151  *    @ref spair.writeable.
152  *
153  * If the remote endpoint is already closed, the former operation does not
154  * take place. Otherwise, the @ref spair.remote of the local endpoint is
155  * set to -1.
156  *
157  * If no threads are blocking on A, then the signals have no effect.
158  *
159  * The memory associated with the local endpoint is cleared and freed.
160  */
spair_delete(struct spair * spair)161 static void spair_delete(struct spair *spair)
162 {
163 	int res;
164 	struct spair *remote = NULL;
165 	bool have_remote_sem = false;
166 
167 	if (spair == NULL) {
168 		return;
169 	}
170 
171 	if (spair->remote != -1) {
172 		remote = z_get_fd_obj(spair->remote,
173 			(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
174 
175 		if (remote != NULL) {
176 			res = k_sem_take(&remote->sem, K_FOREVER);
177 			if (res == 0) {
178 				have_remote_sem = true;
179 				remote->remote = -1;
180 				res = k_poll_signal_raise(&remote->readable,
181 					SPAIR_SIG_CANCEL);
182 				__ASSERT(res == 0,
183 					"k_poll_signal_raise() failed: %d",
184 					res);
185 			}
186 		}
187 	}
188 
189 	spair->remote = -1;
190 
191 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_CANCEL);
192 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
193 
194 	if (remote != NULL && have_remote_sem) {
195 		k_sem_give(&remote->sem);
196 	}
197 
198 	/* ensure no private information is released to the memory pool */
199 	memset(spair, 0, sizeof(*spair));
200 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
201 	k_mem_slab_free(&spair_slab, (void *)spair);
202 #elif CONFIG_USERSPACE
203 	k_object_free(spair);
204 #else
205 	k_free(spair);
206 #endif
207 }
208 
209 /**
210  * Create a @ref spair (1/2 of a socketpair)
211  *
212  * The idea is to call this twice, but store the "local" side in the
213  * @ref spair.remote field initially.
214  *
215  * If both allocations are successful, then swap the @ref spair.remote
216  * fields in the two @ref spair instances.
217  */
spair_new(void)218 static struct spair *spair_new(void)
219 {
220 	struct spair *spair;
221 	int res;
222 
223 #ifdef CONFIG_NET_SOCKETPAIR_STATIC
224 
225 	res = k_mem_slab_alloc(&spair_slab, (void **) &spair, K_NO_WAIT);
226 	if (res != 0) {
227 		spair = NULL;
228 	}
229 
230 #elif CONFIG_USERSPACE
231 	struct z_object *zo = z_dynamic_object_create(sizeof(*spair));
232 
233 	if (zo == NULL) {
234 		spair = NULL;
235 	} else {
236 		spair = zo->name;
237 		zo->type = K_OBJ_NET_SOCKET;
238 	}
239 #else
240 	spair = k_malloc(sizeof(*spair));
241 #endif
242 	if (spair == NULL) {
243 		errno = ENOMEM;
244 		goto out;
245 	}
246 	memset(spair, 0, sizeof(*spair));
247 
248 	/* initialize any non-zero default values */
249 	spair->remote = -1;
250 	spair->flags = SPAIR_FLAGS_DEFAULT;
251 
252 	k_sem_init(&spair->sem, 1, 1);
253 	k_pipe_init(&spair->recv_q, spair->buf, sizeof(spair->buf));
254 	k_poll_signal_init(&spair->readable);
255 	k_poll_signal_init(&spair->writeable);
256 
257 	/* A new socket is always writeable after creation */
258 	res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
259 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
260 
261 	spair->remote = z_reserve_fd();
262 	if (spair->remote == -1) {
263 		errno = ENFILE;
264 		goto cleanup;
265 	}
266 
267 	z_finalize_fd(spair->remote, spair,
268 		      (const struct fd_op_vtable *)&spair_fd_op_vtable);
269 
270 	goto out;
271 
272 cleanup:
273 	spair_delete(spair);
274 	spair = NULL;
275 
276 out:
277 	return spair;
278 }
279 
z_impl_zsock_socketpair(int family,int type,int proto,int * sv)280 int z_impl_zsock_socketpair(int family, int type, int proto, int *sv)
281 {
282 	int res;
283 	size_t i;
284 	struct spair *obj[2] = {};
285 
286 	if (family != AF_UNIX) {
287 		errno = EAFNOSUPPORT;
288 		res = -1;
289 		goto errout;
290 	}
291 
292 	if (type != SOCK_STREAM) {
293 		errno = EPROTOTYPE;
294 		res = -1;
295 		goto errout;
296 	}
297 
298 	if (proto != 0) {
299 		errno = EPROTONOSUPPORT;
300 		res = -1;
301 		goto errout;
302 	}
303 
304 	if (sv == NULL) {
305 		/* not listed in normative spec, but mimics Linux behaviour */
306 		errno = EFAULT;
307 		res = -1;
308 		goto errout;
309 	}
310 
311 	for (i = 0; i < 2; ++i) {
312 		obj[i] = spair_new();
313 		if (!obj[i]) {
314 			res = -1;
315 			goto cleanup;
316 		}
317 	}
318 
319 	/* connect the two endpoints */
320 	swap32(&obj[0]->remote, &obj[1]->remote);
321 
322 	for (i = 0; i < 2; ++i) {
323 		sv[i] = obj[i]->remote;
324 		k_sem_give(&obj[0]->sem);
325 	}
326 
327 	return 0;
328 
329 cleanup:
330 	for (i = 0; i < 2; ++i) {
331 		spair_delete(obj[i]);
332 	}
333 
334 errout:
335 	return res;
336 }
337 
338 #ifdef CONFIG_USERSPACE
z_vrfy_zsock_socketpair(int family,int type,int proto,int * sv)339 int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv)
340 {
341 	int ret;
342 	int tmp[2];
343 
344 	if (!sv || Z_SYSCALL_MEMORY_WRITE(sv, sizeof(tmp)) != 0) {
345 		/* not listed in normative spec, but mimics linux behaviour */
346 		errno = EFAULT;
347 		ret = -1;
348 		goto out;
349 	}
350 
351 	ret = z_impl_zsock_socketpair(family, type, proto, tmp);
352 	if (ret == 0) {
353 		Z_OOPS(z_user_to_copy(sv, tmp, sizeof(tmp)));
354 	}
355 
356 out:
357 	return ret;
358 }
359 
360 #include <syscalls/zsock_socketpair_mrsh.c>
361 #endif /* CONFIG_USERSPACE */
362 
363 /**
364  * Write data to one end of a @ref spair
365  *
366  * Data written on one file descriptor of a socketpair can be read at the
367  * other end using common POSIX calls such as read(2) or recv(2).
368  *
369  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
370  * this function will return immediately. If no data was written on a
371  * non-blocking file descriptor, then -1 will be returned and @ref errno will
372  * be set to @ref EAGAIN.
373  *
374  * Blocking write operations occur when the @ref O_NONBLOCK flag is @em not
375  * set and there is insufficient space in the @em remote @ref spair.pipe.
376  *
377  * Such a blocking write will suspend execution of the current thread until
378  * one of two possible results is received on the @em remote
379  * @ref spair.writeable:
380  *
381  * 1) @ref SPAIR_SIG_DATA - data has been read from the @em remote
382  *    @ref spair.pipe. Thus, allowing more data to be written.
383  *
384  * 2) @ref SPAIR_SIG_CANCEL - the @em remote socketpair endpoint was closed
385  *    Receipt of this result is analogous to SIGPIPE from POSIX
386  *    ("Write on a pipe with no one to read it."). In this case, the function
387  *    will return -1 and set @ref errno to @ref EPIPE.
388  *
389  * @param obj the address of an @ref spair object cast to `void *`
390  * @param buffer the buffer to write
391  * @param count the number of bytes to write from @p buffer
392  *
393  * @return on success, a number > 0 representing the number of bytes written
394  * @return -1 on error, with @ref errno set appropriately.
395  */
spair_write(void * obj,const void * buffer,size_t count)396 static ssize_t spair_write(void *obj, const void *buffer, size_t count)
397 {
398 	int res;
399 	size_t avail;
400 	bool is_nonblock;
401 	size_t bytes_written;
402 	bool have_local_sem = false;
403 	bool have_remote_sem = false;
404 	bool will_block = false;
405 	struct spair *const spair = (struct spair *)obj;
406 	struct spair *remote = NULL;
407 
408 	if (obj == NULL || buffer == NULL || count == 0) {
409 		errno = EINVAL;
410 		res = -1;
411 		goto out;
412 	}
413 
414 	res = k_sem_take(&spair->sem, K_NO_WAIT);
415 	is_nonblock = sock_is_nonblock(spair);
416 	if (res < 0) {
417 		if (is_nonblock) {
418 			errno = EAGAIN;
419 			res = -1;
420 			goto out;
421 		}
422 
423 		res = k_sem_take(&spair->sem, K_FOREVER);
424 		if (res < 0) {
425 			errno = -res;
426 			res = -1;
427 			goto out;
428 		}
429 		is_nonblock = sock_is_nonblock(spair);
430 	}
431 
432 	have_local_sem = true;
433 
434 	remote = z_get_fd_obj(spair->remote,
435 		(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
436 
437 	if (remote == NULL) {
438 		errno = EPIPE;
439 		res = -1;
440 		goto out;
441 	}
442 
443 	res = k_sem_take(&remote->sem, K_NO_WAIT);
444 	if (res < 0) {
445 		if (is_nonblock) {
446 			errno = EAGAIN;
447 			res = -1;
448 			goto out;
449 		}
450 		res = k_sem_take(&remote->sem, K_FOREVER);
451 		if (res < 0) {
452 			errno = -res;
453 			res = -1;
454 			goto out;
455 		}
456 	}
457 
458 	have_remote_sem = true;
459 
460 	avail = spair_write_avail(spair);
461 
462 	if (avail == 0) {
463 		if (is_nonblock) {
464 			errno = EAGAIN;
465 			res = -1;
466 			goto out;
467 		}
468 		will_block = true;
469 	}
470 
471 	if (will_block) {
472 		if (k_is_in_isr()) {
473 			errno = EAGAIN;
474 			res = -1;
475 			goto out;
476 		}
477 
478 		for (int signaled = false, result = -1; !signaled;
479 			result = -1) {
480 
481 			struct k_poll_event events[] = {
482 				K_POLL_EVENT_INITIALIZER(
483 					K_POLL_TYPE_SIGNAL,
484 					K_POLL_MODE_NOTIFY_ONLY,
485 					&remote->writeable),
486 			};
487 
488 			k_sem_give(&remote->sem);
489 			have_remote_sem = false;
490 
491 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
492 			if (res < 0) {
493 				errno = -res;
494 				res = -1;
495 				goto out;
496 			}
497 
498 			remote = z_get_fd_obj(spair->remote,
499 				(const struct fd_op_vtable *)
500 				&spair_fd_op_vtable, 0);
501 
502 			if (remote == NULL) {
503 				errno = EPIPE;
504 				res = -1;
505 				goto out;
506 			}
507 
508 			res = k_sem_take(&remote->sem, K_FOREVER);
509 			if (res < 0) {
510 				errno = -res;
511 				res = -1;
512 				goto out;
513 			}
514 
515 			have_remote_sem = true;
516 
517 			k_poll_signal_check(&remote->writeable, &signaled,
518 					    &result);
519 			if (!signaled) {
520 				continue;
521 			}
522 
523 			switch (result) {
524 				case SPAIR_SIG_DATA: {
525 					break;
526 				}
527 
528 				case SPAIR_SIG_CANCEL: {
529 					errno = EPIPE;
530 					res = -1;
531 					goto out;
532 				}
533 
534 				default: {
535 					__ASSERT(false,
536 						"unrecognized result: %d",
537 						result);
538 					continue;
539 				}
540 			}
541 
542 			/* SPAIR_SIG_DATA was received */
543 			break;
544 		}
545 	}
546 
547 	res = k_pipe_put(&remote->recv_q, (void *)buffer, count,
548 			 &bytes_written, 1, K_NO_WAIT);
549 	__ASSERT(res == 0, "k_pipe_put() failed: %d", res);
550 
551 	if (spair_write_avail(spair) == 0) {
552 		k_poll_signal_reset(&remote->writeable);
553 	}
554 
555 	res = k_poll_signal_raise(&remote->readable, SPAIR_SIG_DATA);
556 	__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
557 
558 	res = bytes_written;
559 
560 out:
561 
562 	if (remote != NULL && have_remote_sem) {
563 		k_sem_give(&remote->sem);
564 	}
565 	if (spair != NULL && have_local_sem) {
566 		k_sem_give(&spair->sem);
567 	}
568 
569 	return res;
570 }
571 
572 /**
573  * Read data from one end of a @ref spair
574  *
575  * Data written on one file descriptor of a socketpair (with e.g. write(2) or
576  * send(2)) can be read at the other end using common POSIX calls such as
577  * read(2) or recv(2).
578  *
579  * If the underlying file descriptor has the @ref O_NONBLOCK flag set then
580  * this function will return immediately. If no data was read from a
581  * non-blocking file descriptor, then -1 will be returned and @ref errno will
582  * be set to @ref EAGAIN.
583  *
584  * Blocking read operations occur when the @ref O_NONBLOCK flag is @em not set
585  * and there are no bytes to read in the @em local @ref spair.pipe.
586  *
587  * Such a blocking read will suspend execution of the current thread until
588  * one of two possible results is received on the @em local
589  * @ref spair.readable:
590  *
591  * -# @ref SPAIR_SIG_DATA - data has been written to the @em local
592  *    @ref spair.pipe. Thus, allowing more data to be read.
593  *
594  * -# @ref SPAIR_SIG_CANCEL - read of the the @em local @spair.pipe
595  *    must be cancelled for some reason (e.g. the file descriptor will be
596  *    closed imminently). In this case, the function will return -1 and set
597  *    @ref errno to @ref EINTR.
598  *
599  * @param obj the address of an @ref spair object cast to `void *`
600  * @param buffer the buffer in which to read
601  * @param count the number of bytes to read
602  *
603  * @return on success, a number > 0 representing the number of bytes written
604  * @return -1 on error, with @ref errno set appropriately.
605  */
spair_read(void * obj,void * buffer,size_t count)606 static ssize_t spair_read(void *obj, void *buffer, size_t count)
607 {
608 	int res;
609 	bool is_connected;
610 	size_t avail;
611 	bool is_nonblock;
612 	size_t bytes_read;
613 	bool have_local_sem = false;
614 	bool will_block = false;
615 	struct spair *const spair = (struct spair *)obj;
616 
617 	if (obj == NULL || buffer == NULL || count == 0) {
618 		errno = EINVAL;
619 		res = -1;
620 		goto out;
621 	}
622 
623 	res = k_sem_take(&spair->sem, K_NO_WAIT);
624 	is_nonblock = sock_is_nonblock(spair);
625 	if (res < 0) {
626 		if (is_nonblock) {
627 			errno = EAGAIN;
628 			res = -1;
629 			goto out;
630 		}
631 
632 		res = k_sem_take(&spair->sem, K_FOREVER);
633 		if (res < 0) {
634 			errno = -res;
635 			res = -1;
636 			goto out;
637 		}
638 		is_nonblock = sock_is_nonblock(spair);
639 	}
640 
641 	have_local_sem = true;
642 
643 	is_connected = sock_is_connected(spair);
644 	avail = spair_read_avail(spair);
645 
646 	if (avail == 0) {
647 		if (!is_connected) {
648 			/* signal EOF */
649 			res = 0;
650 			goto out;
651 		}
652 
653 		if (is_nonblock) {
654 			errno = EAGAIN;
655 			res = -1;
656 			goto out;
657 		}
658 
659 		will_block = true;
660 	}
661 
662 	if (will_block) {
663 		if (k_is_in_isr()) {
664 			errno = EAGAIN;
665 			res = -1;
666 			goto out;
667 		}
668 
669 		for (int signaled = false, result = -1; !signaled;
670 			result = -1) {
671 
672 			struct k_poll_event events[] = {
673 				K_POLL_EVENT_INITIALIZER(
674 					K_POLL_TYPE_SIGNAL,
675 					K_POLL_MODE_NOTIFY_ONLY,
676 					&spair->readable
677 				),
678 			};
679 
680 			k_sem_give(&spair->sem);
681 			have_local_sem = false;
682 
683 			res = k_poll(events, ARRAY_SIZE(events), K_FOREVER);
684 			__ASSERT(res == 0, "k_poll() failed: %d", res);
685 
686 			res = k_sem_take(&spair->sem, K_FOREVER);
687 			__ASSERT(res == 0, "failed to take local sem: %d", res);
688 
689 			have_local_sem = true;
690 
691 			k_poll_signal_check(&spair->readable, &signaled,
692 					    &result);
693 			if (!signaled) {
694 				continue;
695 			}
696 
697 			switch (result) {
698 				case SPAIR_SIG_DATA: {
699 					break;
700 				}
701 
702 				case SPAIR_SIG_CANCEL: {
703 					errno = EPIPE;
704 					res = -1;
705 					goto out;
706 				}
707 
708 				default: {
709 					__ASSERT(false,
710 						"unrecognized result: %d",
711 						result);
712 					continue;
713 				}
714 			}
715 
716 			/* SPAIR_SIG_DATA was received */
717 			break;
718 		}
719 	}
720 
721 	res = k_pipe_get(&spair->recv_q, (void *)buffer, count, &bytes_read,
722 			 1, K_NO_WAIT);
723 	__ASSERT(res == 0, "k_pipe_get() failed: %d", res);
724 
725 	if (spair_read_avail(spair) == 0 && !sock_is_eof(spair)) {
726 		k_poll_signal_reset(&spair->readable);
727 	}
728 
729 	if (is_connected) {
730 		res = k_poll_signal_raise(&spair->writeable, SPAIR_SIG_DATA);
731 		__ASSERT(res == 0, "k_poll_signal_raise() failed: %d", res);
732 	}
733 
734 	res = bytes_read;
735 
736 out:
737 
738 	if (spair != NULL && have_local_sem) {
739 		k_sem_give(&spair->sem);
740 	}
741 
742 	return res;
743 }
744 
zsock_poll_prepare_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)745 static int zsock_poll_prepare_ctx(struct spair *const spair,
746 				  struct zsock_pollfd *const pfd,
747 				  struct k_poll_event **pev,
748 				  struct k_poll_event *pev_end)
749 {
750 	int res;
751 
752 	struct spair *remote = NULL;
753 	bool have_remote_sem = false;
754 
755 	if (pfd->events & ZSOCK_POLLIN) {
756 
757 		/* Tell poll() to short-circuit wait */
758 		if (sock_is_eof(spair)) {
759 			res = -EALREADY;
760 			goto out;
761 		}
762 
763 		if (*pev == pev_end) {
764 			res = -ENOMEM;
765 			goto out;
766 		}
767 
768 		/* Wait until data has been written to the local end */
769 		(*pev)->obj = &spair->readable;
770 	}
771 
772 	if (pfd->events & ZSOCK_POLLOUT) {
773 
774 		/* Tell poll() to short-circuit wait */
775 		if (!sock_is_connected(spair)) {
776 			res = -EALREADY;
777 			goto out;
778 		}
779 
780 		if (*pev == pev_end) {
781 			res = -ENOMEM;
782 			goto out;
783 		}
784 
785 		remote = z_get_fd_obj(spair->remote,
786 			(const struct fd_op_vtable *)
787 			&spair_fd_op_vtable, 0);
788 
789 		__ASSERT(remote != NULL, "remote is NULL");
790 
791 		res = k_sem_take(&remote->sem, K_FOREVER);
792 		if (res < 0) {
793 			goto out;
794 		}
795 
796 		have_remote_sem = true;
797 
798 		/* Wait until the recv queue on the remote end is no longer full */
799 		(*pev)->obj = &remote->writeable;
800 	}
801 
802 	(*pev)->type = K_POLL_TYPE_SIGNAL;
803 	(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
804 	(*pev)->state = K_POLL_STATE_NOT_READY;
805 
806 	(*pev)++;
807 
808 	res = 0;
809 
810 out:
811 
812 	if (remote != NULL && have_remote_sem) {
813 		k_sem_give(&remote->sem);
814 	}
815 
816 	return res;
817 }
818 
zsock_poll_update_ctx(struct spair * const spair,struct zsock_pollfd * const pfd,struct k_poll_event ** pev)819 static int zsock_poll_update_ctx(struct spair *const spair,
820 				 struct zsock_pollfd *const pfd,
821 				 struct k_poll_event **pev)
822 {
823 	int res;
824 	int signaled;
825 	int result;
826 	struct spair *remote = NULL;
827 	bool have_remote_sem = false;
828 
829 	if (pfd->events & ZSOCK_POLLOUT) {
830 		if (!sock_is_connected(spair)) {
831 			pfd->revents |= ZSOCK_POLLHUP;
832 			goto pollout_done;
833 		}
834 
835 		remote = z_get_fd_obj(spair->remote,
836 			(const struct fd_op_vtable *) &spair_fd_op_vtable, 0);
837 
838 		__ASSERT(remote != NULL, "remote is NULL");
839 
840 		res = k_sem_take(&remote->sem, K_FOREVER);
841 		if (res < 0) {
842 			/* if other end is deleted, this might occur */
843 			goto pollout_done;
844 		}
845 
846 		have_remote_sem = true;
847 
848 		if (spair_write_avail(spair) > 0) {
849 			pfd->revents |= ZSOCK_POLLOUT;
850 			goto pollout_done;
851 		}
852 
853 		/* check to see if op was canceled */
854 		signaled = false;
855 		k_poll_signal_check(&remote->writeable, &signaled, &result);
856 		if (signaled) {
857 			/* Cannot be SPAIR_SIG_DATA, because
858 			 * spair_write_avail() would have
859 			 * returned 0
860 			 */
861 			__ASSERT(result == SPAIR_SIG_CANCEL,
862 				"invalid result %d", result);
863 			pfd->revents |= ZSOCK_POLLHUP;
864 		}
865 	}
866 
867 pollout_done:
868 
869 	if (pfd->events & ZSOCK_POLLIN) {
870 		if (sock_is_eof(spair)) {
871 			pfd->revents |= ZSOCK_POLLIN;
872 			goto pollin_done;
873 		}
874 
875 		if (spair_read_avail(spair) > 0) {
876 			pfd->revents |= ZSOCK_POLLIN;
877 			goto pollin_done;
878 		}
879 
880 		/* check to see if op was canceled */
881 		signaled = false;
882 		k_poll_signal_check(&spair->readable, &signaled, &result);
883 		if (signaled) {
884 			/* Cannot be SPAIR_SIG_DATA, because
885 			 * spair_read_avail() would have
886 			 * returned 0
887 			 */
888 			__ASSERT(result == SPAIR_SIG_CANCEL,
889 					 "invalid result %d", result);
890 			pfd->revents |= ZSOCK_POLLIN;
891 		}
892 	}
893 
894 pollin_done:
895 	res = 0;
896 
897 	(*pev)++;
898 
899 	if (remote != NULL && have_remote_sem) {
900 		k_sem_give(&remote->sem);
901 	}
902 
903 	return res;
904 }
905 
spair_ioctl(void * obj,unsigned int request,va_list args)906 static int spair_ioctl(void *obj, unsigned int request, va_list args)
907 {
908 	int res;
909 	struct zsock_pollfd *pfd;
910 	struct k_poll_event **pev;
911 	struct k_poll_event *pev_end;
912 	int flags = 0;
913 	bool have_local_sem = false;
914 	struct spair *const spair = (struct spair *)obj;
915 
916 	if (spair == NULL) {
917 		errno = EINVAL;
918 		res = -1;
919 		goto out;
920 	}
921 
922 	/* The local sem is always taken in this function. If a subsequent
923 	 * function call requires the remote sem, it must acquire and free the
924 	 * remote sem.
925 	 */
926 	res = k_sem_take(&spair->sem, K_FOREVER);
927 	__ASSERT(res == 0, "failed to take local sem: %d", res);
928 
929 	have_local_sem = true;
930 
931 	switch (request) {
932 		case F_GETFL: {
933 			if (sock_is_nonblock(spair)) {
934 				flags |= O_NONBLOCK;
935 			}
936 
937 			res = flags;
938 			goto out;
939 		}
940 
941 		case F_SETFL: {
942 			flags = va_arg(args, int);
943 
944 			if (flags & O_NONBLOCK) {
945 				spair->flags |= SPAIR_FLAG_NONBLOCK;
946 			} else {
947 				spair->flags &= ~SPAIR_FLAG_NONBLOCK;
948 			}
949 
950 			res = 0;
951 			goto out;
952 		}
953 
954 		case ZFD_IOCTL_FIONBIO: {
955 			spair->flags |= SPAIR_FLAG_NONBLOCK;
956 			res = 0;
957 			goto out;
958 		}
959 
960 		case ZFD_IOCTL_FIONREAD: {
961 			int *nbytes;
962 
963 			nbytes = va_arg(args, int *);
964 			*nbytes = spair_read_avail(spair);
965 
966 			res = 0;
967 			goto out;
968 		}
969 
970 		case ZFD_IOCTL_POLL_PREPARE: {
971 			pfd = va_arg(args, struct zsock_pollfd *);
972 			pev = va_arg(args, struct k_poll_event **);
973 			pev_end = va_arg(args, struct k_poll_event *);
974 
975 			res = zsock_poll_prepare_ctx(obj, pfd, pev, pev_end);
976 			goto out;
977 		}
978 
979 		case ZFD_IOCTL_POLL_UPDATE: {
980 			pfd = va_arg(args, struct zsock_pollfd *);
981 			pev = va_arg(args, struct k_poll_event **);
982 
983 			res = zsock_poll_update_ctx(obj, pfd, pev);
984 			goto out;
985 		}
986 
987 		default: {
988 			errno = EOPNOTSUPP;
989 			res = -1;
990 			goto out;
991 		}
992 	}
993 
994 out:
995 	if (spair != NULL && have_local_sem) {
996 		k_sem_give(&spair->sem);
997 	}
998 
999 	return res;
1000 }
1001 
spair_bind(void * obj,const struct sockaddr * addr,socklen_t addrlen)1002 static int spair_bind(void *obj, const struct sockaddr *addr,
1003 		      socklen_t addrlen)
1004 {
1005 	ARG_UNUSED(obj);
1006 	ARG_UNUSED(addr);
1007 	ARG_UNUSED(addrlen);
1008 
1009 	errno = EISCONN;
1010 	return -1;
1011 }
1012 
spair_connect(void * obj,const struct sockaddr * addr,socklen_t addrlen)1013 static int spair_connect(void *obj, const struct sockaddr *addr,
1014 			 socklen_t addrlen)
1015 {
1016 	ARG_UNUSED(obj);
1017 	ARG_UNUSED(addr);
1018 	ARG_UNUSED(addrlen);
1019 
1020 	errno = EISCONN;
1021 	return -1;
1022 }
1023 
spair_listen(void * obj,int backlog)1024 static int spair_listen(void *obj, int backlog)
1025 {
1026 	ARG_UNUSED(obj);
1027 	ARG_UNUSED(backlog);
1028 
1029 	errno = EINVAL;
1030 	return -1;
1031 }
1032 
spair_accept(void * obj,struct sockaddr * addr,socklen_t * addrlen)1033 static int spair_accept(void *obj, struct sockaddr *addr,
1034 			socklen_t *addrlen)
1035 {
1036 	ARG_UNUSED(obj);
1037 	ARG_UNUSED(addr);
1038 	ARG_UNUSED(addrlen);
1039 
1040 	errno = EOPNOTSUPP;
1041 	return -1;
1042 }
1043 
spair_sendto(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1044 static ssize_t spair_sendto(void *obj, const void *buf, size_t len,
1045 			    int flags, const struct sockaddr *dest_addr,
1046 				 socklen_t addrlen)
1047 {
1048 	ARG_UNUSED(flags);
1049 	ARG_UNUSED(dest_addr);
1050 	ARG_UNUSED(addrlen);
1051 
1052 	return spair_write(obj, buf, len);
1053 }
1054 
spair_sendmsg(void * obj,const struct msghdr * msg,int flags)1055 static ssize_t spair_sendmsg(void *obj, const struct msghdr *msg,
1056 			     int flags)
1057 {
1058 	ARG_UNUSED(flags);
1059 
1060 	int res;
1061 	size_t len = 0;
1062 	bool is_connected;
1063 	size_t avail;
1064 	bool is_nonblock;
1065 	struct spair *const spair = (struct spair *)obj;
1066 
1067 	if (spair == NULL || msg == NULL) {
1068 		errno = EINVAL;
1069 		res = -1;
1070 		goto out;
1071 	}
1072 
1073 	is_connected = sock_is_connected(spair);
1074 	avail = is_connected ? spair_write_avail(spair) : 0;
1075 	is_nonblock = sock_is_nonblock(spair);
1076 
1077 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1078 		/* check & msg->msg_iov[i]? */
1079 		/* check & msg->msg_iov[i].iov_base? */
1080 		len += msg->msg_iov[i].iov_len;
1081 	}
1082 
1083 	if (!is_connected) {
1084 		errno = EPIPE;
1085 		res = -1;
1086 		goto out;
1087 	}
1088 
1089 	if (len == 0) {
1090 		res = 0;
1091 		goto out;
1092 	}
1093 
1094 	if (len > avail && is_nonblock) {
1095 		errno = EMSGSIZE;
1096 		res = -1;
1097 		goto out;
1098 	}
1099 
1100 	for (size_t i = 0; i < msg->msg_iovlen; ++i) {
1101 		res = spair_write(spair, msg->msg_iov[i].iov_base,
1102 			msg->msg_iov[i].iov_len);
1103 		if (res == -1) {
1104 			goto out;
1105 		}
1106 	}
1107 
1108 	res = len;
1109 
1110 out:
1111 	return res;
1112 }
1113 
spair_recvfrom(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1114 static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len,
1115 			      int flags, struct sockaddr *src_addr,
1116 				   socklen_t *addrlen)
1117 {
1118 	(void)flags;
1119 	(void)src_addr;
1120 	(void)addrlen;
1121 
1122 	if (addrlen != NULL) {
1123 		/* Protocol (PF_UNIX) does not support addressing with connected
1124 		 * sockets and, therefore, it is unspecified behaviour to modify
1125 		 * src_addr. However, it would be ambiguous to leave addrlen
1126 		 * untouched if the user expects it to be updated. It is not
1127 		 * mentioned that modifying addrlen is unspecified. Therefore
1128 		 * we choose to eliminate ambiguity.
1129 		 *
1130 		 * Setting it to zero mimics Linux's behaviour.
1131 		 */
1132 		*addrlen = 0;
1133 	}
1134 
1135 	return spair_read(obj, buf, max_len);
1136 }
1137 
spair_getsockopt(void * obj,int level,int optname,void * optval,socklen_t * optlen)1138 static int spair_getsockopt(void *obj, int level, int optname,
1139 			    void *optval, socklen_t *optlen)
1140 {
1141 	ARG_UNUSED(obj);
1142 	ARG_UNUSED(level);
1143 	ARG_UNUSED(optname);
1144 	ARG_UNUSED(optval);
1145 	ARG_UNUSED(optlen);
1146 
1147 	errno = ENOPROTOOPT;
1148 	return -1;
1149 }
1150 
spair_setsockopt(void * obj,int level,int optname,const void * optval,socklen_t optlen)1151 static int spair_setsockopt(void *obj, int level, int optname,
1152 			    const void *optval, socklen_t optlen)
1153 {
1154 	ARG_UNUSED(obj);
1155 	ARG_UNUSED(level);
1156 	ARG_UNUSED(optname);
1157 	ARG_UNUSED(optval);
1158 	ARG_UNUSED(optlen);
1159 
1160 	errno = ENOPROTOOPT;
1161 	return -1;
1162 }
1163 
spair_close(void * obj)1164 static int spair_close(void *obj)
1165 {
1166 	struct spair *const spair = (struct spair *)obj;
1167 	int res;
1168 
1169 	res = k_sem_take(&spair->sem, K_FOREVER);
1170 	__ASSERT(res == 0, "failed to take local sem: %d", res);
1171 
1172 	/* disconnect the remote endpoint */
1173 	spair_delete(spair);
1174 
1175 	/* Note that the semaphore released already so need to do it here */
1176 
1177 	return 0;
1178 }
1179 
1180 static const struct socket_op_vtable spair_fd_op_vtable = {
1181 	.fd_vtable = {
1182 		.read = spair_read,
1183 		.write = spair_write,
1184 		.close = spair_close,
1185 		.ioctl = spair_ioctl,
1186 	},
1187 	.bind = spair_bind,
1188 	.connect = spair_connect,
1189 	.listen = spair_listen,
1190 	.accept = spair_accept,
1191 	.sendto = spair_sendto,
1192 	.sendmsg = spair_sendmsg,
1193 	.recvfrom = spair_recvfrom,
1194 	.getsockopt = spair_getsockopt,
1195 	.setsockopt = spair_setsockopt,
1196 };
1197