1 /*
2  * Copyright (c) 2017 Linaro Limited
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <stdio.h>
8 #include <stdbool.h>
9 #include <errno.h>
10 #include <stdlib.h>
11 
12 #if !defined(__ZEPHYR__)
13 
14 #include <netinet/in.h>
15 #include <sys/socket.h>
16 #include <arpa/inet.h>
17 #include <unistd.h>
18 #include <fcntl.h>
19 #include <poll.h>
20 
21 #define USE_IPV6
22 
23 #else
24 
25 #include <zephyr/posix/fcntl.h>
26 #include <zephyr/posix/netinet/in.h>
27 #include <zephyr/posix/sys/socket.h>
28 #include <zephyr/posix/arpa/inet.h>
29 #include <zephyr/posix/unistd.h>
30 #include <zephyr/posix/poll.h>
31 
32 #include <zephyr/net/socket.h>
33 #include <zephyr/kernel.h>
34 
35 #include "net_sample_common.h"
36 
37 #ifdef CONFIG_NET_IPV6
38 #define USE_IPV6
39 #endif
40 
41 #endif
42 
43 /* For Zephyr, keep max number of fd's in sync with max poll() capacity */
44 #ifdef CONFIG_ZVFS_POLL_MAX
45 #define NUM_FDS CONFIG_ZVFS_POLL_MAX
46 #else
47 #define NUM_FDS 5
48 #endif
49 
50 #define BIND_PORT 4242
51 
52 /* Number of simultaneous client connections will be NUM_FDS be minus 2 */
53 struct pollfd pollfds[NUM_FDS];
54 int pollnum;
55 
56 #define fatal(msg, ...) { \
57 		printf("Error: " msg "\n", ##__VA_ARGS__); \
58 		exit(1); \
59 	}
60 
61 
setblocking(int fd,bool val)62 static void setblocking(int fd, bool val)
63 {
64 	int fl, res;
65 
66 	fl = fcntl(fd, F_GETFL, 0);
67 	if (fl == -1) {
68 		fatal("fcntl(F_GETFL): %d", errno);
69 	}
70 
71 	if (val) {
72 		fl &= ~O_NONBLOCK;
73 	} else {
74 		fl |= O_NONBLOCK;
75 	}
76 
77 	res = fcntl(fd, F_SETFL, fl);
78 	if (fl == -1) {
79 		fatal("fcntl(F_SETFL): %d", errno);
80 	}
81 }
82 
pollfds_add(int fd)83 int pollfds_add(int fd)
84 {
85 	int i;
86 	if (pollnum < NUM_FDS) {
87 		i = pollnum++;
88 	} else {
89 		for (i = 0; i < NUM_FDS; i++) {
90 			if (pollfds[i].fd < 0) {
91 				goto found;
92 			}
93 		}
94 
95 		return -1;
96 	}
97 
98 found:
99 	pollfds[i].fd = fd;
100 	pollfds[i].events = POLLIN;
101 
102 	return 0;
103 }
104 
pollfds_del(int fd)105 void pollfds_del(int fd)
106 {
107 	for (int i = 0; i < pollnum; i++) {
108 		if (pollfds[i].fd == fd) {
109 			pollfds[i].fd = -1;
110 			break;
111 		}
112 	}
113 }
114 
main(void)115 int main(void)
116 {
117 	int res;
118 	static int counter;
119 	int num_servs = 0;
120 #if !defined(USE_IPV6) || !(CONFIG_SOC_SERIES_CC32XX)
121 	int serv4;
122 	struct sockaddr_in bind_addr4 = {
123 		.sin_family = AF_INET,
124 		.sin_port = htons(BIND_PORT),
125 		.sin_addr = {
126 			.s_addr = htonl(INADDR_ANY),
127 		},
128 	};
129 #endif
130 #ifdef USE_IPV6
131 	int serv6;
132 	struct sockaddr_in6 bind_addr6 = {
133 		.sin6_family = AF_INET6,
134 		.sin6_port = htons(BIND_PORT),
135 		.sin6_addr = IN6ADDR_ANY_INIT,
136 	};
137 #endif
138 
139 	wait_for_network();
140 
141 #if !defined(USE_IPV6) || !(CONFIG_SOC_SERIES_CC32XX)
142 	serv4 = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
143 	if (serv4 < 0) {
144 		printf("error: socket: %d\n", errno);
145 		exit(1);
146 	}
147 
148 	res = bind(serv4, (struct sockaddr *)&bind_addr4, sizeof(bind_addr4));
149 	if (res == -1) {
150 		printf("Cannot bind IPv4, errno: %d\n", errno);
151 	}
152 	num_servs++;
153 
154 	setblocking(serv4, false);
155 	listen(serv4, 5);
156 	pollfds_add(serv4);
157 #endif
158 
159 #ifdef USE_IPV6
160 	serv6 = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
161 	if (serv6 < 0) {
162 		printf("error: socket(AF_INET6): %d\n", errno);
163 		exit(1);
164 	}
165 	#ifdef IPV6_V6ONLY
166 	/* For Linux, we need to make socket IPv6-only to bind it to the
167 	 * same port as IPv4 socket above.
168 	 */
169 	int TRUE = 1;
170 	res = setsockopt(serv6, IPPROTO_IPV6, IPV6_V6ONLY, &TRUE, sizeof(TRUE));
171 	if (res < 0) {
172 		printf("error: setsockopt: %d\n", errno);
173 		exit(1);
174 	}
175 	#endif
176 	res = bind(serv6, (struct sockaddr *)&bind_addr6, sizeof(bind_addr6));
177 	if (res == -1) {
178 		printf("Cannot bind IPv6, errno: %d\n", errno);
179 	}
180 	num_servs++;
181 
182 	setblocking(serv6, false);
183 	listen(serv6, 5);
184 	pollfds_add(serv6);
185 #endif
186 
187 	printf("Asynchronous TCP echo server waits for connections on "
188 	       "port %d...\n", BIND_PORT);
189 
190 	while (1) {
191 		struct sockaddr_storage client_addr;
192 		socklen_t client_addr_len = sizeof(client_addr);
193 		char addr_str[32];
194 
195 		res = poll(pollfds, pollnum, -1);
196 		if (res == -1) {
197 			printf("poll error: %d\n", errno);
198 			continue;
199 		}
200 
201 		for (int i = 0; i < pollnum; i++) {
202 			if (!(pollfds[i].revents & POLLIN)) {
203 				continue;
204 			}
205 			int fd = pollfds[i].fd;
206 			if (i < num_servs) {
207 				/* If server socket */
208 				int client = accept(fd, (struct sockaddr *)&client_addr,
209 						    &client_addr_len);
210 				void *addr = &((struct sockaddr_in *)&client_addr)->sin_addr;
211 
212 				if (client < 0) {
213 					printf("error: accept: %d\n", errno);
214 					continue;
215 				}
216 				inet_ntop(client_addr.ss_family, addr,
217 					  addr_str, sizeof(addr_str));
218 				printf("Connection #%d from %s fd=%d\n", counter++,
219 				       addr_str, client);
220 				if (pollfds_add(client) < 0) {
221 					static char msg[] = "Too many connections\n";
222 
223 					res = send(client, msg, sizeof(msg) - 1, 0);
224 					if (res < 0) {
225 						printf("error: send: %d\n", errno);
226 					}
227 					close(client);
228 				} else {
229 					setblocking(client, false);
230 				}
231 			} else {
232 				char buf[128];
233 				int len = recv(fd, buf, sizeof(buf), 0);
234 				if (len <= 0) {
235 					if (len < 0) {
236 						printf("error: recv: %d\n", errno);
237 					}
238 error:
239 					pollfds_del(fd);
240 					close(fd);
241 					printf("Connection fd=%d closed\n", fd);
242 				} else {
243 					int out_len;
244 					const char *p;
245 					/* We implement semi-async server,
246 					 * where reads are async, but writes
247 					 * *can* be sync (blocking). Note that
248 					 * in majority of cases they expected
249 					 * to not block, but to be robust, we
250 					 * handle all possibilities.
251 					 */
252 					setblocking(fd, true);
253 
254 					for (p = buf; len; len -= out_len) {
255 						out_len = send(fd, p, len, 0);
256 						if (out_len < 0) {
257 							printf("error: "
258 							       "send: %d\n",
259 							       errno);
260 							goto error;
261 						}
262 						p += out_len;
263 					}
264 
265 					setblocking(fd, false);
266 				}
267 			}
268 		}
269 	}
270 	return 0;
271 }
272