1 /** @file
2  * @brief Websocket client API
3  *
4  * An API for applications to setup a websocket connections.
5  */
6 
7 /*
8  * Copyright (c) 2019 Intel Corporation
9  *
10  * SPDX-License-Identifier: Apache-2.0
11  */
12 
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15 
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21 
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #include <zephyr/net/socket.h>
26 #include <zephyr/net/http/client.h>
27 #include <zephyr/net/websocket.h>
28 
29 #include <zephyr/random/random.h>
30 #include <zephyr/sys/byteorder.h>
31 #include <zephyr/sys/base64.h>
32 
33 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
34 #include <psa/crypto.h>
35 #else
36 #include <mbedtls/sha1.h>
37 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
38 
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42 
43 /* If you want to see the data that is being sent or received,
44  * then you can enable debugging and set the following variables to 1.
45  * This will print a lot of data so is not enabled by default.
46  */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49 
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51 
52 static struct k_sem contexts_lock;
53 
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55 
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct net_msghdr *msg, bool split_msg);
58 #endif
59 
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 	switch (opcode) {
63 	case WEBSOCKET_OPCODE_DATA_TEXT:
64 		return "TEXT";
65 	case WEBSOCKET_OPCODE_DATA_BINARY:
66 		return "BIN";
67 	case WEBSOCKET_OPCODE_CONTINUE:
68 		return "CONT";
69 	case WEBSOCKET_OPCODE_CLOSE:
70 		return "CLOSE";
71 	case WEBSOCKET_OPCODE_PING:
72 		return "PING";
73 	case WEBSOCKET_OPCODE_PONG:
74 		return "PONG";
75 	default:
76 		break;
77 	}
78 
79 	return NULL;
80 }
81 
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 	int old_rc = atomic_inc(&ctx->refcount);
85 
86 	return old_rc + 1;
87 }
88 
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 	int old_rc = atomic_dec(&ctx->refcount);
92 
93 	if (old_rc != 1) {
94 		return old_rc - 1;
95 	}
96 
97 	return 0;
98 }
99 
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 	return !!atomic_get(&ctx->refcount);
103 }
104 
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 	struct websocket_context *ctx = NULL;
108 	int i;
109 
110 	k_sem_take(&contexts_lock, K_FOREVER);
111 
112 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 		if (websocket_context_is_used(&contexts[i])) {
114 			continue;
115 		}
116 
117 		websocket_context_ref(&contexts[i]);
118 		ctx = &contexts[i];
119 		break;
120 	}
121 
122 	k_sem_give(&contexts_lock);
123 
124 	return ctx;
125 }
126 
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 	struct websocket_context *ctx = NULL;
130 	int i;
131 
132 	k_sem_take(&contexts_lock, K_FOREVER);
133 
134 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 		if (!websocket_context_is_used(&contexts[i])) {
136 			continue;
137 		}
138 
139 		if (contexts[i].real_sock != real_sock) {
140 			continue;
141 		}
142 
143 		ctx = &contexts[i];
144 		break;
145 	}
146 
147 	k_sem_give(&contexts_lock);
148 
149 	return ctx;
150 }
151 
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static int response_cb(struct http_response *rsp,
153 		       enum http_final_call final_data,
154 		       void *user_data)
155 {
156 	struct websocket_context *ctx = user_data;
157 
158 	if (final_data == HTTP_DATA_MORE) {
159 		NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 			rsp->data_len);
161 		ctx->all_received = false;
162 	} else if (final_data == HTTP_DATA_FINAL) {
163 		NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 			rsp->data_len);
165 		ctx->all_received = true;
166 	}
167 
168 	return 0;
169 }
170 
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 			   size_t length)
173 {
174 	struct http_request *req = CONTAINER_OF(parser,
175 						struct http_request,
176 						internal.parser);
177 	struct websocket_context *ctx = req->internal.user_data;
178 	const char *ws_accept_str = "Sec-WebSocket-Accept";
179 	uint16_t len;
180 
181 	len = strlen(ws_accept_str);
182 	if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 		ctx->sec_accept_present = true;
184 	}
185 
186 	if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 		ctx->http_cb->on_header_field(parser, at, length);
188 	}
189 
190 	return 0;
191 }
192 
193 #define MAX_SEC_ACCEPT_LEN 32
194 
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 			   size_t length)
197 {
198 	struct http_request *req = CONTAINER_OF(parser,
199 						struct http_request,
200 						internal.parser);
201 	struct websocket_context *ctx = req->internal.user_data;
202 	char str[MAX_SEC_ACCEPT_LEN];
203 
204 	if (ctx->sec_accept_present) {
205 		int ret;
206 		size_t olen;
207 
208 		ctx->sec_accept_ok = false;
209 		ctx->sec_accept_present = false;
210 
211 		ret = base64_encode(str, sizeof(str) - 1, &olen,
212 				    ctx->sec_accept_key,
213 				    WS_SHA1_OUTPUT_LEN);
214 		if (ret == 0) {
215 			if (strncmp(at, str, length)) {
216 				NET_DBG("[%p] Security keys do not match "
217 					"%s vs %s", ctx, str, at);
218 			} else {
219 				ctx->sec_accept_ok = true;
220 			}
221 		}
222 	}
223 
224 	if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 		ctx->http_cb->on_header_value(parser, at, length);
226 	}
227 
228 	return 0;
229 }
230 
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 		      int32_t timeout, void *user_data)
233 {
234 	/* This is the expected Sec-WebSocket-Accept key. We are storing a
235 	 * pointer to this in ctx but the value is only used for the duration
236 	 * of this function call so there is no issue even if this variable
237 	 * is allocated from stack.
238 	 */
239 	uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 	struct http_parser_settings http_parser_settings;
241 	struct websocket_context *ctx;
242 	struct http_request req;
243 	int ret, fd, key_len;
244 	size_t olen;
245 	char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 	uint32_t rnd_value = sys_rand32_get();
247 	char sec_ws_key[] =
248 		"Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 	char *headers[] = {
250 		sec_ws_key,
251 		"Upgrade: websocket\r\n",
252 		"Connection: Upgrade\r\n",
253 		"Sec-WebSocket-Version: 13\r\n",
254 		NULL
255 	};
256 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
257 	psa_status_t psa_status;
258 	size_t hash_length;
259 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
260 
261 	fd = -1;
262 
263 	if (sock < 0 || wreq == NULL || wreq->host == NULL ||
264 	    wreq->url == NULL) {
265 		return -EINVAL;
266 	}
267 
268 	ctx = websocket_find(sock);
269 	if (ctx) {
270 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
271 			sock);
272 		return -EEXIST;
273 	}
274 
275 	ctx = websocket_get();
276 	if (!ctx) {
277 		return -ENOENT;
278 	}
279 
280 	ctx->real_sock = sock;
281 	ctx->recv_buf.buf = wreq->tmp_buf;
282 	ctx->recv_buf.size = wreq->tmp_buf_len;
283 	ctx->sec_accept_key = sec_accept_key;
284 	ctx->http_cb = wreq->http_cb;
285 	ctx->is_client = 1;
286 
287 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
288 	psa_status = psa_hash_compute(PSA_ALG_SHA_1, (const uint8_t *)&rnd_value, sizeof(rnd_value),
289 				      sec_accept_key, sizeof(sec_accept_key), &hash_length);
290 	if (psa_status != PSA_SUCCESS) {
291 		NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, psa_status);
292 		ret = -EPROTO;
293 		goto out;
294 	}
295 #else
296 	ret = mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value), sec_accept_key);
297 	if (ret != 0) {
298 		NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, ret);
299 		ret = -EPROTO;
300 		goto out;
301 	}
302 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
303 
304 
305 	ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
306 			    sizeof(sec_ws_key) -
307 					sizeof("Sec-Websocket-Key: "),
308 			    &olen, sec_accept_key,
309 			    /* We are only interested in 16 first bytes so
310 			     * subtract 4 from the SHA-1 length
311 			     */
312 			    sizeof(sec_accept_key) - 4);
313 	if (ret) {
314 		NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
315 		goto out;
316 	}
317 
318 	if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
319 		NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
320 			olen + sizeof("Sec-Websocket-Key: ") + 2,
321 			sizeof(sec_ws_key));
322 		ret = -EMSGSIZE;
323 		goto out;
324 	}
325 
326 	memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
327 	       HTTP_CRLF, sizeof(HTTP_CRLF));
328 
329 	memset(&req, 0, sizeof(req));
330 
331 	req.method = HTTP_GET;
332 	req.url = wreq->url;
333 	req.host = wreq->host;
334 	req.protocol = "HTTP/1.1";
335 	req.header_fields = (const char **)headers;
336 	req.optional_headers_cb = wreq->optional_headers_cb;
337 	req.optional_headers = wreq->optional_headers;
338 	req.response = response_cb;
339 	req.http_cb = &http_parser_settings;
340 	req.recv_buf = wreq->tmp_buf;
341 	req.recv_buf_len = wreq->tmp_buf_len;
342 
343 	/* We need to catch the Sec-WebSocket-Accept field in order to verify
344 	 * that it contains the stuff that we sent in Sec-WebSocket-Key field
345 	 * so setup HTTP callbacks so that we will get the needed fields.
346 	 */
347 	if (ctx->http_cb) {
348 		memcpy(&http_parser_settings, ctx->http_cb,
349 		       sizeof(http_parser_settings));
350 	} else {
351 		memset(&http_parser_settings, 0, sizeof(http_parser_settings));
352 	}
353 
354 	http_parser_settings.on_header_field = on_header_field;
355 	http_parser_settings.on_header_value = on_header_value;
356 
357 	/* Pre-calculate the expected Sec-Websocket-Accept field */
358 	key_len = MIN(sizeof(key_accept) - 1, olen);
359 	strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
360 		key_len);
361 
362 	olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
363 	memcpy(key_accept + key_len, WS_MAGIC, olen);
364 
365 	/* This SHA-1 value is then checked when we receive the response */
366 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
367 	psa_status = psa_hash_compute(PSA_ALG_SHA_1, (const uint8_t *)key_accept, olen + key_len,
368 				      sec_accept_key, sizeof(sec_accept_key), &hash_length);
369 	if (psa_status != PSA_SUCCESS) {
370 		NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, psa_status);
371 		ret = -EPROTO;
372 		goto out;
373 	}
374 #else
375 	ret = mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
376 	if (ret != 0) {
377 		NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, ret);
378 		ret = -EPROTO;
379 		goto out;
380 	}
381 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
382 
383 	ret = http_client_req(sock, &req, timeout, ctx);
384 	if (ret < 0) {
385 		NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
386 			wreq->host);
387 		ret = -ECONNABORTED;
388 		goto out;
389 	}
390 
391 	if (!(ctx->all_received && ctx->sec_accept_ok)) {
392 		NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
393 			ctx->all_received, ctx->sec_accept_ok);
394 		ret = -ECONNABORTED;
395 		goto out;
396 	}
397 
398 	ctx->user_data = user_data;
399 
400 	fd = zvfs_reserve_fd();
401 	if (fd < 0) {
402 		ret = -ENOSPC;
403 		goto out;
404 	}
405 
406 	ctx->sock = fd;
407 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
408 			    ZVFS_MODE_IFSOCK);
409 
410 	/* Call the user specified callback and if it accepts the connection
411 	 * then continue.
412 	 */
413 	if (wreq->cb) {
414 		ret = wreq->cb(fd, &req, user_data);
415 		if (ret < 0) {
416 			NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
417 			goto out;
418 		}
419 	}
420 
421 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
422 
423 	/* We will re-use the temp buffer in receive function. If there were
424 	 * any leftover data from HTTP headers processing, we need to reflect
425 	 * this in the count variable.
426 	 */
427 	ctx->recv_buf.count = req.data_len;
428 
429 	/* Init parser FSM */
430 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
431 
432 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, NET_SOCK_STREAM);
433 
434 	return fd;
435 
436 out:
437 	if (fd >= 0) {
438 		(void)zsock_close(fd);
439 	}
440 
441 	websocket_context_unref(ctx);
442 	return ret;
443 }
444 
websocket_disconnect(int ws_sock)445 int websocket_disconnect(int ws_sock)
446 {
447 	return zsock_close(ws_sock);
448 }
449 
websocket_internal_disconnect(struct websocket_context * ctx)450 static int websocket_internal_disconnect(struct websocket_context *ctx)
451 {
452 	int ret;
453 
454 	if (ctx == NULL) {
455 		return -ENOENT;
456 	}
457 
458 	NET_DBG("[%p] Disconnecting", ctx);
459 
460 	ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
461 				 ctx->is_client, true, SYS_FOREVER_MS);
462 	if (ret < 0) {
463 		NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
464 	}
465 
466 	(void)sock_obj_core_dealloc(ctx->sock);
467 
468 	websocket_context_unref(ctx);
469 
470 	return ret;
471 }
472 
websocket_close_vmeth(void * obj)473 static int websocket_close_vmeth(void *obj)
474 {
475 	struct websocket_context *ctx = obj;
476 	int ret;
477 
478 	ret = websocket_internal_disconnect(ctx);
479 	if (ret < 0) {
480 		/* Ignore error if we are not connected */
481 		if (ret != -ENOTCONN) {
482 			NET_DBG("[%p] Cannot close (%d)", obj, ret);
483 
484 			errno = -ret;
485 			return -1;
486 		}
487 
488 		ret = 0;
489 	}
490 
491 	return ret;
492 }
493 
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)494 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
495 					 int timeout)
496 {
497 	int fd_backup[CONFIG_ZVFS_POLL_MAX];
498 	const struct fd_op_vtable *vtable;
499 	void *ctx;
500 	int ret = 0;
501 	int i;
502 
503 	/* Overwrite websocket file descriptors with underlying ones. */
504 	for (i = 0; i < nfds; i++) {
505 		fd_backup[i] = fds[i].fd;
506 
507 		ctx = zvfs_get_fd_obj(fds[i].fd,
508 				   (const struct fd_op_vtable *)
509 						     &websocket_fd_op_vtable,
510 				   0);
511 		if (ctx == NULL) {
512 			continue;
513 		}
514 
515 		fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
516 	}
517 
518 	/* Get offloaded sockets vtable. */
519 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
520 				      (const struct fd_op_vtable **)&vtable,
521 				      NULL);
522 	if (ctx == NULL) {
523 		errno = EINVAL;
524 		ret = -1;
525 		goto exit;
526 	}
527 
528 	ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
529 				   fds, nfds, timeout);
530 
531 exit:
532 	/* Restore original fds. */
533 	for (i = 0; i < nfds; i++) {
534 		fds[i].fd = fd_backup[i];
535 	}
536 
537 	return ret;
538 }
539 
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)540 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
541 {
542 	struct websocket_context *ctx = obj;
543 
544 	switch (request) {
545 	case ZFD_IOCTL_POLL_OFFLOAD: {
546 		struct zsock_pollfd *fds;
547 		int nfds;
548 		int timeout;
549 
550 		fds = va_arg(args, struct zsock_pollfd *);
551 		nfds = va_arg(args, int);
552 		timeout = va_arg(args, int);
553 
554 		return websocket_poll_offload(fds, nfds, timeout);
555 	}
556 
557 	case ZFD_IOCTL_SET_LOCK:
558 		/* Ignore, don't want to overwrite underlying socket lock. */
559 		return 0;
560 
561 	default: {
562 		const struct fd_op_vtable *vtable;
563 		void *core_obj;
564 
565 		core_obj = zvfs_get_fd_obj_and_vtable(
566 				ctx->real_sock,
567 				(const struct fd_op_vtable **)&vtable,
568 				NULL);
569 		if (core_obj == NULL) {
570 			errno = EBADF;
571 			return -1;
572 		}
573 
574 		/* Pass the call to the core socket implementation. */
575 		return vtable->ioctl(core_obj, request, args);
576 	}
577 	}
578 
579 	return 0;
580 }
581 
582 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct net_msghdr * message,int flags,const k_timepoint_t req_end_timepoint)583 static int sendmsg_all(int sock, const struct net_msghdr *message, int flags,
584 			const k_timepoint_t req_end_timepoint)
585 {
586 	int ret, i;
587 	size_t offset = 0;
588 	size_t total_len = 0;
589 
590 	for (i = 0; i < message->msg_iovlen; i++) {
591 		total_len += message->msg_iov[i].iov_len;
592 	}
593 
594 	while (offset < total_len) {
595 		ret = zsock_sendmsg(sock, message, flags);
596 
597 		if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
598 			struct zsock_pollfd pfd;
599 			int pollres;
600 			k_ticks_t req_timeout_ticks =
601 				sys_timepoint_timeout(req_end_timepoint).ticks;
602 			int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
603 
604 			pfd.fd = sock;
605 			pfd.events = ZSOCK_POLLOUT;
606 			pollres = zsock_poll(&pfd, 1, req_timeout_ms);
607 			if (pollres == 0) {
608 				return -ETIMEDOUT;
609 			} else if (pollres > 0) {
610 				continue;
611 			} else {
612 				return -errno;
613 			}
614 		} else if (ret < 0) {
615 			return -errno;
616 		}
617 
618 		offset += ret;
619 		if (offset >= total_len) {
620 			break;
621 		}
622 
623 		/* Update msghdr for the next iteration. */
624 		for (i = 0; i < message->msg_iovlen; i++) {
625 			if (ret < message->msg_iov[i].iov_len) {
626 				message->msg_iov[i].iov_len -= ret;
627 				message->msg_iov[i].iov_base =
628 					(uint8_t *)message->msg_iov[i].iov_base + ret;
629 				break;
630 			}
631 
632 			ret -= message->msg_iov[i].iov_len;
633 			message->msg_iov[i].iov_len = 0;
634 		}
635 	}
636 
637 	return total_len;
638 }
639 #endif /* !defined(CONFIG_NET_TEST) */
640 
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)641 static int websocket_prepare_and_send(struct websocket_context *ctx,
642 				      uint8_t *header, size_t header_len,
643 				      uint8_t *payload, size_t payload_len,
644 				      int32_t timeout)
645 {
646 	struct net_iovec io_vector[2];
647 	struct net_msghdr msg;
648 
649 	io_vector[0].iov_base = header;
650 	io_vector[0].iov_len = header_len;
651 	io_vector[1].iov_base = payload;
652 	io_vector[1].iov_len = payload_len;
653 
654 	memset(&msg, 0, sizeof(msg));
655 
656 	msg.msg_iov = io_vector;
657 	msg.msg_iovlen = ARRAY_SIZE(io_vector);
658 
659 	if (HEXDUMP_SENT_PACKETS) {
660 		LOG_HEXDUMP_DBG(header, header_len, "Header");
661 		if ((payload != NULL) && (payload_len > 0)) {
662 			LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
663 		} else {
664 			LOG_DBG("No payload");
665 		}
666 	}
667 
668 #if defined(CONFIG_NET_TEST)
669 	/* Simulate a case where the payload is split to two. The unit test
670 	 * does not set mask bit in this case.
671 	 */
672 	return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
673 #else
674 	k_timeout_t tout = K_FOREVER;
675 
676 	if (timeout != SYS_FOREVER_MS) {
677 		tout = K_MSEC(timeout);
678 	}
679 
680 	k_timeout_t req_timeout = K_MSEC(timeout);
681 	k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
682 
683 	return sendmsg_all(ctx->real_sock, &msg,
684 			   K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
685 			   req_end_timepoint);
686 #endif /* CONFIG_NET_TEST */
687 }
688 
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)689 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
690 		       enum websocket_opcode opcode, bool mask, bool final,
691 		       int32_t timeout)
692 {
693 	struct websocket_context *ctx;
694 	uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
695 	uint8_t *data_to_send = (uint8_t *)payload;
696 	int ret;
697 
698 	if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
699 	    opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
700 	    opcode != WEBSOCKET_OPCODE_CONTINUE &&
701 	    opcode != WEBSOCKET_OPCODE_CLOSE &&
702 	    opcode != WEBSOCKET_OPCODE_PING &&
703 	    opcode != WEBSOCKET_OPCODE_PONG) {
704 		return -EINVAL;
705 	}
706 
707 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
708 	if (ctx == NULL) {
709 		return -EBADF;
710 	}
711 
712 #if !defined(CONFIG_NET_TEST)
713 	/* Websocket unit test does not use context from pool but allocates
714 	 * its own, hence skip the check.
715 	 */
716 
717 	if (!PART_OF_ARRAY(contexts, ctx)) {
718 		return -ENOENT;
719 	}
720 #endif /* !defined(CONFIG_NET_TEST) */
721 
722 	NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
723 		mask, final ? "final" : "more");
724 
725 	memset(header, 0, sizeof(header));
726 
727 	/* Is this the last packet? */
728 	header[0] = final ? BIT(7) : 0;
729 
730 	/* Text, binary, ping, pong or close ? */
731 	header[0] |= opcode;
732 
733 	/* Masking */
734 	header[1] = mask ? BIT(7) : 0;
735 
736 	if (payload_len < 126) {
737 		header[1] |= payload_len;
738 	} else if (payload_len < 65536) {
739 		header[1] |= 126;
740 		header[2] = payload_len >> 8;
741 		header[3] = payload_len;
742 		hdr_len += 2;
743 	} else {
744 		header[1] |= 127;
745 		header[2] = 0;
746 		header[3] = 0;
747 		header[4] = 0;
748 		header[5] = 0;
749 		header[6] = payload_len >> 24;
750 		header[7] = payload_len >> 16;
751 		header[8] = payload_len >> 8;
752 		header[9] = payload_len;
753 		hdr_len += 8;
754 	}
755 
756 	/* Add masking value if needed */
757 	if (mask) {
758 		int i;
759 
760 		ctx->masking_value = sys_rand32_get();
761 
762 		header[hdr_len++] |= ctx->masking_value >> 24;
763 		header[hdr_len++] |= ctx->masking_value >> 16;
764 		header[hdr_len++] |= ctx->masking_value >> 8;
765 		header[hdr_len++] |= ctx->masking_value;
766 
767 		if ((payload != NULL) && (payload_len > 0)) {
768 			data_to_send = k_malloc(payload_len);
769 			if (!data_to_send) {
770 				return -ENOMEM;
771 			}
772 
773 			memcpy(data_to_send, payload, payload_len);
774 
775 			for (i = 0; i < payload_len; i++) {
776 				data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
777 			}
778 		}
779 	}
780 
781 	ret = websocket_prepare_and_send(ctx, header, hdr_len,
782 					 data_to_send, payload_len, timeout);
783 	if (ret < 0) {
784 		NET_DBG("Cannot send ws msg (%d)", -errno);
785 		goto quit;
786 	}
787 
788 quit:
789 	if (data_to_send != payload) {
790 		k_free(data_to_send);
791 	}
792 
793 	/* Do no math with 0 and error codes */
794 	if (ret <= 0) {
795 		return ret;
796 	}
797 
798 	return ret - hdr_len;
799 }
800 
websocket_opcode2flag(uint8_t data)801 static uint32_t websocket_opcode2flag(uint8_t data)
802 {
803 	switch (data & 0x0f) {
804 	case WEBSOCKET_OPCODE_DATA_TEXT:
805 		return WEBSOCKET_FLAG_TEXT;
806 	case WEBSOCKET_OPCODE_DATA_BINARY:
807 		return WEBSOCKET_FLAG_BINARY;
808 	case WEBSOCKET_OPCODE_CLOSE:
809 		return WEBSOCKET_FLAG_CLOSE;
810 	case WEBSOCKET_OPCODE_PING:
811 		return WEBSOCKET_FLAG_PING;
812 	case WEBSOCKET_OPCODE_PONG:
813 		return WEBSOCKET_FLAG_PONG;
814 	default:
815 		break;
816 	}
817 	return 0;
818 }
819 
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)820 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
821 {
822 	int len;
823 	uint8_t data;
824 	size_t parsed_count = 0;
825 
826 	do {
827 		if (parsed_count >= ctx->recv_buf.count) {
828 			return parsed_count;
829 		}
830 		if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
831 			data = ctx->recv_buf.buf[parsed_count++];
832 
833 			switch (ctx->parser_state) {
834 			case WEBSOCKET_PARSER_STATE_OPCODE:
835 				ctx->message_type = websocket_opcode2flag(data);
836 				if ((data & 0x80) != 0) {
837 					ctx->message_type |= WEBSOCKET_FLAG_FINAL;
838 				}
839 				ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
840 				break;
841 			case WEBSOCKET_PARSER_STATE_LENGTH:
842 				ctx->masked = (data & 0x80) != 0;
843 				len = data & 0x7f;
844 				if (len < 126) {
845 					ctx->message_len = len;
846 					if (ctx->masked) {
847 						ctx->masking_value = 0;
848 						ctx->parser_remaining = 4;
849 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
850 					} else {
851 						ctx->parser_remaining = ctx->message_len;
852 						ctx->parser_state =
853 							(ctx->parser_remaining == 0)
854 								? WEBSOCKET_PARSER_STATE_OPCODE
855 								: WEBSOCKET_PARSER_STATE_PAYLOAD;
856 					}
857 				} else {
858 					ctx->message_len = 0;
859 					ctx->parser_remaining = (len < 127) ? 2 : 8;
860 					ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
861 				}
862 				break;
863 			case WEBSOCKET_PARSER_STATE_EXT_LEN:
864 				ctx->parser_remaining--;
865 				ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
866 				if (ctx->parser_remaining == 0) {
867 					if (ctx->masked) {
868 						ctx->masking_value = 0;
869 						ctx->parser_remaining = 4;
870 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
871 					} else {
872 						ctx->parser_remaining = ctx->message_len;
873 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
874 					}
875 				}
876 				break;
877 			case WEBSOCKET_PARSER_STATE_MASK:
878 				ctx->parser_remaining--;
879 				ctx->masking_value |=
880 					(uint32_t)((uint64_t)data << (ctx->parser_remaining * 8));
881 				if (ctx->parser_remaining == 0) {
882 					if (ctx->message_len == 0) {
883 						ctx->parser_remaining = 0;
884 						ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
885 					} else {
886 						ctx->parser_remaining = ctx->message_len;
887 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
888 					}
889 				}
890 				break;
891 			default:
892 				return -EFAULT;
893 			}
894 
895 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
896 			if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
897 			    ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
898 			     (ctx->message_len == 0))) {
899 				NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
900 					ctx->masked ? "" : "un",
901 					ctx->masked ? ctx->masking_value : 0, ctx->message_type,
902 					(size_t)ctx->message_len);
903 			}
904 #endif
905 		} else {
906 			size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
907 			size_t payload_in_recv_buf =
908 				MIN(remaining_in_recv_buf, ctx->parser_remaining);
909 			size_t free_in_payload_buf = payload->size - payload->count;
910 			size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
911 
912 			if (free_in_payload_buf == 0) {
913 				break;
914 			}
915 
916 			memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
917 			       ready_to_copy);
918 			parsed_count += ready_to_copy;
919 			payload->count += ready_to_copy;
920 			ctx->parser_remaining -= ready_to_copy;
921 			if (ctx->parser_remaining == 0) {
922 				ctx->parser_remaining = 0;
923 				ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
924 			}
925 		}
926 
927 	} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
928 
929 	return parsed_count;
930 }
931 
932 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)933 static int wait_rx(int sock, int timeout)
934 {
935 	struct zsock_pollfd fds = {
936 		.fd = sock,
937 		.events = ZSOCK_POLLIN,
938 	};
939 	int ret;
940 
941 	ret = zsock_poll(&fds, 1, timeout);
942 	if (ret < 0) {
943 		return ret;
944 	}
945 
946 	if (ret == 0) {
947 		/* Timeout */
948 		return -EAGAIN;
949 	}
950 
951 	if (fds.revents & ZSOCK_POLLNVAL) {
952 		return -EBADF;
953 	}
954 
955 	if (fds.revents & ZSOCK_POLLERR) {
956 		return -EIO;
957 	}
958 
959 	return 0;
960 }
961 
timeout_to_ms(k_timeout_t * timeout)962 static int timeout_to_ms(k_timeout_t *timeout)
963 {
964 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
965 		return 0;
966 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
967 		return SYS_FOREVER_MS;
968 	} else {
969 		return k_ticks_to_ms_floor32(timeout->ticks);
970 	}
971 }
972 
973 #endif /* !defined(CONFIG_NET_TEST) */
974 
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)975 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
976 		       uint32_t *message_type, uint64_t *remaining, int32_t timeout)
977 {
978 	struct websocket_context *ctx;
979 	int ret;
980 	k_timepoint_t end;
981 	k_timeout_t tout = K_FOREVER;
982 	struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
983 
984 	if (timeout != SYS_FOREVER_MS) {
985 		tout = K_MSEC(timeout);
986 	}
987 
988 	if ((buf == NULL) || (buf_len == 0)) {
989 		return -EINVAL;
990 	}
991 
992 	end = sys_timepoint_calc(tout);
993 
994 #if defined(CONFIG_NET_TEST)
995 	struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
996 
997 	if (test_data == NULL) {
998 		return -EBADF;
999 	}
1000 
1001 	ctx = test_data->ctx;
1002 #else
1003 	ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
1004 	if (ctx == NULL) {
1005 		return -EBADF;
1006 	}
1007 
1008 	if (!PART_OF_ARRAY(contexts, ctx)) {
1009 		return -ENOENT;
1010 	}
1011 #endif /* CONFIG_NET_TEST */
1012 
1013 	do {
1014 		size_t parsed_count;
1015 
1016 		if (ctx->recv_buf.count == 0) {
1017 #if defined(CONFIG_NET_TEST)
1018 			size_t input_len = MIN(ctx->recv_buf.size,
1019 					       test_data->input_len - test_data->input_pos);
1020 
1021 			if (input_len > 0) {
1022 				memcpy(ctx->recv_buf.buf,
1023 				       &test_data->input_buf[test_data->input_pos], input_len);
1024 				test_data->input_pos += input_len;
1025 				ret = input_len;
1026 			} else {
1027 				/* emulate timeout */
1028 				ret = -EAGAIN;
1029 			}
1030 #else
1031 			tout = sys_timepoint_timeout(end);
1032 
1033 			ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
1034 			if (ret == 0) {
1035 				ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
1036 						 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
1037 				if (ret < 0) {
1038 					ret = -errno;
1039 				}
1040 			}
1041 #endif /* CONFIG_NET_TEST */
1042 
1043 			if (ret < 0) {
1044 				if ((ret == -EAGAIN) && (payload.count > 0)) {
1045 					/* go to unmasking */
1046 					break;
1047 				}
1048 				return ret;
1049 			}
1050 
1051 			if (ret == 0) {
1052 				/* Socket closed */
1053 				return -ENOTCONN;
1054 			}
1055 
1056 			ctx->recv_buf.count = ret;
1057 
1058 			NET_DBG("[%p] Received %d bytes", ctx, ret);
1059 		}
1060 
1061 		ret = websocket_parse(ctx, &payload);
1062 		if (ret < 0) {
1063 			return ret;
1064 		}
1065 		parsed_count = ret;
1066 
1067 		if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1068 		    (payload.count >= payload.size)) {
1069 			if (remaining != NULL) {
1070 				*remaining = ctx->parser_remaining;
1071 			}
1072 			if (message_type != NULL) {
1073 				*message_type = ctx->message_type;
1074 			}
1075 
1076 			size_t left = ctx->recv_buf.count - parsed_count;
1077 
1078 			if (left > 0) {
1079 				memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1080 			}
1081 			ctx->recv_buf.count = left;
1082 			break;
1083 		}
1084 
1085 		ctx->recv_buf.count -= parsed_count;
1086 
1087 	} while (true);
1088 
1089 	/* Unmask the data */
1090 	if (ctx->masked) {
1091 		uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1092 		size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1093 
1094 		for (size_t i = 0; i < payload.count; i++) {
1095 			size_t m = data_buf_offset % 4;
1096 
1097 			payload.buf[i] ^= mask_as_bytes[3 - m];
1098 			data_buf_offset++;
1099 		}
1100 	}
1101 
1102 	if (ctx->message_type == WEBSOCKET_FLAG_CLOSE) {
1103 		return websocket_internal_disconnect(ctx);
1104 	}
1105 
1106 	return payload.count;
1107 }
1108 
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1109 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1110 			  size_t buf_len, int32_t timeout)
1111 {
1112 	int ret;
1113 
1114 	NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1115 
1116 	ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1117 				 ctx->is_client, true, timeout);
1118 	if (ret < 0) {
1119 		errno = -ret;
1120 		return -1;
1121 	}
1122 
1123 	NET_DBG("[%p] Sent %d bytes", ctx, ret);
1124 
1125 	sock_obj_core_update_send_stats(ctx->sock, ret);
1126 
1127 	return ret;
1128 }
1129 
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1130 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1131 			  size_t buf_len, int32_t timeout)
1132 {
1133 	uint32_t message_type;
1134 	uint64_t remaining;
1135 	int ret;
1136 
1137 	NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1138 
1139 	/* TODO: add support for recvmsg() so that we could return the
1140 	 *       websocket specific information in ancillary data.
1141 	 */
1142 	ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1143 				 &remaining, timeout);
1144 	if (ret < 0) {
1145 		if (ret == -ENOTCONN) {
1146 			ret = 0;
1147 		} else {
1148 			errno = -ret;
1149 			return -1;
1150 		}
1151 	}
1152 
1153 	NET_DBG("[%p] Received %d bytes", ctx, ret);
1154 
1155 	sock_obj_core_update_recv_stats(ctx->sock, ret);
1156 
1157 	return ret;
1158 }
1159 
websocket_read_vmeth(void * obj,void * buffer,size_t count)1160 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1161 {
1162 	return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1163 }
1164 
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1165 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1166 				     size_t count)
1167 {
1168 	return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1169 }
1170 
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)1171 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1172 				    int flags,
1173 				    const struct net_sockaddr *dest_addr,
1174 				    net_socklen_t addrlen)
1175 {
1176 	struct websocket_context *ctx = obj;
1177 	int32_t timeout = SYS_FOREVER_MS;
1178 
1179 	if (flags & ZSOCK_MSG_DONTWAIT) {
1180 		timeout = 0;
1181 	}
1182 
1183 	ARG_UNUSED(dest_addr);
1184 	ARG_UNUSED(addrlen);
1185 
1186 	return (ssize_t)websocket_send(ctx, buf, len, timeout);
1187 }
1188 
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)1189 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1190 				      int flags, struct net_sockaddr *src_addr,
1191 				      net_socklen_t *addrlen)
1192 {
1193 	struct websocket_context *ctx = obj;
1194 	int32_t timeout = SYS_FOREVER_MS;
1195 
1196 	if (flags & ZSOCK_MSG_DONTWAIT) {
1197 		timeout = 0;
1198 	}
1199 
1200 	ARG_UNUSED(src_addr);
1201 	ARG_UNUSED(addrlen);
1202 
1203 	return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1204 }
1205 
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1206 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1207 {
1208 	struct websocket_context *ctx;
1209 	int ret, fd;
1210 
1211 	if (sock < 0) {
1212 		return -EINVAL;
1213 	}
1214 
1215 	ctx = websocket_find(sock);
1216 	if (ctx) {
1217 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1218 		return -EEXIST;
1219 	}
1220 
1221 	ctx = websocket_get();
1222 	if (!ctx) {
1223 		return -ENOENT;
1224 	}
1225 
1226 	ctx->real_sock = sock;
1227 	ctx->recv_buf.buf = recv_buf;
1228 	ctx->recv_buf.size = recv_buf_len;
1229 	ctx->is_client = 0;
1230 
1231 	fd = zvfs_reserve_fd();
1232 	if (fd < 0) {
1233 		ret = -ENOSPC;
1234 		goto out;
1235 	}
1236 
1237 	ctx->sock = fd;
1238 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1239 			    ZVFS_MODE_IFSOCK);
1240 
1241 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1242 
1243 	ctx->recv_buf.count = 0;
1244 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1245 
1246 	(void)sock_obj_core_alloc_find(ctx->real_sock, fd, NET_SOCK_STREAM);
1247 
1248 	return fd;
1249 
1250 out:
1251 	websocket_context_unref(ctx);
1252 
1253 	return ret;
1254 }
1255 
websocket_search(int sock)1256 static struct websocket_context *websocket_search(int sock)
1257 {
1258 	struct websocket_context *ctx = NULL;
1259 	int i;
1260 
1261 	k_sem_take(&contexts_lock, K_FOREVER);
1262 
1263 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1264 		if (!websocket_context_is_used(&contexts[i])) {
1265 			continue;
1266 		}
1267 
1268 		if (contexts[i].sock != sock) {
1269 			continue;
1270 		}
1271 
1272 		ctx = &contexts[i];
1273 		break;
1274 	}
1275 
1276 	k_sem_give(&contexts_lock);
1277 
1278 	return ctx;
1279 }
1280 
websocket_unregister(int sock)1281 int websocket_unregister(int sock)
1282 {
1283 	struct websocket_context *ctx;
1284 
1285 	if (sock < 0) {
1286 		return -EINVAL;
1287 	}
1288 
1289 	ctx = websocket_search(sock);
1290 	if (ctx == NULL) {
1291 		NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1292 		return -ENOENT;
1293 	}
1294 
1295 	if (ctx->real_sock < 0) {
1296 		return -EALREADY;
1297 	}
1298 
1299 	(void)zsock_close(sock);
1300 	(void)zsock_close(ctx->real_sock);
1301 
1302 	ctx->real_sock = -1;
1303 	ctx->sock = -1;
1304 
1305 	return 0;
1306 }
1307 
1308 static const struct socket_op_vtable websocket_fd_op_vtable = {
1309 	.fd_vtable = {
1310 		.read = websocket_read_vmeth,
1311 		.write = websocket_write_vmeth,
1312 		.close = websocket_close_vmeth,
1313 		.ioctl = websocket_ioctl_vmeth,
1314 	},
1315 	.sendto = websocket_sendto_ctx,
1316 	.recvfrom = websocket_recvfrom_ctx,
1317 };
1318 
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1319 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1320 {
1321 	int i;
1322 
1323 	k_sem_take(&contexts_lock, K_FOREVER);
1324 
1325 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1326 		if (!websocket_context_is_used(&contexts[i])) {
1327 			continue;
1328 		}
1329 
1330 		k_mutex_lock(&contexts[i].lock, K_FOREVER);
1331 
1332 		cb(&contexts[i], user_data);
1333 
1334 		k_mutex_unlock(&contexts[i].lock);
1335 	}
1336 
1337 	k_sem_give(&contexts_lock);
1338 }
1339 
websocket_init(void)1340 void websocket_init(void)
1341 {
1342 	k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1343 }
1344