1 /*
2  * Copyright (c) 2019 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_websocket_client_sample, LOG_LEVEL_DBG);
9 
10 #include <zephyr/posix/sys/socket.h>
11 #include <zephyr/posix/arpa/inet.h>
12 #include <zephyr/posix/unistd.h>
13 
14 #include <zephyr/misc/lorem_ipsum.h>
15 #include <zephyr/net/net_ip.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/tls_credentials.h>
18 #include <zephyr/net/websocket.h>
19 #include <zephyr/random/random.h>
20 #include <zephyr/shell/shell.h>
21 
22 #include "ca_certificate.h"
23 
24 #define SERVER_PORT 9001
25 
26 #if defined(CONFIG_NET_CONFIG_PEER_IPV6_ADDR)
27 #define SERVER_ADDR6  CONFIG_NET_CONFIG_PEER_IPV6_ADDR
28 #else
29 #define SERVER_ADDR6 ""
30 #endif
31 
32 #if defined(CONFIG_NET_CONFIG_PEER_IPV4_ADDR)
33 #define SERVER_ADDR4  CONFIG_NET_CONFIG_PEER_IPV4_ADDR
34 #else
35 #define SERVER_ADDR4 ""
36 #endif
37 
38 static const char lorem_ipsum[] = LOREM_IPSUM;
39 
40 #define MAX_RECV_BUF_LEN (sizeof(lorem_ipsum) - 1)
41 
42 const int ipsum_len = MAX_RECV_BUF_LEN;
43 
44 static uint8_t recv_buf_ipv4[MAX_RECV_BUF_LEN];
45 static uint8_t recv_buf_ipv6[MAX_RECV_BUF_LEN];
46 
47 /* We need to allocate bigger buffer for the websocket data we receive so that
48  * the websocket header fits into it.
49  */
50 #define EXTRA_BUF_SPACE 30
51 
52 static uint8_t temp_recv_buf_ipv4[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
53 static uint8_t temp_recv_buf_ipv6[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
54 
setup_socket(sa_family_t family,const char * server,int port,int * sock,struct sockaddr * addr,socklen_t addr_len)55 static int setup_socket(sa_family_t family, const char *server, int port,
56 			int *sock, struct sockaddr *addr, socklen_t addr_len)
57 {
58 	const char *family_str = family == AF_INET ? "IPv4" : "IPv6";
59 	int ret = 0;
60 
61 	memset(addr, 0, addr_len);
62 
63 	if (family == AF_INET) {
64 		net_sin(addr)->sin_family = AF_INET;
65 		net_sin(addr)->sin_port = htons(port);
66 		inet_pton(family, server, &net_sin(addr)->sin_addr);
67 	} else {
68 		net_sin6(addr)->sin6_family = AF_INET6;
69 		net_sin6(addr)->sin6_port = htons(port);
70 		inet_pton(family, server, &net_sin6(addr)->sin6_addr);
71 	}
72 
73 	if (IS_ENABLED(CONFIG_NET_SOCKETS_SOCKOPT_TLS)) {
74 		sec_tag_t sec_tag_list[] = {
75 			CA_CERTIFICATE_TAG,
76 		};
77 
78 		*sock = socket(family, SOCK_STREAM, IPPROTO_TLS_1_2);
79 		if (*sock >= 0) {
80 			ret = setsockopt(*sock, SOL_TLS, TLS_SEC_TAG_LIST,
81 					 sec_tag_list, sizeof(sec_tag_list));
82 			if (ret < 0) {
83 				LOG_ERR("Failed to set %s secure option (%d)",
84 					family_str, -errno);
85 				ret = -errno;
86 				goto fail;
87 			}
88 
89 			ret = setsockopt(*sock, SOL_TLS, TLS_HOSTNAME,
90 					 TLS_PEER_HOSTNAME,
91 					 sizeof(TLS_PEER_HOSTNAME));
92 			if (ret < 0) {
93 				LOG_ERR("Failed to set %s TLS_HOSTNAME "
94 					"option (%d)", family_str, -errno);
95 				ret = -errno;
96 				goto fail;
97 			}
98 		}
99 	} else {
100 		*sock = socket(family, SOCK_STREAM, IPPROTO_TCP);
101 	}
102 
103 	if (*sock < 0) {
104 		LOG_ERR("Failed to create %s HTTP socket (%d)", family_str,
105 			-errno);
106 	}
107 
108 	return ret;
109 
110 fail:
111 	if (*sock >= 0) {
112 		close(*sock);
113 		*sock = -1;
114 	}
115 
116 	return ret;
117 }
118 
connect_socket(sa_family_t family,const char * server,int port,int * sock,struct sockaddr * addr,socklen_t addr_len)119 static int connect_socket(sa_family_t family, const char *server, int port,
120 			  int *sock, struct sockaddr *addr, socklen_t addr_len)
121 {
122 	int ret;
123 
124 	ret = setup_socket(family, server, port, sock, addr, addr_len);
125 	if (ret < 0 || *sock < 0) {
126 		return -1;
127 	}
128 
129 	ret = connect(*sock, addr, addr_len);
130 	if (ret < 0) {
131 		LOG_ERR("Cannot connect to %s remote (%d)",
132 			family == AF_INET ? "IPv4" : "IPv6",
133 			-errno);
134 		ret = -errno;
135 	}
136 
137 	return ret;
138 }
139 
connect_cb(int sock,struct http_request * req,void * user_data)140 static int connect_cb(int sock, struct http_request *req, void *user_data)
141 {
142 	LOG_INF("Websocket %d for %s connected.", sock, (char *)user_data);
143 
144 	return 0;
145 }
146 
how_much_to_send(size_t max_len)147 static size_t how_much_to_send(size_t max_len)
148 {
149 	size_t amount;
150 
151 	do {
152 		amount = sys_rand32_get() % max_len;
153 	} while (amount == 0U);
154 
155 	return amount;
156 }
157 
sendall_with_ws_api(int sock,const void * buf,size_t len)158 static ssize_t sendall_with_ws_api(int sock, const void *buf, size_t len)
159 {
160 	return websocket_send_msg(sock, buf, len, WEBSOCKET_OPCODE_DATA_TEXT,
161 				  true, true, SYS_FOREVER_MS);
162 }
163 
sendall_with_bsd_api(int sock,const void * buf,size_t len)164 static ssize_t sendall_with_bsd_api(int sock, const void *buf, size_t len)
165 {
166 	return send(sock, buf, len, 0);
167 }
168 
recv_data_wso_api(int sock,size_t amount,uint8_t * buf,size_t buf_len,const char * proto)169 static void recv_data_wso_api(int sock, size_t amount, uint8_t *buf,
170 			      size_t buf_len, const char *proto)
171 {
172 	uint64_t remaining = ULLONG_MAX;
173 	int total_read;
174 	uint32_t message_type;
175 	int ret, read_pos;
176 
177 	read_pos = 0;
178 	total_read = 0;
179 
180 	while (remaining > 0) {
181 		ret = websocket_recv_msg(sock, buf + read_pos,
182 					 buf_len - read_pos,
183 					 &message_type,
184 					 &remaining,
185 					 0);
186 		if (ret < 0) {
187 			if (ret == -EAGAIN) {
188 				k_sleep(K_MSEC(50));
189 				continue;
190 			}
191 
192 			LOG_DBG("%s connection closed while "
193 				"waiting (%d/%d)", proto, ret, errno);
194 			break;
195 		}
196 
197 		read_pos += ret;
198 		total_read += ret;
199 	}
200 
201 	if (remaining != 0 || total_read != amount ||
202 	    /* Do not check the final \n at the end of the msg */
203 	    memcmp(lorem_ipsum, buf, amount - 1) != 0) {
204 		LOG_ERR("%s data recv failure %zd/%d bytes (remaining %" PRId64 ")",
205 			proto, amount, total_read, remaining);
206 		LOG_HEXDUMP_DBG(buf, total_read, "received ws buf");
207 		LOG_HEXDUMP_DBG(lorem_ipsum, total_read, "sent ws buf");
208 	} else {
209 		LOG_DBG("%s recv %d bytes", proto, total_read);
210 	}
211 }
212 
recv_data_bsd_api(int sock,size_t amount,uint8_t * buf,size_t buf_len,const char * proto)213 static void recv_data_bsd_api(int sock, size_t amount, uint8_t *buf,
214 			      size_t buf_len, const char *proto)
215 {
216 	int remaining;
217 	int ret, read_pos;
218 
219 	remaining = amount;
220 	read_pos = 0;
221 
222 	while (remaining > 0) {
223 		ret = recv(sock, buf + read_pos, buf_len - read_pos, 0);
224 		if (ret <= 0) {
225 			if (errno == EAGAIN || errno == ETIMEDOUT) {
226 				k_sleep(K_MSEC(50));
227 				continue;
228 			}
229 
230 			LOG_DBG("%s connection closed while "
231 				"waiting (%d/%d)", proto, ret, errno);
232 			break;
233 		}
234 
235 		read_pos += ret;
236 		remaining -= ret;
237 	}
238 
239 	if (remaining != 0 ||
240 	    /* Do not check the final \n at the end of the msg */
241 	    memcmp(lorem_ipsum, buf, amount - 1) != 0) {
242 		LOG_ERR("%s data recv failure %zd/%d bytes (remaining %d)",
243 			proto, amount, read_pos, remaining);
244 		LOG_HEXDUMP_DBG(buf, read_pos, "received bsd buf");
245 		LOG_HEXDUMP_DBG(lorem_ipsum, read_pos, "sent bsd buf");
246 	} else {
247 		LOG_DBG("%s recv %d bytes", proto, read_pos);
248 	}
249 }
250 
send_and_wait_msg(int sock,size_t amount,const char * proto,uint8_t * buf,size_t buf_len)251 static bool send_and_wait_msg(int sock, size_t amount, const char *proto,
252 			      uint8_t *buf, size_t buf_len)
253 {
254 	static int count;
255 	int ret;
256 
257 	if (sock < 0) {
258 		return true;
259 	}
260 
261 	/* Terminate the sent data with \n so that we can use the
262 	 *      websocketd --port=9001 cat
263 	 * command in server side.
264 	 */
265 	memcpy(buf, lorem_ipsum, amount);
266 	buf[amount] = '\n';
267 
268 	/* Send every 2nd message using dedicated websocket API and generic
269 	 * BSD socket API. Real applications would not work like this but here
270 	 * we want to test both APIs. We also need to send the \n so add it
271 	 * here to amount variable.
272 	 */
273 	if (count % 2) {
274 		ret = sendall_with_ws_api(sock, buf, amount + 1);
275 	} else {
276 		ret = sendall_with_bsd_api(sock, buf, amount + 1);
277 	}
278 
279 	if (ret <= 0) {
280 		if (ret < 0) {
281 			LOG_ERR("%s failed to send data using %s (%d)", proto,
282 				(count % 2) ? "ws API" : "socket API", ret);
283 		} else {
284 			LOG_DBG("%s connection closed", proto);
285 		}
286 
287 		return false;
288 	} else {
289 		LOG_DBG("%s sent %d bytes", proto, ret);
290 	}
291 
292 	if (count % 2) {
293 		recv_data_wso_api(sock, amount + 1, buf, buf_len, proto);
294 	} else {
295 		recv_data_bsd_api(sock, amount + 1, buf, buf_len, proto);
296 	}
297 
298 	count++;
299 
300 	return true;
301 }
302 
main(void)303 int main(void)
304 {
305 	/* Just an example how to set extra headers */
306 	const char *extra_headers[] = {
307 		"Origin: http://foobar\r\n",
308 		NULL
309 	};
310 	int sock4 = -1, sock6 = -1;
311 	int websock4 = -1, websock6 = -1;
312 	int32_t timeout = 3 * MSEC_PER_SEC;
313 	struct sockaddr_in6 addr6;
314 	struct sockaddr_in addr4;
315 	size_t amount;
316 	int ret;
317 
318 	if (IS_ENABLED(CONFIG_NET_SOCKETS_SOCKOPT_TLS)) {
319 		ret = tls_credential_add(CA_CERTIFICATE_TAG,
320 					 TLS_CREDENTIAL_CA_CERTIFICATE,
321 					 ca_certificate,
322 					 sizeof(ca_certificate));
323 		if (ret < 0) {
324 			LOG_ERR("Failed to register public certificate: %d",
325 				ret);
326 			k_sleep(K_FOREVER);
327 		}
328 	}
329 
330 	if (IS_ENABLED(CONFIG_NET_IPV4)) {
331 		(void)connect_socket(AF_INET, SERVER_ADDR4, SERVER_PORT,
332 				     &sock4, (struct sockaddr *)&addr4,
333 				     sizeof(addr4));
334 	}
335 
336 	if (IS_ENABLED(CONFIG_NET_IPV6)) {
337 		(void)connect_socket(AF_INET6, SERVER_ADDR6, SERVER_PORT,
338 				     &sock6, (struct sockaddr *)&addr6,
339 				     sizeof(addr6));
340 	}
341 
342 	if (sock4 < 0 && sock6 < 0) {
343 		LOG_ERR("Cannot create HTTP connection.");
344 		k_sleep(K_FOREVER);
345 	}
346 
347 	if (sock4 >= 0 && IS_ENABLED(CONFIG_NET_IPV4)) {
348 		struct websocket_request req;
349 
350 		memset(&req, 0, sizeof(req));
351 
352 		req.host = SERVER_ADDR4;
353 		req.url = "/";
354 		req.optional_headers = extra_headers;
355 		req.cb = connect_cb;
356 		req.tmp_buf = temp_recv_buf_ipv4;
357 		req.tmp_buf_len = sizeof(temp_recv_buf_ipv4);
358 
359 		websock4 = websocket_connect(sock4, &req, timeout, "IPv4");
360 		if (websock4 < 0) {
361 			LOG_ERR("Cannot connect to %s:%d", SERVER_ADDR4,
362 				SERVER_PORT);
363 			close(sock4);
364 		}
365 	}
366 
367 	if (sock6 >= 0 && IS_ENABLED(CONFIG_NET_IPV6)) {
368 		struct websocket_request req;
369 
370 		memset(&req, 0, sizeof(req));
371 
372 		req.host = SERVER_ADDR6;
373 		req.url = "/";
374 		req.optional_headers = extra_headers;
375 		req.cb = connect_cb;
376 		req.tmp_buf = temp_recv_buf_ipv6;
377 		req.tmp_buf_len = sizeof(temp_recv_buf_ipv6);
378 
379 		websock6 = websocket_connect(sock6, &req, timeout, "IPv6");
380 		if (websock6 < 0) {
381 			LOG_ERR("Cannot connect to [%s]:%d", SERVER_ADDR6,
382 				SERVER_PORT);
383 			close(sock6);
384 		}
385 	}
386 
387 	if (websock4 < 0 && websock6 < 0) {
388 		LOG_ERR("No IPv4 or IPv6 connectivity");
389 		k_sleep(K_FOREVER);
390 	}
391 
392 	LOG_INF("Websocket IPv4 %d IPv6 %d", websock4, websock6);
393 
394 	while (1) {
395 		amount = how_much_to_send(ipsum_len);
396 
397 		if (websock4 >= 0 &&
398 		    !send_and_wait_msg(websock4, amount, "IPv4",
399 				       recv_buf_ipv4, sizeof(recv_buf_ipv4))) {
400 			break;
401 		}
402 
403 		if (websock6 >= 0 &&
404 		    !send_and_wait_msg(websock6, amount, "IPv6",
405 				       recv_buf_ipv6, sizeof(recv_buf_ipv6))) {
406 			break;
407 		}
408 
409 		k_sleep(K_MSEC(250));
410 	}
411 
412 	if (websock4 >= 0) {
413 		close(websock4);
414 	}
415 
416 	if (sock4 >= 0) {
417 		close(sock4);
418 	}
419 
420 	if (websock6 >= 0) {
421 		close(websock6);
422 	}
423 
424 	if (sock6 >= 0) {
425 		close(sock6);
426 	}
427 
428 	k_sleep(K_FOREVER);
429 	return 0;
430 }
431