1 /*
2  * Copyright (c) 2017 Linaro Limited
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <stdio.h>
8 #include <stdbool.h>
9 #include <errno.h>
10 #include <stdlib.h>
11 
12 #if !defined(__ZEPHYR__)
13 
14 #include <netinet/in.h>
15 #include <sys/socket.h>
16 #include <arpa/inet.h>
17 #include <unistd.h>
18 #include <fcntl.h>
19 #include <poll.h>
20 
21 #define USE_IPV6
22 
23 #else
24 
25 #include <zephyr/posix/fcntl.h>
26 #include <zephyr/net/socket.h>
27 #include <zephyr/kernel.h>
28 
29 #include "net_sample_common.h"
30 
31 #ifdef CONFIG_NET_IPV6
32 #define USE_IPV6
33 #endif
34 
35 #endif
36 
37 /* For Zephyr, keep max number of fd's in sync with max poll() capacity */
38 #ifdef CONFIG_ZVFS_POLL_MAX
39 #define NUM_FDS CONFIG_ZVFS_POLL_MAX
40 #else
41 #define NUM_FDS 5
42 #endif
43 
44 #define BIND_PORT 4242
45 
46 /* Number of simultaneous client connections will be NUM_FDS be minus 2 */
47 struct pollfd pollfds[NUM_FDS];
48 int pollnum;
49 
50 #define fatal(msg, ...) { \
51 		printf("Error: " msg "\n", ##__VA_ARGS__); \
52 		exit(1); \
53 	}
54 
55 
setblocking(int fd,bool val)56 static void setblocking(int fd, bool val)
57 {
58 	int fl, res;
59 
60 	fl = fcntl(fd, F_GETFL, 0);
61 	if (fl == -1) {
62 		fatal("fcntl(F_GETFL): %d", errno);
63 	}
64 
65 	if (val) {
66 		fl &= ~O_NONBLOCK;
67 	} else {
68 		fl |= O_NONBLOCK;
69 	}
70 
71 	res = fcntl(fd, F_SETFL, fl);
72 	if (fl == -1) {
73 		fatal("fcntl(F_SETFL): %d", errno);
74 	}
75 }
76 
pollfds_add(int fd)77 int pollfds_add(int fd)
78 {
79 	int i;
80 	if (pollnum < NUM_FDS) {
81 		i = pollnum++;
82 	} else {
83 		for (i = 0; i < NUM_FDS; i++) {
84 			if (pollfds[i].fd < 0) {
85 				goto found;
86 			}
87 		}
88 
89 		return -1;
90 	}
91 
92 found:
93 	pollfds[i].fd = fd;
94 	pollfds[i].events = POLLIN;
95 
96 	return 0;
97 }
98 
pollfds_del(int fd)99 void pollfds_del(int fd)
100 {
101 	for (int i = 0; i < pollnum; i++) {
102 		if (pollfds[i].fd == fd) {
103 			pollfds[i].fd = -1;
104 			break;
105 		}
106 	}
107 }
108 
main(void)109 int main(void)
110 {
111 	int res;
112 	static int counter;
113 	int num_servs = 0;
114 #if !defined(USE_IPV6) || !(CONFIG_SOC_SERIES_CC32XX)
115 	int serv4;
116 	struct sockaddr_in bind_addr4 = {
117 		.sin_family = AF_INET,
118 		.sin_port = htons(BIND_PORT),
119 		.sin_addr = {
120 			.s_addr = htonl(INADDR_ANY),
121 		},
122 	};
123 #endif
124 #ifdef USE_IPV6
125 	int serv6;
126 	struct sockaddr_in6 bind_addr6 = {
127 		.sin6_family = AF_INET6,
128 		.sin6_port = htons(BIND_PORT),
129 		.sin6_addr = IN6ADDR_ANY_INIT,
130 	};
131 #endif
132 
133 	wait_for_network();
134 
135 #if !defined(USE_IPV6) || !(CONFIG_SOC_SERIES_CC32XX)
136 	serv4 = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
137 	if (serv4 < 0) {
138 		printf("error: socket: %d\n", errno);
139 		exit(1);
140 	}
141 
142 	res = bind(serv4, (struct sockaddr *)&bind_addr4, sizeof(bind_addr4));
143 	if (res == -1) {
144 		printf("Cannot bind IPv4, errno: %d\n", errno);
145 	}
146 	num_servs++;
147 
148 	setblocking(serv4, false);
149 	listen(serv4, 5);
150 	pollfds_add(serv4);
151 #endif
152 
153 #ifdef USE_IPV6
154 	serv6 = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
155 	if (serv6 < 0) {
156 		printf("error: socket(AF_INET6): %d\n", errno);
157 		exit(1);
158 	}
159 	#ifdef IPV6_V6ONLY
160 	/* For Linux, we need to make socket IPv6-only to bind it to the
161 	 * same port as IPv4 socket above.
162 	 */
163 	int TRUE = 1;
164 	res = setsockopt(serv6, IPPROTO_IPV6, IPV6_V6ONLY, &TRUE, sizeof(TRUE));
165 	if (res < 0) {
166 		printf("error: setsockopt: %d\n", errno);
167 		exit(1);
168 	}
169 	#endif
170 	res = bind(serv6, (struct sockaddr *)&bind_addr6, sizeof(bind_addr6));
171 	if (res == -1) {
172 		printf("Cannot bind IPv6, errno: %d\n", errno);
173 	}
174 	num_servs++;
175 
176 	setblocking(serv6, false);
177 	listen(serv6, 5);
178 	pollfds_add(serv6);
179 #endif
180 
181 	printf("Asynchronous TCP echo server waits for connections on "
182 	       "port %d...\n", BIND_PORT);
183 
184 	while (1) {
185 		struct sockaddr_storage client_addr;
186 		socklen_t client_addr_len = sizeof(client_addr);
187 		char addr_str[32];
188 
189 		res = poll(pollfds, pollnum, -1);
190 		if (res == -1) {
191 			printf("poll error: %d\n", errno);
192 			continue;
193 		}
194 
195 		for (int i = 0; i < pollnum; i++) {
196 			if (!(pollfds[i].revents & POLLIN)) {
197 				continue;
198 			}
199 			int fd = pollfds[i].fd;
200 			if (i < num_servs) {
201 				/* If server socket */
202 				int client = accept(fd, (struct sockaddr *)&client_addr,
203 						    &client_addr_len);
204 				void *addr = &((struct sockaddr_in *)&client_addr)->sin_addr;
205 
206 				if (client < 0) {
207 					printf("error: accept: %d\n", errno);
208 					continue;
209 				}
210 				inet_ntop(client_addr.ss_family, addr,
211 					  addr_str, sizeof(addr_str));
212 				printf("Connection #%d from %s fd=%d\n", counter++,
213 				       addr_str, client);
214 				if (pollfds_add(client) < 0) {
215 					static char msg[] = "Too many connections\n";
216 
217 					res = send(client, msg, sizeof(msg) - 1, 0);
218 					if (res < 0) {
219 						printf("error: send: %d\n", errno);
220 					}
221 					close(client);
222 				} else {
223 					setblocking(client, false);
224 				}
225 			} else {
226 				char buf[128];
227 				int len = recv(fd, buf, sizeof(buf), 0);
228 				if (len <= 0) {
229 					if (len < 0) {
230 						printf("error: recv: %d\n", errno);
231 					}
232 error:
233 					pollfds_del(fd);
234 					close(fd);
235 					printf("Connection fd=%d closed\n", fd);
236 				} else {
237 					int out_len;
238 					const char *p;
239 					/* We implement semi-async server,
240 					 * where reads are async, but writes
241 					 * *can* be sync (blocking). Note that
242 					 * in majority of cases they expected
243 					 * to not block, but to be robust, we
244 					 * handle all possibilities.
245 					 */
246 					setblocking(fd, true);
247 
248 					for (p = buf; len; len -= out_len) {
249 						out_len = send(fd, p, len, 0);
250 						if (out_len < 0) {
251 							printf("error: "
252 							       "send: %d\n",
253 							       errno);
254 							goto error;
255 						}
256 						p += out_len;
257 					}
258 
259 					setblocking(fd, false);
260 				}
261 			}
262 		}
263 	}
264 	return 0;
265 }
266