1 /** @file
2  * @brief Websocket client API
3  *
4  * An API for applications to setup a websocket connections.
5  */
6 
7 /*
8  * Copyright (c) 2019 Intel Corporation
9  *
10  * SPDX-License-Identifier: Apache-2.0
11  */
12 
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15 
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21 
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33 
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38 
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42 
43 /* If you want to see the data that is being sent or received,
44  * then you can enable debugging and set the following variables to 1.
45  * This will print a lot of data so is not enabled by default.
46  */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49 
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51 
52 static struct k_sem contexts_lock;
53 
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55 
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59 
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 	switch (opcode) {
63 	case WEBSOCKET_OPCODE_DATA_TEXT:
64 		return "TEXT";
65 	case WEBSOCKET_OPCODE_DATA_BINARY:
66 		return "BIN";
67 	case WEBSOCKET_OPCODE_CONTINUE:
68 		return "CONT";
69 	case WEBSOCKET_OPCODE_CLOSE:
70 		return "CLOSE";
71 	case WEBSOCKET_OPCODE_PING:
72 		return "PING";
73 	case WEBSOCKET_OPCODE_PONG:
74 		return "PONG";
75 	default:
76 		break;
77 	}
78 
79 	return NULL;
80 }
81 
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 	int old_rc = atomic_inc(&ctx->refcount);
85 
86 	return old_rc + 1;
87 }
88 
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 	int old_rc = atomic_dec(&ctx->refcount);
92 
93 	if (old_rc != 1) {
94 		return old_rc - 1;
95 	}
96 
97 	return 0;
98 }
99 
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 	return !!atomic_get(&ctx->refcount);
103 }
104 
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 	struct websocket_context *ctx = NULL;
108 	int i;
109 
110 	k_sem_take(&contexts_lock, K_FOREVER);
111 
112 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 		if (websocket_context_is_used(&contexts[i])) {
114 			continue;
115 		}
116 
117 		websocket_context_ref(&contexts[i]);
118 		ctx = &contexts[i];
119 		break;
120 	}
121 
122 	k_sem_give(&contexts_lock);
123 
124 	return ctx;
125 }
126 
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 	struct websocket_context *ctx = NULL;
130 	int i;
131 
132 	k_sem_take(&contexts_lock, K_FOREVER);
133 
134 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 		if (!websocket_context_is_used(&contexts[i])) {
136 			continue;
137 		}
138 
139 		if (contexts[i].real_sock != real_sock) {
140 			continue;
141 		}
142 
143 		ctx = &contexts[i];
144 		break;
145 	}
146 
147 	k_sem_give(&contexts_lock);
148 
149 	return ctx;
150 }
151 
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static void response_cb(struct http_response *rsp,
153 			enum http_final_call final_data,
154 			void *user_data)
155 {
156 	struct websocket_context *ctx = user_data;
157 
158 	if (final_data == HTTP_DATA_MORE) {
159 		NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 			rsp->data_len);
161 		ctx->all_received = false;
162 	} else if (final_data == HTTP_DATA_FINAL) {
163 		NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 			rsp->data_len);
165 		ctx->all_received = true;
166 	}
167 }
168 
on_header_field(struct http_parser * parser,const char * at,size_t length)169 static int on_header_field(struct http_parser *parser, const char *at,
170 			   size_t length)
171 {
172 	struct http_request *req = CONTAINER_OF(parser,
173 						struct http_request,
174 						internal.parser);
175 	struct websocket_context *ctx = req->internal.user_data;
176 	const char *ws_accept_str = "Sec-WebSocket-Accept";
177 	uint16_t len;
178 
179 	len = strlen(ws_accept_str);
180 	if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
181 		ctx->sec_accept_present = true;
182 	}
183 
184 	if (ctx->http_cb && ctx->http_cb->on_header_field) {
185 		ctx->http_cb->on_header_field(parser, at, length);
186 	}
187 
188 	return 0;
189 }
190 
191 #define MAX_SEC_ACCEPT_LEN 32
192 
on_header_value(struct http_parser * parser,const char * at,size_t length)193 static int on_header_value(struct http_parser *parser, const char *at,
194 			   size_t length)
195 {
196 	struct http_request *req = CONTAINER_OF(parser,
197 						struct http_request,
198 						internal.parser);
199 	struct websocket_context *ctx = req->internal.user_data;
200 	char str[MAX_SEC_ACCEPT_LEN];
201 
202 	if (ctx->sec_accept_present) {
203 		int ret;
204 		size_t olen;
205 
206 		ctx->sec_accept_ok = false;
207 		ctx->sec_accept_present = false;
208 
209 		ret = base64_encode(str, sizeof(str) - 1, &olen,
210 				    ctx->sec_accept_key,
211 				    WS_SHA1_OUTPUT_LEN);
212 		if (ret == 0) {
213 			if (strncmp(at, str, length)) {
214 				NET_DBG("[%p] Security keys do not match "
215 					"%s vs %s", ctx, str, at);
216 			} else {
217 				ctx->sec_accept_ok = true;
218 			}
219 		}
220 	}
221 
222 	if (ctx->http_cb && ctx->http_cb->on_header_value) {
223 		ctx->http_cb->on_header_value(parser, at, length);
224 	}
225 
226 	return 0;
227 }
228 
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)229 int websocket_connect(int sock, struct websocket_request *wreq,
230 		      int32_t timeout, void *user_data)
231 {
232 	/* This is the expected Sec-WebSocket-Accept key. We are storing a
233 	 * pointer to this in ctx but the value is only used for the duration
234 	 * of this function call so there is no issue even if this variable
235 	 * is allocated from stack.
236 	 */
237 	uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
238 	struct http_parser_settings http_parser_settings;
239 	struct websocket_context *ctx;
240 	struct http_request req;
241 	int ret, fd, key_len;
242 	size_t olen;
243 	char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
244 	uint32_t rnd_value = sys_rand32_get();
245 	char sec_ws_key[] =
246 		"Sec-WebSocket-Key: 0123456789012345678901==\r\n";
247 	char *headers[] = {
248 		sec_ws_key,
249 		"Upgrade: websocket\r\n",
250 		"Connection: Upgrade\r\n",
251 		"Sec-WebSocket-Version: 13\r\n",
252 		NULL
253 	};
254 
255 	fd = -1;
256 
257 	if (sock < 0 || wreq == NULL || wreq->host == NULL ||
258 	    wreq->url == NULL) {
259 		return -EINVAL;
260 	}
261 
262 	ctx = websocket_find(sock);
263 	if (ctx) {
264 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
265 			sock);
266 		return -EEXIST;
267 	}
268 
269 	ctx = websocket_get();
270 	if (!ctx) {
271 		return -ENOENT;
272 	}
273 
274 	ctx->real_sock = sock;
275 	ctx->recv_buf.buf = wreq->tmp_buf;
276 	ctx->recv_buf.size = wreq->tmp_buf_len;
277 	ctx->sec_accept_key = sec_accept_key;
278 	ctx->http_cb = wreq->http_cb;
279 	ctx->is_client = 1;
280 
281 	mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
282 			 sec_accept_key);
283 
284 	ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
285 			    sizeof(sec_ws_key) -
286 					sizeof("Sec-Websocket-Key: "),
287 			    &olen, sec_accept_key,
288 			    /* We are only interested in 16 first bytes so
289 			     * subtract 4 from the SHA-1 length
290 			     */
291 			    sizeof(sec_accept_key) - 4);
292 	if (ret) {
293 		NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
294 		goto out;
295 	}
296 
297 	if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
298 		NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
299 			olen + sizeof("Sec-Websocket-Key: ") + 2,
300 			sizeof(sec_ws_key));
301 		ret = -EMSGSIZE;
302 		goto out;
303 	}
304 
305 	memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
306 	       HTTP_CRLF, sizeof(HTTP_CRLF));
307 
308 	memset(&req, 0, sizeof(req));
309 
310 	req.method = HTTP_GET;
311 	req.url = wreq->url;
312 	req.host = wreq->host;
313 	req.protocol = "HTTP/1.1";
314 	req.header_fields = (const char **)headers;
315 	req.optional_headers_cb = wreq->optional_headers_cb;
316 	req.optional_headers = wreq->optional_headers;
317 	req.response = response_cb;
318 	req.http_cb = &http_parser_settings;
319 	req.recv_buf = wreq->tmp_buf;
320 	req.recv_buf_len = wreq->tmp_buf_len;
321 
322 	/* We need to catch the Sec-WebSocket-Accept field in order to verify
323 	 * that it contains the stuff that we sent in Sec-WebSocket-Key field
324 	 * so setup HTTP callbacks so that we will get the needed fields.
325 	 */
326 	if (ctx->http_cb) {
327 		memcpy(&http_parser_settings, ctx->http_cb,
328 		       sizeof(http_parser_settings));
329 	} else {
330 		memset(&http_parser_settings, 0, sizeof(http_parser_settings));
331 	}
332 
333 	http_parser_settings.on_header_field = on_header_field;
334 	http_parser_settings.on_header_value = on_header_value;
335 
336 	/* Pre-calculate the expected Sec-Websocket-Accept field */
337 	key_len = MIN(sizeof(key_accept) - 1, olen);
338 	strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
339 		key_len);
340 
341 	olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
342 	strncpy(key_accept + key_len, WS_MAGIC, olen);
343 
344 	/* This SHA-1 value is then checked when we receive the response */
345 	mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
346 
347 	ret = http_client_req(sock, &req, timeout, ctx);
348 	if (ret < 0) {
349 		NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
350 			wreq->host);
351 		ret = -ECONNABORTED;
352 		goto out;
353 	}
354 
355 	if (!(ctx->all_received && ctx->sec_accept_ok)) {
356 		NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
357 			ctx->all_received, ctx->sec_accept_ok);
358 		ret = -ECONNABORTED;
359 		goto out;
360 	}
361 
362 	ctx->user_data = user_data;
363 
364 	fd = zvfs_reserve_fd();
365 	if (fd < 0) {
366 		ret = -ENOSPC;
367 		goto out;
368 	}
369 
370 	ctx->sock = fd;
371 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
372 			    ZVFS_MODE_IFSOCK);
373 
374 	/* Call the user specified callback and if it accepts the connection
375 	 * then continue.
376 	 */
377 	if (wreq->cb) {
378 		ret = wreq->cb(fd, &req, user_data);
379 		if (ret < 0) {
380 			NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
381 			goto out;
382 		}
383 	}
384 
385 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
386 
387 	/* We will re-use the temp buffer in receive function if needed but
388 	 * in order that to work the amount of data in buffer must be set to 0
389 	 */
390 	ctx->recv_buf.count = 0;
391 
392 	/* Init parser FSM */
393 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
394 
395 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
396 
397 	return fd;
398 
399 out:
400 	if (fd >= 0) {
401 		(void)zsock_close(fd);
402 	}
403 
404 	websocket_context_unref(ctx);
405 	return ret;
406 }
407 
websocket_disconnect(int ws_sock)408 int websocket_disconnect(int ws_sock)
409 {
410 	return zsock_close(ws_sock);
411 }
412 
websocket_interal_disconnect(struct websocket_context * ctx)413 static int websocket_interal_disconnect(struct websocket_context *ctx)
414 {
415 	int ret;
416 
417 	if (ctx == NULL) {
418 		return -ENOENT;
419 	}
420 
421 	NET_DBG("[%p] Disconnecting", ctx);
422 
423 	ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
424 				 true, true, SYS_FOREVER_MS);
425 	if (ret < 0) {
426 		NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
427 	}
428 
429 	(void)sock_obj_core_dealloc(ctx->sock);
430 
431 	websocket_context_unref(ctx);
432 
433 	return ret;
434 }
435 
websocket_close_vmeth(void * obj)436 static int websocket_close_vmeth(void *obj)
437 {
438 	struct websocket_context *ctx = obj;
439 	int ret;
440 
441 	ret = websocket_interal_disconnect(ctx);
442 	if (ret < 0) {
443 		/* Ignore error if we are not connected */
444 		if (ret != -ENOTCONN) {
445 			NET_DBG("[%p] Cannot close (%d)", obj, ret);
446 
447 			errno = -ret;
448 			return -1;
449 		}
450 
451 		ret = 0;
452 	}
453 
454 	return ret;
455 }
456 
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)457 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
458 					 int timeout)
459 {
460 	int fd_backup[CONFIG_ZVFS_POLL_MAX];
461 	const struct fd_op_vtable *vtable;
462 	void *ctx;
463 	int ret = 0;
464 	int i;
465 
466 	/* Overwrite websocket file descriptors with underlying ones. */
467 	for (i = 0; i < nfds; i++) {
468 		fd_backup[i] = fds[i].fd;
469 
470 		ctx = zvfs_get_fd_obj(fds[i].fd,
471 				   (const struct fd_op_vtable *)
472 						     &websocket_fd_op_vtable,
473 				   0);
474 		if (ctx == NULL) {
475 			continue;
476 		}
477 
478 		fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
479 	}
480 
481 	/* Get offloaded sockets vtable. */
482 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
483 				      (const struct fd_op_vtable **)&vtable,
484 				      NULL);
485 	if (ctx == NULL) {
486 		errno = EINVAL;
487 		ret = -1;
488 		goto exit;
489 	}
490 
491 	ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
492 				   fds, nfds, timeout);
493 
494 exit:
495 	/* Restore original fds. */
496 	for (i = 0; i < nfds; i++) {
497 		fds[i].fd = fd_backup[i];
498 	}
499 
500 	return ret;
501 }
502 
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)503 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
504 {
505 	struct websocket_context *ctx = obj;
506 
507 	switch (request) {
508 	case ZFD_IOCTL_POLL_OFFLOAD: {
509 		struct zsock_pollfd *fds;
510 		int nfds;
511 		int timeout;
512 
513 		fds = va_arg(args, struct zsock_pollfd *);
514 		nfds = va_arg(args, int);
515 		timeout = va_arg(args, int);
516 
517 		return websocket_poll_offload(fds, nfds, timeout);
518 	}
519 
520 	case ZFD_IOCTL_SET_LOCK:
521 		/* Ignore, don't want to overwrite underlying socket lock. */
522 		return 0;
523 
524 	default: {
525 		const struct fd_op_vtable *vtable;
526 		void *core_obj;
527 
528 		core_obj = zvfs_get_fd_obj_and_vtable(
529 				ctx->real_sock,
530 				(const struct fd_op_vtable **)&vtable,
531 				NULL);
532 		if (core_obj == NULL) {
533 			errno = EBADF;
534 			return -1;
535 		}
536 
537 		/* Pass the call to the core socket implementation. */
538 		return vtable->ioctl(core_obj, request, args);
539 	}
540 	}
541 
542 	return 0;
543 }
544 
545 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags,const k_timepoint_t req_end_timepoint)546 static int sendmsg_all(int sock, const struct msghdr *message, int flags,
547 			const k_timepoint_t req_end_timepoint)
548 {
549 	int ret, i;
550 	size_t offset = 0;
551 	size_t total_len = 0;
552 
553 	for (i = 0; i < message->msg_iovlen; i++) {
554 		total_len += message->msg_iov[i].iov_len;
555 	}
556 
557 	while (offset < total_len) {
558 		ret = zsock_sendmsg(sock, message, flags);
559 
560 		if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
561 			struct zsock_pollfd pfd;
562 			int pollres;
563 			k_ticks_t req_timeout_ticks =
564 				sys_timepoint_timeout(req_end_timepoint).ticks;
565 			int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
566 
567 			pfd.fd = sock;
568 			pfd.events = ZSOCK_POLLOUT;
569 			pollres = zsock_poll(&pfd, 1, req_timeout_ms);
570 			if (pollres == 0) {
571 				return -ETIMEDOUT;
572 			} else if (pollres > 0) {
573 				continue;
574 			} else {
575 				return -errno;
576 			}
577 		} else if (ret < 0) {
578 			return -errno;
579 		}
580 
581 		offset += ret;
582 		if (offset >= total_len) {
583 			break;
584 		}
585 
586 		/* Update msghdr for the next iteration. */
587 		for (i = 0; i < message->msg_iovlen; i++) {
588 			if (ret < message->msg_iov[i].iov_len) {
589 				message->msg_iov[i].iov_len -= ret;
590 				message->msg_iov[i].iov_base =
591 					(uint8_t *)message->msg_iov[i].iov_base + ret;
592 				break;
593 			}
594 
595 			ret -= message->msg_iov[i].iov_len;
596 			message->msg_iov[i].iov_len = 0;
597 		}
598 	}
599 
600 	return total_len;
601 }
602 #endif /* !defined(CONFIG_NET_TEST) */
603 
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)604 static int websocket_prepare_and_send(struct websocket_context *ctx,
605 				      uint8_t *header, size_t header_len,
606 				      uint8_t *payload, size_t payload_len,
607 				      int32_t timeout)
608 {
609 	struct iovec io_vector[2];
610 	struct msghdr msg;
611 
612 	io_vector[0].iov_base = header;
613 	io_vector[0].iov_len = header_len;
614 	io_vector[1].iov_base = payload;
615 	io_vector[1].iov_len = payload_len;
616 
617 	memset(&msg, 0, sizeof(msg));
618 
619 	msg.msg_iov = io_vector;
620 	msg.msg_iovlen = ARRAY_SIZE(io_vector);
621 
622 	if (HEXDUMP_SENT_PACKETS) {
623 		LOG_HEXDUMP_DBG(header, header_len, "Header");
624 		if ((payload != NULL) && (payload_len > 0)) {
625 			LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
626 		} else {
627 			LOG_DBG("No payload");
628 		}
629 	}
630 
631 #if defined(CONFIG_NET_TEST)
632 	/* Simulate a case where the payload is split to two. The unit test
633 	 * does not set mask bit in this case.
634 	 */
635 	return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
636 #else
637 	k_timeout_t tout = K_FOREVER;
638 
639 	if (timeout != SYS_FOREVER_MS) {
640 		tout = K_MSEC(timeout);
641 	}
642 
643 	k_timeout_t req_timeout = K_MSEC(timeout);
644 	k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
645 
646 	return sendmsg_all(ctx->real_sock, &msg,
647 			   K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
648 			   req_end_timepoint);
649 #endif /* CONFIG_NET_TEST */
650 }
651 
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)652 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
653 		       enum websocket_opcode opcode, bool mask, bool final,
654 		       int32_t timeout)
655 {
656 	struct websocket_context *ctx;
657 	uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
658 	uint8_t *data_to_send = (uint8_t *)payload;
659 	int ret;
660 
661 	if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
662 	    opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
663 	    opcode != WEBSOCKET_OPCODE_CONTINUE &&
664 	    opcode != WEBSOCKET_OPCODE_CLOSE &&
665 	    opcode != WEBSOCKET_OPCODE_PING &&
666 	    opcode != WEBSOCKET_OPCODE_PONG) {
667 		return -EINVAL;
668 	}
669 
670 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
671 	if (ctx == NULL) {
672 		return -EBADF;
673 	}
674 
675 #if !defined(CONFIG_NET_TEST)
676 	/* Websocket unit test does not use context from pool but allocates
677 	 * its own, hence skip the check.
678 	 */
679 
680 	if (!PART_OF_ARRAY(contexts, ctx)) {
681 		return -ENOENT;
682 	}
683 #endif /* !defined(CONFIG_NET_TEST) */
684 
685 	NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
686 		mask, final ? "final" : "more");
687 
688 	memset(header, 0, sizeof(header));
689 
690 	/* Is this the last packet? */
691 	header[0] = final ? BIT(7) : 0;
692 
693 	/* Text, binary, ping, pong or close ? */
694 	header[0] |= opcode;
695 
696 	/* Masking */
697 	header[1] = mask ? BIT(7) : 0;
698 
699 	if (payload_len < 126) {
700 		header[1] |= payload_len;
701 	} else if (payload_len < 65536) {
702 		header[1] |= 126;
703 		header[2] = payload_len >> 8;
704 		header[3] = payload_len;
705 		hdr_len += 2;
706 	} else {
707 		header[1] |= 127;
708 		header[2] = 0;
709 		header[3] = 0;
710 		header[4] = 0;
711 		header[5] = 0;
712 		header[6] = payload_len >> 24;
713 		header[7] = payload_len >> 16;
714 		header[8] = payload_len >> 8;
715 		header[9] = payload_len;
716 		hdr_len += 8;
717 	}
718 
719 	/* Add masking value if needed */
720 	if (mask) {
721 		int i;
722 
723 		ctx->masking_value = sys_rand32_get();
724 
725 		header[hdr_len++] |= ctx->masking_value >> 24;
726 		header[hdr_len++] |= ctx->masking_value >> 16;
727 		header[hdr_len++] |= ctx->masking_value >> 8;
728 		header[hdr_len++] |= ctx->masking_value;
729 
730 		if ((payload != NULL) && (payload_len > 0)) {
731 			data_to_send = k_malloc(payload_len);
732 			if (!data_to_send) {
733 				return -ENOMEM;
734 			}
735 
736 			memcpy(data_to_send, payload, payload_len);
737 
738 			for (i = 0; i < payload_len; i++) {
739 				data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
740 			}
741 		}
742 	}
743 
744 	ret = websocket_prepare_and_send(ctx, header, hdr_len,
745 					 data_to_send, payload_len, timeout);
746 	if (ret < 0) {
747 		NET_DBG("Cannot send ws msg (%d)", -errno);
748 		goto quit;
749 	}
750 
751 quit:
752 	if (data_to_send != payload) {
753 		k_free(data_to_send);
754 	}
755 
756 	/* Do no math with 0 and error codes */
757 	if (ret <= 0) {
758 		return ret;
759 	}
760 
761 	return ret - hdr_len;
762 }
763 
websocket_opcode2flag(uint8_t data)764 static uint32_t websocket_opcode2flag(uint8_t data)
765 {
766 	switch (data & 0x0f) {
767 	case WEBSOCKET_OPCODE_DATA_TEXT:
768 		return WEBSOCKET_FLAG_TEXT;
769 	case WEBSOCKET_OPCODE_DATA_BINARY:
770 		return WEBSOCKET_FLAG_BINARY;
771 	case WEBSOCKET_OPCODE_CLOSE:
772 		return WEBSOCKET_FLAG_CLOSE;
773 	case WEBSOCKET_OPCODE_PING:
774 		return WEBSOCKET_FLAG_PING;
775 	case WEBSOCKET_OPCODE_PONG:
776 		return WEBSOCKET_FLAG_PONG;
777 	default:
778 		break;
779 	}
780 	return 0;
781 }
782 
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)783 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
784 {
785 	int len;
786 	uint8_t data;
787 	size_t parsed_count = 0;
788 
789 	do {
790 		if (parsed_count >= ctx->recv_buf.count) {
791 			return parsed_count;
792 		}
793 		if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
794 			data = ctx->recv_buf.buf[parsed_count++];
795 
796 			switch (ctx->parser_state) {
797 			case WEBSOCKET_PARSER_STATE_OPCODE:
798 				ctx->message_type = websocket_opcode2flag(data);
799 				if ((data & 0x80) != 0) {
800 					ctx->message_type |= WEBSOCKET_FLAG_FINAL;
801 				}
802 				ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
803 				break;
804 			case WEBSOCKET_PARSER_STATE_LENGTH:
805 				ctx->masked = (data & 0x80) != 0;
806 				len = data & 0x7f;
807 				if (len < 126) {
808 					ctx->message_len = len;
809 					if (ctx->masked) {
810 						ctx->masking_value = 0;
811 						ctx->parser_remaining = 4;
812 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
813 					} else {
814 						ctx->parser_remaining = ctx->message_len;
815 						ctx->parser_state =
816 							(ctx->parser_remaining == 0)
817 								? WEBSOCKET_PARSER_STATE_OPCODE
818 								: WEBSOCKET_PARSER_STATE_PAYLOAD;
819 					}
820 				} else {
821 					ctx->message_len = 0;
822 					ctx->parser_remaining = (len < 127) ? 2 : 8;
823 					ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
824 				}
825 				break;
826 			case WEBSOCKET_PARSER_STATE_EXT_LEN:
827 				ctx->parser_remaining--;
828 				ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
829 				if (ctx->parser_remaining == 0) {
830 					if (ctx->masked) {
831 						ctx->masking_value = 0;
832 						ctx->parser_remaining = 4;
833 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
834 					} else {
835 						ctx->parser_remaining = ctx->message_len;
836 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
837 					}
838 				}
839 				break;
840 			case WEBSOCKET_PARSER_STATE_MASK:
841 				ctx->parser_remaining--;
842 				ctx->masking_value |= (data << (ctx->parser_remaining * 8));
843 				if (ctx->parser_remaining == 0) {
844 					if (ctx->message_len == 0) {
845 						ctx->parser_remaining = 0;
846 						ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
847 					} else {
848 						ctx->parser_remaining = ctx->message_len;
849 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
850 					}
851 				}
852 				break;
853 			default:
854 				return -EFAULT;
855 			}
856 
857 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
858 			if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
859 			    ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
860 			     (ctx->message_len == 0))) {
861 				NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
862 					ctx->masked ? "" : "un",
863 					ctx->masked ? ctx->masking_value : 0, ctx->message_type,
864 					(size_t)ctx->message_len);
865 			}
866 #endif
867 		} else {
868 			size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
869 			size_t payload_in_recv_buf =
870 				MIN(remaining_in_recv_buf, ctx->parser_remaining);
871 			size_t free_in_payload_buf = payload->size - payload->count;
872 			size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
873 
874 			if (free_in_payload_buf == 0) {
875 				break;
876 			}
877 
878 			memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
879 			       ready_to_copy);
880 			parsed_count += ready_to_copy;
881 			payload->count += ready_to_copy;
882 			ctx->parser_remaining -= ready_to_copy;
883 			if (ctx->parser_remaining == 0) {
884 				ctx->parser_remaining = 0;
885 				ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
886 			}
887 		}
888 
889 	} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
890 
891 	return parsed_count;
892 }
893 
894 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)895 static int wait_rx(int sock, int timeout)
896 {
897 	struct zsock_pollfd fds = {
898 		.fd = sock,
899 		.events = ZSOCK_POLLIN,
900 	};
901 	int ret;
902 
903 	ret = zsock_poll(&fds, 1, timeout);
904 	if (ret < 0) {
905 		return ret;
906 	}
907 
908 	if (ret == 0) {
909 		/* Timeout */
910 		return -EAGAIN;
911 	}
912 
913 	if (fds.revents & ZSOCK_POLLNVAL) {
914 		return -EBADF;
915 	}
916 
917 	if (fds.revents & ZSOCK_POLLERR) {
918 		return -EIO;
919 	}
920 
921 	return 0;
922 }
923 
timeout_to_ms(k_timeout_t * timeout)924 static int timeout_to_ms(k_timeout_t *timeout)
925 {
926 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
927 		return 0;
928 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
929 		return SYS_FOREVER_MS;
930 	} else {
931 		return k_ticks_to_ms_floor32(timeout->ticks);
932 	}
933 }
934 
935 #endif /* !defined(CONFIG_NET_TEST) */
936 
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)937 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
938 		       uint32_t *message_type, uint64_t *remaining, int32_t timeout)
939 {
940 	struct websocket_context *ctx;
941 	int ret;
942 	k_timepoint_t end;
943 	k_timeout_t tout = K_FOREVER;
944 	struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
945 
946 	if (timeout != SYS_FOREVER_MS) {
947 		tout = K_MSEC(timeout);
948 	}
949 
950 	if ((buf == NULL) || (buf_len == 0)) {
951 		return -EINVAL;
952 	}
953 
954 	end = sys_timepoint_calc(tout);
955 
956 #if defined(CONFIG_NET_TEST)
957 	struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
958 
959 	if (test_data == NULL) {
960 		return -EBADF;
961 	}
962 
963 	ctx = test_data->ctx;
964 #else
965 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
966 	if (ctx == NULL) {
967 		return -EBADF;
968 	}
969 
970 	if (!PART_OF_ARRAY(contexts, ctx)) {
971 		return -ENOENT;
972 	}
973 #endif /* CONFIG_NET_TEST */
974 
975 	do {
976 		size_t parsed_count;
977 
978 		if (ctx->recv_buf.count == 0) {
979 #if defined(CONFIG_NET_TEST)
980 			size_t input_len = MIN(ctx->recv_buf.size,
981 					       test_data->input_len - test_data->input_pos);
982 
983 			if (input_len > 0) {
984 				memcpy(ctx->recv_buf.buf,
985 				       &test_data->input_buf[test_data->input_pos], input_len);
986 				test_data->input_pos += input_len;
987 				ret = input_len;
988 			} else {
989 				/* emulate timeout */
990 				ret = -EAGAIN;
991 			}
992 #else
993 			tout = sys_timepoint_timeout(end);
994 
995 			ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
996 			if (ret == 0) {
997 				ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
998 						 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
999 				if (ret < 0) {
1000 					ret = -errno;
1001 				}
1002 			}
1003 #endif /* CONFIG_NET_TEST */
1004 
1005 			if (ret < 0) {
1006 				if ((ret == -EAGAIN) && (payload.count > 0)) {
1007 					/* go to unmasking */
1008 					break;
1009 				}
1010 				return ret;
1011 			}
1012 
1013 			if (ret == 0) {
1014 				/* Socket closed */
1015 				return -ENOTCONN;
1016 			}
1017 
1018 			ctx->recv_buf.count = ret;
1019 
1020 			NET_DBG("[%p] Received %d bytes", ctx, ret);
1021 		}
1022 
1023 		ret = websocket_parse(ctx, &payload);
1024 		if (ret < 0) {
1025 			return ret;
1026 		}
1027 		parsed_count = ret;
1028 
1029 		if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1030 		    (payload.count >= payload.size)) {
1031 			if (remaining != NULL) {
1032 				*remaining = ctx->parser_remaining;
1033 			}
1034 			if (message_type != NULL) {
1035 				*message_type = ctx->message_type;
1036 			}
1037 
1038 			size_t left = ctx->recv_buf.count - parsed_count;
1039 
1040 			if (left > 0) {
1041 				memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1042 			}
1043 			ctx->recv_buf.count = left;
1044 			break;
1045 		}
1046 
1047 		ctx->recv_buf.count -= parsed_count;
1048 
1049 	} while (true);
1050 
1051 	/* Unmask the data */
1052 	if (ctx->masked) {
1053 		uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1054 		size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1055 
1056 		for (size_t i = 0; i < payload.count; i++) {
1057 			size_t m = data_buf_offset % 4;
1058 
1059 			payload.buf[i] ^= mask_as_bytes[3 - m];
1060 			data_buf_offset++;
1061 		}
1062 	}
1063 
1064 	return payload.count;
1065 }
1066 
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1067 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1068 			  size_t buf_len, int32_t timeout)
1069 {
1070 	int ret;
1071 
1072 	NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1073 
1074 	ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1075 				 ctx->is_client, true, timeout);
1076 	if (ret < 0) {
1077 		errno = -ret;
1078 		return -1;
1079 	}
1080 
1081 	NET_DBG("[%p] Sent %d bytes", ctx, ret);
1082 
1083 	sock_obj_core_update_send_stats(ctx->sock, ret);
1084 
1085 	return ret;
1086 }
1087 
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1088 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1089 			  size_t buf_len, int32_t timeout)
1090 {
1091 	uint32_t message_type;
1092 	uint64_t remaining;
1093 	int ret;
1094 
1095 	NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1096 
1097 	/* TODO: add support for recvmsg() so that we could return the
1098 	 *       websocket specific information in ancillary data.
1099 	 */
1100 	ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1101 				 &remaining, timeout);
1102 	if (ret < 0) {
1103 		if (ret == -ENOTCONN) {
1104 			ret = 0;
1105 		} else {
1106 			errno = -ret;
1107 			return -1;
1108 		}
1109 	}
1110 
1111 	NET_DBG("[%p] Received %d bytes", ctx, ret);
1112 
1113 	sock_obj_core_update_recv_stats(ctx->sock, ret);
1114 
1115 	return ret;
1116 }
1117 
websocket_read_vmeth(void * obj,void * buffer,size_t count)1118 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1119 {
1120 	return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1121 }
1122 
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1123 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1124 				     size_t count)
1125 {
1126 	return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1127 }
1128 
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1129 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1130 				    int flags,
1131 				    const struct sockaddr *dest_addr,
1132 				    socklen_t addrlen)
1133 {
1134 	struct websocket_context *ctx = obj;
1135 	int32_t timeout = SYS_FOREVER_MS;
1136 
1137 	if (flags & ZSOCK_MSG_DONTWAIT) {
1138 		timeout = 0;
1139 	}
1140 
1141 	ARG_UNUSED(dest_addr);
1142 	ARG_UNUSED(addrlen);
1143 
1144 	return (ssize_t)websocket_send(ctx, buf, len, timeout);
1145 }
1146 
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1147 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1148 				      int flags, struct sockaddr *src_addr,
1149 				      socklen_t *addrlen)
1150 {
1151 	struct websocket_context *ctx = obj;
1152 	int32_t timeout = SYS_FOREVER_MS;
1153 
1154 	if (flags & ZSOCK_MSG_DONTWAIT) {
1155 		timeout = 0;
1156 	}
1157 
1158 	ARG_UNUSED(src_addr);
1159 	ARG_UNUSED(addrlen);
1160 
1161 	return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1162 }
1163 
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1164 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1165 {
1166 	struct websocket_context *ctx;
1167 	int ret, fd;
1168 
1169 	if (sock < 0) {
1170 		return -EINVAL;
1171 	}
1172 
1173 	ctx = websocket_find(sock);
1174 	if (ctx) {
1175 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1176 		return -EEXIST;
1177 	}
1178 
1179 	ctx = websocket_get();
1180 	if (!ctx) {
1181 		return -ENOENT;
1182 	}
1183 
1184 	ctx->real_sock = sock;
1185 	ctx->recv_buf.buf = recv_buf;
1186 	ctx->recv_buf.size = recv_buf_len;
1187 	ctx->is_client = 0;
1188 
1189 	fd = zvfs_reserve_fd();
1190 	if (fd < 0) {
1191 		ret = -ENOSPC;
1192 		goto out;
1193 	}
1194 
1195 	ctx->sock = fd;
1196 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1197 			    ZVFS_MODE_IFSOCK);
1198 
1199 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1200 
1201 	ctx->recv_buf.count = 0;
1202 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1203 
1204 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
1205 
1206 	return fd;
1207 
1208 out:
1209 	websocket_context_unref(ctx);
1210 
1211 	return ret;
1212 }
1213 
websocket_search(int sock)1214 static struct websocket_context *websocket_search(int sock)
1215 {
1216 	struct websocket_context *ctx = NULL;
1217 	int i;
1218 
1219 	k_sem_take(&contexts_lock, K_FOREVER);
1220 
1221 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1222 		if (!websocket_context_is_used(&contexts[i])) {
1223 			continue;
1224 		}
1225 
1226 		if (contexts[i].sock != sock) {
1227 			continue;
1228 		}
1229 
1230 		ctx = &contexts[i];
1231 		break;
1232 	}
1233 
1234 	k_sem_give(&contexts_lock);
1235 
1236 	return ctx;
1237 }
1238 
websocket_unregister(int sock)1239 int websocket_unregister(int sock)
1240 {
1241 	struct websocket_context *ctx;
1242 
1243 	if (sock < 0) {
1244 		return -EINVAL;
1245 	}
1246 
1247 	ctx = websocket_search(sock);
1248 	if (ctx == NULL) {
1249 		NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1250 		return -ENOENT;
1251 	}
1252 
1253 	if (ctx->real_sock < 0) {
1254 		return -EALREADY;
1255 	}
1256 
1257 	(void)zsock_close(sock);
1258 	(void)zsock_close(ctx->real_sock);
1259 
1260 	ctx->real_sock = -1;
1261 	ctx->sock = -1;
1262 
1263 	return 0;
1264 }
1265 
1266 static const struct socket_op_vtable websocket_fd_op_vtable = {
1267 	.fd_vtable = {
1268 		.read = websocket_read_vmeth,
1269 		.write = websocket_write_vmeth,
1270 		.close = websocket_close_vmeth,
1271 		.ioctl = websocket_ioctl_vmeth,
1272 	},
1273 	.sendto = websocket_sendto_ctx,
1274 	.recvfrom = websocket_recvfrom_ctx,
1275 };
1276 
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1277 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1278 {
1279 	int i;
1280 
1281 	k_sem_take(&contexts_lock, K_FOREVER);
1282 
1283 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1284 		if (!websocket_context_is_used(&contexts[i])) {
1285 			continue;
1286 		}
1287 
1288 		k_mutex_lock(&contexts[i].lock, K_FOREVER);
1289 
1290 		cb(&contexts[i], user_data);
1291 
1292 		k_mutex_unlock(&contexts[i].lock);
1293 	}
1294 
1295 	k_sem_give(&contexts_lock);
1296 }
1297 
websocket_init(void)1298 void websocket_init(void)
1299 {
1300 	k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1301 }
1302