1 /*
2  * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 #include <errno.h>
7 #include <poll.h>
8 #include <pthread.h>
9 #include <stdbool.h>
10 #include <stdio.h>
11 #include <string.h>
12 #include <sys/socket.h>
13 #include <unistd.h>
14 
15 #ifdef __ZEPHYR__
16 #include <zephyr/kernel.h>
17 #endif
18 
19 #define NUM_SOCKETPAIRS 3
20 #define NUM_REPITITIONS 3
21 
22 struct context {
23 	int spair[2];
24 	pthread_t thread;
25 	const char *name;
26 };
27 
28 static const char *const names[] = {
29 	"Alpha",
30 	"Bravo",
31 	"Charlie",
32 };
33 
34 #ifdef __ZEPHYR__
35 #define STACK_SIZE (1024)
36 K_THREAD_STACK_ARRAY_DEFINE(stack, NUM_SOCKETPAIRS, STACK_SIZE);
37 #endif
38 
hello(int fd,const char * name)39 static int hello(int fd, const char *name)
40 {
41 	int res;
42 	char buf[32] = {0};
43 
44 	/* check for an echo of what is written */
45 	res = write(fd, name, strlen(name));
46 	if (res < 0) {
47 		perror("write");
48 	} else if (res != strlen(name)) {
49 		printf("only wrote %d/%d bytes", res, (int)strlen(name));
50 		return -EIO;
51 	}
52 
53 	res = read(fd, buf, sizeof(buf) - 1);
54 	if (res < 0) {
55 		perror("read");
56 	} else if (res != strlen(name)) {
57 		printf("only read %d/%d bytes", res, (int)strlen(name));
58 		return -EIO;
59 	}
60 
61 	if (strncmp(buf, name, sizeof(buf)) != 0) {
62 		printf("expected %s\n", name);
63 		return -EINVAL;
64 	}
65 
66 	return 0;
67 }
68 
fun(void * arg)69 static void *fun(void *arg)
70 {
71 	struct context *ctx = (struct context *)arg;
72 	int fd = ctx->spair[1];
73 	const char *name = ctx->name;
74 
75 	for (size_t i = 0; i < NUM_REPITITIONS; ++i) {
76 		if (hello(fd, name) < 0) {
77 			break;
78 		}
79 	}
80 
81 	return NULL;
82 }
83 
fd_to_idx(int fd,const struct context * ctx,size_t n)84 static int fd_to_idx(int fd, const struct context *ctx, size_t n)
85 {
86 	int res = -1;
87 	size_t i;
88 
89 	for (i = 0; i < n; ++i) {
90 		if (ctx[i].spair[0] == fd) {
91 			res = i;
92 			break;
93 		}
94 	}
95 
96 	return res;
97 }
98 
setup(struct context * ctx,size_t n)99 static int setup(struct context *ctx, size_t n)
100 {
101 	int res;
102 #ifdef __ZEPHYR__
103 	pthread_attr_t attr;
104 	pthread_attr_t *attrp = &attr;
105 #else
106 	pthread_attr_t *attrp = NULL;
107 #endif
108 
109 	for (size_t i = 0; i < n; ++i) {
110 		ctx[i].name = (char *)names[i];
111 		res = socketpair(AF_UNIX, SOCK_STREAM, 0, ctx[i].spair);
112 		if (res < 0) {
113 			perror("socketpair");
114 			return res;
115 		}
116 
117 #ifdef __ZEPHYR__
118 		/* Zephyr requires a non-NULL attribute for pthread_create */
119 		res = pthread_attr_init(attrp);
120 		if (res != 0) {
121 			errno = res;
122 			perror("pthread_attr_init");
123 			return -res;
124 		}
125 
126 		res = pthread_attr_setstack(&attr, &stack[i], STACK_SIZE);
127 		if (res != 0) {
128 			errno = res;
129 			perror("pthread_attr_setstack");
130 			return -res;
131 		}
132 #endif
133 
134 		res = pthread_create(&ctx[i].thread, attrp, fun, &ctx[i]);
135 		if (res != 0) {
136 			errno = res;
137 			perror("pthread_create");
138 			return -res;
139 		}
140 
141 		printf("%s: socketpair: %d <=> %d\n",
142 			ctx[i].name, ctx[i].spair[0], ctx[i].spair[1]);
143 	}
144 
145 	return 0;
146 }
147 
teardown(struct context * ctx,size_t n)148 static void teardown(struct context *ctx, size_t n)
149 {
150 	void *unused;
151 
152 	for (size_t i = 0; i < n; ++i) {
153 		pthread_join(ctx[i].thread, &unused);
154 
155 		close(ctx[i].spair[0]);
156 		ctx[i].spair[0] = -1;
157 
158 		close(ctx[i].spair[1]);
159 		ctx[i].spair[1] = -1;
160 	}
161 }
162 
setup_poll(const struct context * ctx,struct pollfd * fds,size_t n)163 static void setup_poll(const struct context *ctx, struct pollfd *fds, size_t n)
164 {
165 	for (size_t i = 0; i < n; ++i) {
166 		fds[i].fd = ctx[i].spair[0];
167 		fds[i].events = POLLIN;
168 		fds[i].revents = 0;
169 	}
170 }
171 
handle_poll_events(const struct context * ctx,struct pollfd * fds,size_t n,size_t n_events)172 static int handle_poll_events(const struct context *ctx, struct pollfd *fds, size_t n,
173 			      size_t n_events)
174 {
175 	int res;
176 	int idx;
177 	char buf[32];
178 	size_t event = 0;
179 
180 	for (size_t i = 0; event < n_events && i < n; ++i) {
181 		idx = fd_to_idx(fds[i].fd, ctx, n);
182 		if (idx < 0) {
183 			printf("failed to find fd %d in any active context\n", fds[i].fd);
184 			continue;
185 		}
186 
187 		if ((fds[i].revents & POLLERR) != 0) {
188 			printf("fd: %d: error\n", fds[i].fd);
189 			return -EIO;
190 		} else if ((fds[i].revents & POLLIN) != 0) {
191 			memset(buf, '\0', sizeof(buf));
192 
193 			/* echo back the same thing that was read */
194 			res = read(fds[i].fd, buf, sizeof(buf));
195 			if (res < 0) {
196 				perror("read");
197 				return -errno;
198 			}
199 
200 			printf("main: read '%s' on fd %d\n", buf, fds[i].fd);
201 			if (strncmp(ctx[idx].name, buf, sizeof(buf)) != 0) {
202 				printf("main: expected: '%s' actual: '%s'\n", ctx[idx].name, buf);
203 				return -EINVAL;
204 			}
205 
206 			res = write(fds[i].fd, buf, res);
207 			if (res < 0) {
208 				perror("write");
209 				return -errno;
210 			}
211 
212 			++event;
213 		}
214 	}
215 
216 	if (event != n_events) {
217 		printf("main: unhandled events remaining\n");
218 		return -EINVAL;
219 	}
220 
221 	return n_events;
222 }
223 
main(void)224 int main(void)
225 {
226 	int res;
227 	struct context ctx[NUM_SOCKETPAIRS] = {};
228 	struct pollfd fds[NUM_SOCKETPAIRS] = {};
229 
230 	printf("setting-up\n");
231 	res = setup(ctx, NUM_SOCKETPAIRS);
232 	if (res < 0) {
233 		goto out;
234 	}
235 
236 	for (size_t n_events = NUM_SOCKETPAIRS * NUM_REPITITIONS; n_events > 0; n_events -= res) {
237 
238 		setup_poll(ctx, fds, NUM_SOCKETPAIRS);
239 		res = poll(fds, NUM_SOCKETPAIRS, -1);
240 		if (res < 0) {
241 			perror("poll");
242 			goto out;
243 		}
244 
245 		res = handle_poll_events(ctx, fds, NUM_SOCKETPAIRS, res);
246 		if (res < 0) {
247 			goto out;
248 		}
249 	}
250 
251 	res = 0;
252 
253 out:
254 	printf("tearing-down\n");
255 	teardown(ctx, NUM_SOCKETPAIRS);
256 
257 	printf("%s\n", res == 0 ? "SUCCESS" : "FAILURE");
258 
259 	return res;
260 }
261