1 /** @file
2  * @brief Websocket client API
3  *
4  * An API for applications to setup a websocket connections.
5  */
6 
7 /*
8  * Copyright (c) 2019 Intel Corporation
9  *
10  * SPDX-License-Identifier: Apache-2.0
11  */
12 
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15 
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21 
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33 
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38 
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42 
43 /* If you want to see the data that is being sent or received,
44  * then you can enable debugging and set the following variables to 1.
45  * This will print a lot of data so is not enabled by default.
46  */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49 
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51 
52 static struct k_sem contexts_lock;
53 
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55 
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59 
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 	switch (opcode) {
63 	case WEBSOCKET_OPCODE_DATA_TEXT:
64 		return "TEXT";
65 	case WEBSOCKET_OPCODE_DATA_BINARY:
66 		return "BIN";
67 	case WEBSOCKET_OPCODE_CONTINUE:
68 		return "CONT";
69 	case WEBSOCKET_OPCODE_CLOSE:
70 		return "CLOSE";
71 	case WEBSOCKET_OPCODE_PING:
72 		return "PING";
73 	case WEBSOCKET_OPCODE_PONG:
74 		return "PONG";
75 	default:
76 		break;
77 	}
78 
79 	return NULL;
80 }
81 
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 	int old_rc = atomic_inc(&ctx->refcount);
85 
86 	return old_rc + 1;
87 }
88 
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 	int old_rc = atomic_dec(&ctx->refcount);
92 
93 	if (old_rc != 1) {
94 		return old_rc - 1;
95 	}
96 
97 	return 0;
98 }
99 
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 	NET_ASSERT(ctx);
103 
104 	return !!atomic_get(&ctx->refcount);
105 }
106 
websocket_get(void)107 static struct websocket_context *websocket_get(void)
108 {
109 	struct websocket_context *ctx = NULL;
110 	int i;
111 
112 	k_sem_take(&contexts_lock, K_FOREVER);
113 
114 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
115 		if (websocket_context_is_used(&contexts[i])) {
116 			continue;
117 		}
118 
119 		websocket_context_ref(&contexts[i]);
120 		ctx = &contexts[i];
121 		break;
122 	}
123 
124 	k_sem_give(&contexts_lock);
125 
126 	return ctx;
127 }
128 
websocket_find(int real_sock)129 static struct websocket_context *websocket_find(int real_sock)
130 {
131 	struct websocket_context *ctx = NULL;
132 	int i;
133 
134 	k_sem_take(&contexts_lock, K_FOREVER);
135 
136 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
137 		if (!websocket_context_is_used(&contexts[i])) {
138 			continue;
139 		}
140 
141 		if (contexts[i].real_sock != real_sock) {
142 			continue;
143 		}
144 
145 		ctx = &contexts[i];
146 		break;
147 	}
148 
149 	k_sem_give(&contexts_lock);
150 
151 	return ctx;
152 }
153 
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)154 static void response_cb(struct http_response *rsp,
155 			enum http_final_call final_data,
156 			void *user_data)
157 {
158 	struct websocket_context *ctx = user_data;
159 
160 	if (final_data == HTTP_DATA_MORE) {
161 		NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
162 			rsp->data_len);
163 		ctx->all_received = false;
164 	} else if (final_data == HTTP_DATA_FINAL) {
165 		NET_DBG("[%p] All the data received (%zd bytes)", ctx,
166 			rsp->data_len);
167 		ctx->all_received = true;
168 	}
169 }
170 
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 			   size_t length)
173 {
174 	struct http_request *req = CONTAINER_OF(parser,
175 						struct http_request,
176 						internal.parser);
177 	struct websocket_context *ctx = req->internal.user_data;
178 	const char *ws_accept_str = "Sec-WebSocket-Accept";
179 	uint16_t len;
180 
181 	len = strlen(ws_accept_str);
182 	if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 		ctx->sec_accept_present = true;
184 	}
185 
186 	if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 		ctx->http_cb->on_header_field(parser, at, length);
188 	}
189 
190 	return 0;
191 }
192 
193 #define MAX_SEC_ACCEPT_LEN 32
194 
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 			   size_t length)
197 {
198 	struct http_request *req = CONTAINER_OF(parser,
199 						struct http_request,
200 						internal.parser);
201 	struct websocket_context *ctx = req->internal.user_data;
202 	char str[MAX_SEC_ACCEPT_LEN];
203 
204 	if (ctx->sec_accept_present) {
205 		int ret;
206 		size_t olen;
207 
208 		ctx->sec_accept_ok = false;
209 		ctx->sec_accept_present = false;
210 
211 		ret = base64_encode(str, sizeof(str) - 1, &olen,
212 				    ctx->sec_accept_key,
213 				    WS_SHA1_OUTPUT_LEN);
214 		if (ret == 0) {
215 			if (strncmp(at, str, length)) {
216 				NET_DBG("[%p] Security keys do not match "
217 					"%s vs %s", ctx, str, at);
218 			} else {
219 				ctx->sec_accept_ok = true;
220 			}
221 		}
222 	}
223 
224 	if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 		ctx->http_cb->on_header_value(parser, at, length);
226 	}
227 
228 	return 0;
229 }
230 
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 		      int32_t timeout, void *user_data)
233 {
234 	/* This is the expected Sec-WebSocket-Accept key. We are storing a
235 	 * pointer to this in ctx but the value is only used for the duration
236 	 * of this function call so there is no issue even if this variable
237 	 * is allocated from stack.
238 	 */
239 	uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 	struct http_parser_settings http_parser_settings;
241 	struct websocket_context *ctx;
242 	struct http_request req;
243 	int ret, fd, key_len;
244 	size_t olen;
245 	char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 	uint32_t rnd_value = sys_rand32_get();
247 	char sec_ws_key[] =
248 		"Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 	char *headers[] = {
250 		sec_ws_key,
251 		"Upgrade: websocket\r\n",
252 		"Connection: Upgrade\r\n",
253 		"Sec-WebSocket-Version: 13\r\n",
254 		NULL
255 	};
256 
257 	fd = -1;
258 
259 	if (sock < 0 || wreq == NULL || wreq->host == NULL ||
260 	    wreq->url == NULL) {
261 		return -EINVAL;
262 	}
263 
264 	ctx = websocket_find(sock);
265 	if (ctx) {
266 		NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
267 			sock);
268 		return -EEXIST;
269 	}
270 
271 	ctx = websocket_get();
272 	if (!ctx) {
273 		return -ENOENT;
274 	}
275 
276 	ctx->real_sock = sock;
277 	ctx->recv_buf.buf = wreq->tmp_buf;
278 	ctx->recv_buf.size = wreq->tmp_buf_len;
279 	ctx->sec_accept_key = sec_accept_key;
280 	ctx->http_cb = wreq->http_cb;
281 
282 	mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
283 			 sec_accept_key);
284 
285 	ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
286 			    sizeof(sec_ws_key) -
287 					sizeof("Sec-Websocket-Key: "),
288 			    &olen, sec_accept_key,
289 			    /* We are only interested in 16 first bytes so
290 			     * subtract 4 from the SHA-1 length
291 			     */
292 			    sizeof(sec_accept_key) - 4);
293 	if (ret) {
294 		NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
295 		goto out;
296 	}
297 
298 	if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
299 		NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
300 			olen + sizeof("Sec-Websocket-Key: ") + 2,
301 			sizeof(sec_ws_key));
302 		ret = -EMSGSIZE;
303 		goto out;
304 	}
305 
306 	memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
307 	       HTTP_CRLF, sizeof(HTTP_CRLF));
308 
309 	memset(&req, 0, sizeof(req));
310 
311 	req.method = HTTP_GET;
312 	req.url = wreq->url;
313 	req.host = wreq->host;
314 	req.protocol = "HTTP/1.1";
315 	req.header_fields = (const char **)headers;
316 	req.optional_headers_cb = wreq->optional_headers_cb;
317 	req.optional_headers = wreq->optional_headers;
318 	req.response = response_cb;
319 	req.http_cb = &http_parser_settings;
320 	req.recv_buf = wreq->tmp_buf;
321 	req.recv_buf_len = wreq->tmp_buf_len;
322 
323 	/* We need to catch the Sec-WebSocket-Accept field in order to verify
324 	 * that it contains the stuff that we sent in Sec-WebSocket-Key field
325 	 * so setup HTTP callbacks so that we will get the needed fields.
326 	 */
327 	if (ctx->http_cb) {
328 		memcpy(&http_parser_settings, ctx->http_cb,
329 		       sizeof(http_parser_settings));
330 	} else {
331 		memset(&http_parser_settings, 0, sizeof(http_parser_settings));
332 	}
333 
334 	http_parser_settings.on_header_field = on_header_field;
335 	http_parser_settings.on_header_value = on_header_value;
336 
337 	/* Pre-calculate the expected Sec-Websocket-Accept field */
338 	key_len = MIN(sizeof(key_accept) - 1, olen);
339 	strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
340 		key_len);
341 
342 	olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
343 	strncpy(key_accept + key_len, WS_MAGIC, olen);
344 
345 	/* This SHA-1 value is then checked when we receive the response */
346 	mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
347 
348 	ret = http_client_req(sock, &req, timeout, ctx);
349 	if (ret < 0) {
350 		NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
351 			wreq->host);
352 		ret = -ECONNABORTED;
353 		goto out;
354 	}
355 
356 	if (!(ctx->all_received && ctx->sec_accept_ok)) {
357 		NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
358 			ctx->all_received, ctx->sec_accept_ok);
359 		ret = -ECONNABORTED;
360 		goto out;
361 	}
362 
363 	ctx->user_data = user_data;
364 
365 	fd = z_reserve_fd();
366 	if (fd < 0) {
367 		ret = -ENOSPC;
368 		goto out;
369 	}
370 
371 	ctx->sock = fd;
372 	z_finalize_fd(fd, ctx,
373 		      (const struct fd_op_vtable *)&websocket_fd_op_vtable);
374 
375 	/* Call the user specified callback and if it accepts the connection
376 	 * then continue.
377 	 */
378 	if (wreq->cb) {
379 		ret = wreq->cb(fd, &req, user_data);
380 		if (ret < 0) {
381 			NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
382 			goto out;
383 		}
384 	}
385 
386 	NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
387 
388 	/* We will re-use the temp buffer in receive function if needed but
389 	 * in order that to work the amount of data in buffer must be set to 0
390 	 */
391 	ctx->recv_buf.count = 0;
392 
393 	/* Init parser FSM */
394 	ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
395 
396 	return fd;
397 
398 out:
399 	if (fd >= 0) {
400 		(void)close(fd);
401 	}
402 
403 	websocket_context_unref(ctx);
404 	return ret;
405 }
406 
websocket_disconnect(int ws_sock)407 int websocket_disconnect(int ws_sock)
408 {
409 	return close(ws_sock);
410 }
411 
websocket_interal_disconnect(struct websocket_context * ctx)412 static int websocket_interal_disconnect(struct websocket_context *ctx)
413 {
414 	int ret;
415 
416 	if (ctx == NULL) {
417 		return -ENOENT;
418 	}
419 
420 	NET_DBG("[%p] Disconnecting", ctx);
421 
422 	ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
423 				 true, true, SYS_FOREVER_MS);
424 	if (ret < 0) {
425 		NET_ERR("[%p] Failed to send close message (err %d).", ctx, ret);
426 	}
427 
428 	websocket_context_unref(ctx);
429 
430 	return ret;
431 }
432 
websocket_close_vmeth(void * obj)433 static int websocket_close_vmeth(void *obj)
434 {
435 	struct websocket_context *ctx = obj;
436 	int ret;
437 
438 	ret = websocket_interal_disconnect(ctx);
439 	if (ret < 0) {
440 		NET_DBG("[%p] Cannot close (%d)", obj, ret);
441 
442 		errno = -ret;
443 		return -1;
444 	}
445 
446 	return ret;
447 }
448 
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)449 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
450 					 int timeout)
451 {
452 	int fd_backup[CONFIG_NET_SOCKETS_POLL_MAX];
453 	const struct fd_op_vtable *vtable;
454 	void *ctx;
455 	int ret = 0;
456 	int i;
457 
458 	/* Overwrite websocket file descriptors with underlying ones. */
459 	for (i = 0; i < nfds; i++) {
460 		fd_backup[i] = fds[i].fd;
461 
462 		ctx = z_get_fd_obj(fds[i].fd,
463 				   (const struct fd_op_vtable *)
464 						     &websocket_fd_op_vtable,
465 				   0);
466 		if (ctx == NULL) {
467 			continue;
468 		}
469 
470 		fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
471 	}
472 
473 	/* Get offloaded sockets vtable. */
474 	ctx = z_get_fd_obj_and_vtable(fds[0].fd,
475 				      (const struct fd_op_vtable **)&vtable,
476 				      NULL);
477 	if (ctx == NULL) {
478 		errno = EINVAL;
479 		ret = -1;
480 		goto exit;
481 	}
482 
483 	ret = z_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
484 				   fds, nfds, timeout);
485 
486 exit:
487 	/* Restore original fds. */
488 	for (i = 0; i < nfds; i++) {
489 		fds[i].fd = fd_backup[i];
490 	}
491 
492 	return ret;
493 }
494 
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)495 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
496 {
497 	struct websocket_context *ctx = obj;
498 
499 	switch (request) {
500 	case ZFD_IOCTL_POLL_OFFLOAD: {
501 		struct zsock_pollfd *fds;
502 		int nfds;
503 		int timeout;
504 
505 		fds = va_arg(args, struct zsock_pollfd *);
506 		nfds = va_arg(args, int);
507 		timeout = va_arg(args, int);
508 
509 		return websocket_poll_offload(fds, nfds, timeout);
510 	}
511 
512 	case ZFD_IOCTL_SET_LOCK:
513 		/* Ignore, don't want to overwrite underlying socket lock. */
514 		return 0;
515 
516 	default: {
517 		const struct fd_op_vtable *vtable;
518 		void *core_obj;
519 
520 		core_obj = z_get_fd_obj_and_vtable(
521 				ctx->real_sock,
522 				(const struct fd_op_vtable **)&vtable,
523 				NULL);
524 		if (core_obj == NULL) {
525 			errno = EBADF;
526 			return -1;
527 		}
528 
529 		/* Pass the call to the core socket implementation. */
530 		return vtable->ioctl(core_obj, request, args);
531 	}
532 	}
533 
534 	return 0;
535 }
536 
537 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags)538 static int sendmsg_all(int sock, const struct msghdr *message, int flags)
539 {
540 	int ret, i;
541 	size_t offset = 0;
542 	size_t total_len = 0;
543 
544 	for (i = 0; i < message->msg_iovlen; i++) {
545 		total_len += message->msg_iov[i].iov_len;
546 	}
547 
548 	while (offset < total_len) {
549 		ret = zsock_sendmsg(sock, message, flags);
550 		if (ret < 0) {
551 			return -errno;
552 		}
553 
554 		offset += ret;
555 		if (offset >= total_len) {
556 			break;
557 		}
558 
559 		/* Update msghdr for the next iteration. */
560 		for (i = 0; i < message->msg_iovlen; i++) {
561 			if (ret < message->msg_iov[i].iov_len) {
562 				message->msg_iov[i].iov_len -= ret;
563 				message->msg_iov[i].iov_base =
564 					(uint8_t *)message->msg_iov[i].iov_base + ret;
565 				break;
566 			}
567 
568 			ret -= message->msg_iov[i].iov_len;
569 			message->msg_iov[i].iov_len = 0;
570 		}
571 	}
572 
573 	return total_len;
574 }
575 #endif /* !defined(CONFIG_NET_TEST) */
576 
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)577 static int websocket_prepare_and_send(struct websocket_context *ctx,
578 				      uint8_t *header, size_t header_len,
579 				      uint8_t *payload, size_t payload_len,
580 				      int32_t timeout)
581 {
582 	struct iovec io_vector[2];
583 	struct msghdr msg;
584 
585 	io_vector[0].iov_base = header;
586 	io_vector[0].iov_len = header_len;
587 	io_vector[1].iov_base = payload;
588 	io_vector[1].iov_len = payload_len;
589 
590 	memset(&msg, 0, sizeof(msg));
591 
592 	msg.msg_iov = io_vector;
593 	msg.msg_iovlen = ARRAY_SIZE(io_vector);
594 
595 	if (HEXDUMP_SENT_PACKETS) {
596 		LOG_HEXDUMP_DBG(header, header_len, "Header");
597 		if ((payload != NULL) && (payload_len > 0)) {
598 			LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
599 		} else {
600 			LOG_DBG("No payload");
601 		}
602 	}
603 
604 #if defined(CONFIG_NET_TEST)
605 	/* Simulate a case where the payload is split to two. The unit test
606 	 * does not set mask bit in this case.
607 	 */
608 	return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
609 #else
610 	k_timeout_t tout = K_FOREVER;
611 
612 	if (timeout != SYS_FOREVER_MS) {
613 		tout = K_MSEC(timeout);
614 	}
615 
616 	return sendmsg_all(ctx->real_sock, &msg,
617 			   K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
618 #endif /* CONFIG_NET_TEST */
619 }
620 
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)621 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
622 		       enum websocket_opcode opcode, bool mask, bool final,
623 		       int32_t timeout)
624 {
625 	struct websocket_context *ctx;
626 	uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
627 	uint8_t *data_to_send = (uint8_t *)payload;
628 	int ret;
629 
630 	if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
631 	    opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
632 	    opcode != WEBSOCKET_OPCODE_CONTINUE &&
633 	    opcode != WEBSOCKET_OPCODE_CLOSE &&
634 	    opcode != WEBSOCKET_OPCODE_PING &&
635 	    opcode != WEBSOCKET_OPCODE_PONG) {
636 		return -EINVAL;
637 	}
638 
639 	ctx = z_get_fd_obj(ws_sock, NULL, 0);
640 	if (ctx == NULL) {
641 		return -EBADF;
642 	}
643 
644 #if !defined(CONFIG_NET_TEST)
645 	/* Websocket unit test does not use context from pool but allocates
646 	 * its own, hence skip the check.
647 	 */
648 
649 	if (!PART_OF_ARRAY(contexts, ctx)) {
650 		return -ENOENT;
651 	}
652 #endif /* !defined(CONFIG_NET_TEST) */
653 
654 	NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
655 		mask, final ? "final" : "more");
656 
657 	memset(header, 0, sizeof(header));
658 
659 	/* Is this the last packet? */
660 	header[0] = final ? BIT(7) : 0;
661 
662 	/* Text, binary, ping, pong or close ? */
663 	header[0] |= opcode;
664 
665 	/* Masking */
666 	header[1] = mask ? BIT(7) : 0;
667 
668 	if (payload_len < 126) {
669 		header[1] |= payload_len;
670 	} else if (payload_len < 65536) {
671 		header[1] |= 126;
672 		header[2] = payload_len >> 8;
673 		header[3] = payload_len;
674 		hdr_len += 2;
675 	} else {
676 		header[1] |= 127;
677 		header[2] = 0;
678 		header[3] = 0;
679 		header[4] = 0;
680 		header[5] = 0;
681 		header[6] = payload_len >> 24;
682 		header[7] = payload_len >> 16;
683 		header[8] = payload_len >> 8;
684 		header[9] = payload_len;
685 		hdr_len += 8;
686 	}
687 
688 	/* Add masking value if needed */
689 	if (mask) {
690 		int i;
691 
692 		ctx->masking_value = sys_rand32_get();
693 
694 		header[hdr_len++] |= ctx->masking_value >> 24;
695 		header[hdr_len++] |= ctx->masking_value >> 16;
696 		header[hdr_len++] |= ctx->masking_value >> 8;
697 		header[hdr_len++] |= ctx->masking_value;
698 
699 		if ((payload != NULL) && (payload_len > 0)) {
700 			data_to_send = k_malloc(payload_len);
701 			if (!data_to_send) {
702 				return -ENOMEM;
703 			}
704 
705 			memcpy(data_to_send, payload, payload_len);
706 
707 			for (i = 0; i < payload_len; i++) {
708 				data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
709 			}
710 		}
711 	}
712 
713 	ret = websocket_prepare_and_send(ctx, header, hdr_len,
714 					 data_to_send, payload_len, timeout);
715 	if (ret < 0) {
716 		NET_DBG("Cannot send ws msg (%d)", -errno);
717 		goto quit;
718 	}
719 
720 quit:
721 	if (data_to_send != payload) {
722 		k_free(data_to_send);
723 	}
724 
725 	/* Do no math with 0 and error codes */
726 	if (ret <= 0) {
727 		return ret;
728 	}
729 
730 	return ret - hdr_len;
731 }
732 
websocket_opcode2flag(uint8_t data)733 static uint32_t websocket_opcode2flag(uint8_t data)
734 {
735 	switch (data & 0x0f) {
736 	case WEBSOCKET_OPCODE_DATA_TEXT:
737 		return WEBSOCKET_FLAG_TEXT;
738 	case WEBSOCKET_OPCODE_DATA_BINARY:
739 		return WEBSOCKET_FLAG_BINARY;
740 	case WEBSOCKET_OPCODE_CLOSE:
741 		return WEBSOCKET_FLAG_CLOSE;
742 	case WEBSOCKET_OPCODE_PING:
743 		return WEBSOCKET_FLAG_PING;
744 	case WEBSOCKET_OPCODE_PONG:
745 		return WEBSOCKET_FLAG_PONG;
746 	default:
747 		break;
748 	}
749 	return 0;
750 }
751 
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)752 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
753 {
754 	int len;
755 	uint8_t data;
756 	size_t parsed_count = 0;
757 
758 	do {
759 		if (parsed_count >= ctx->recv_buf.count) {
760 			return parsed_count;
761 		}
762 		if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
763 			data = ctx->recv_buf.buf[parsed_count++];
764 
765 			switch (ctx->parser_state) {
766 			case WEBSOCKET_PARSER_STATE_OPCODE:
767 				ctx->message_type = websocket_opcode2flag(data);
768 				if ((data & 0x80) != 0) {
769 					ctx->message_type |= WEBSOCKET_FLAG_FINAL;
770 				}
771 				ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
772 				break;
773 			case WEBSOCKET_PARSER_STATE_LENGTH:
774 				ctx->masked = (data & 0x80) != 0;
775 				len = data & 0x7f;
776 				if (len < 126) {
777 					ctx->message_len = len;
778 					if (ctx->masked) {
779 						ctx->masking_value = 0;
780 						ctx->parser_remaining = 4;
781 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
782 					} else {
783 						ctx->parser_remaining = ctx->message_len;
784 						ctx->parser_state =
785 							(ctx->parser_remaining == 0)
786 								? WEBSOCKET_PARSER_STATE_OPCODE
787 								: WEBSOCKET_PARSER_STATE_PAYLOAD;
788 					}
789 				} else {
790 					ctx->message_len = 0;
791 					ctx->parser_remaining = (len < 127) ? 2 : 8;
792 					ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
793 				}
794 				break;
795 			case WEBSOCKET_PARSER_STATE_EXT_LEN:
796 				ctx->parser_remaining--;
797 				ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
798 				if (ctx->parser_remaining == 0) {
799 					if (ctx->masked) {
800 						ctx->masking_value = 0;
801 						ctx->parser_remaining = 4;
802 						ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
803 					} else {
804 						ctx->parser_remaining = ctx->message_len;
805 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
806 					}
807 				}
808 				break;
809 			case WEBSOCKET_PARSER_STATE_MASK:
810 				ctx->parser_remaining--;
811 				ctx->masking_value |= (data << (ctx->parser_remaining * 8));
812 				if (ctx->parser_remaining == 0) {
813 					if (ctx->message_len == 0) {
814 						ctx->parser_remaining = 0;
815 						ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
816 					} else {
817 						ctx->parser_remaining = ctx->message_len;
818 						ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
819 					}
820 				}
821 				break;
822 			default:
823 				return -EFAULT;
824 			}
825 
826 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
827 			if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
828 			    ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
829 			     (ctx->message_len == 0))) {
830 				NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
831 					ctx->masked ? "" : "un",
832 					ctx->masked ? ctx->masking_value : 0, ctx->message_type,
833 					(size_t)ctx->message_len);
834 			}
835 #endif
836 		} else {
837 			size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
838 			size_t payload_in_recv_buf =
839 				MIN(remaining_in_recv_buf, ctx->parser_remaining);
840 			size_t free_in_payload_buf = payload->size - payload->count;
841 			size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
842 
843 			if (free_in_payload_buf == 0) {
844 				break;
845 			}
846 
847 			memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
848 			       ready_to_copy);
849 			parsed_count += ready_to_copy;
850 			payload->count += ready_to_copy;
851 			ctx->parser_remaining -= ready_to_copy;
852 			if (ctx->parser_remaining == 0) {
853 				ctx->parser_remaining = 0;
854 				ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
855 			}
856 		}
857 
858 	} while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
859 
860 	return parsed_count;
861 }
862 
863 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)864 static int wait_rx(int sock, int timeout)
865 {
866 	struct zsock_pollfd fds = {
867 		.fd = sock,
868 		.events = ZSOCK_POLLIN,
869 	};
870 	int ret;
871 
872 	ret = zsock_poll(&fds, 1, timeout);
873 	if (ret < 0) {
874 		return ret;
875 	}
876 
877 	if (ret == 0) {
878 		/* Timeout */
879 		return -EAGAIN;
880 	}
881 
882 	if (fds.revents & ZSOCK_POLLNVAL) {
883 		return -EBADF;
884 	}
885 
886 	if (fds.revents & ZSOCK_POLLERR) {
887 		return -EIO;
888 	}
889 
890 	return 0;
891 }
892 
timeout_to_ms(k_timeout_t * timeout)893 static int timeout_to_ms(k_timeout_t *timeout)
894 {
895 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
896 		return 0;
897 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
898 		return SYS_FOREVER_MS;
899 	} else {
900 		return k_ticks_to_ms_floor32(timeout->ticks);
901 	}
902 }
903 
904 #endif /* !defined(CONFIG_NET_TEST) */
905 
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)906 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
907 		       uint32_t *message_type, uint64_t *remaining, int32_t timeout)
908 {
909 	struct websocket_context *ctx;
910 	int ret;
911 	k_timepoint_t end;
912 	k_timeout_t tout = K_FOREVER;
913 	struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
914 
915 	if (timeout != SYS_FOREVER_MS) {
916 		tout = K_MSEC(timeout);
917 	}
918 
919 	if ((buf == NULL) || (buf_len == 0)) {
920 		return -EINVAL;
921 	}
922 
923 	end = sys_timepoint_calc(tout);
924 
925 #if defined(CONFIG_NET_TEST)
926 	struct test_data *test_data = z_get_fd_obj(ws_sock, NULL, 0);
927 
928 	if (test_data == NULL) {
929 		return -EBADF;
930 	}
931 
932 	ctx = test_data->ctx;
933 #else
934 	ctx = z_get_fd_obj(ws_sock, NULL, 0);
935 	if (ctx == NULL) {
936 		return -EBADF;
937 	}
938 
939 	if (!PART_OF_ARRAY(contexts, ctx)) {
940 		return -ENOENT;
941 	}
942 #endif /* CONFIG_NET_TEST */
943 
944 	do {
945 		size_t parsed_count;
946 
947 		if (ctx->recv_buf.count == 0) {
948 #if defined(CONFIG_NET_TEST)
949 			size_t input_len = MIN(ctx->recv_buf.size,
950 					       test_data->input_len - test_data->input_pos);
951 
952 			if (input_len > 0) {
953 				memcpy(ctx->recv_buf.buf,
954 				       &test_data->input_buf[test_data->input_pos], input_len);
955 				test_data->input_pos += input_len;
956 				ret = input_len;
957 			} else {
958 				/* emulate timeout */
959 				ret = -EAGAIN;
960 			}
961 #else
962 			tout = sys_timepoint_timeout(end);
963 
964 			ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
965 			if (ret == 0) {
966 				ret = recv(ctx->real_sock, ctx->recv_buf.buf,
967 					   ctx->recv_buf.size, MSG_DONTWAIT);
968 				if (ret < 0) {
969 					ret = -errno;
970 				}
971 			}
972 #endif /* CONFIG_NET_TEST */
973 
974 			if (ret < 0) {
975 				if ((ret == -EAGAIN) && (payload.count > 0)) {
976 					/* go to unmasking */
977 					break;
978 				}
979 				return ret;
980 			}
981 
982 			if (ret == 0) {
983 				/* Socket closed */
984 				return -ENOTCONN;
985 			}
986 
987 			ctx->recv_buf.count = ret;
988 
989 			NET_DBG("[%p] Received %d bytes", ctx, ret);
990 		}
991 
992 		ret = websocket_parse(ctx, &payload);
993 		if (ret < 0) {
994 			return ret;
995 		}
996 		parsed_count = ret;
997 
998 		if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
999 		    (payload.count >= payload.size)) {
1000 			if (remaining != NULL) {
1001 				*remaining = ctx->parser_remaining;
1002 			}
1003 			if (message_type != NULL) {
1004 				*message_type = ctx->message_type;
1005 			}
1006 
1007 			size_t left = ctx->recv_buf.count - parsed_count;
1008 
1009 			if (left > 0) {
1010 				memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1011 			}
1012 			ctx->recv_buf.count = left;
1013 			break;
1014 		}
1015 
1016 		ctx->recv_buf.count -= parsed_count;
1017 
1018 	} while (true);
1019 
1020 	/* Unmask the data */
1021 	if (ctx->masked) {
1022 		uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1023 		size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1024 
1025 		for (size_t i = 0; i < payload.count; i++) {
1026 			size_t m = data_buf_offset % 4;
1027 
1028 			payload.buf[i] ^= mask_as_bytes[3 - m];
1029 			data_buf_offset++;
1030 		}
1031 	}
1032 
1033 	return payload.count;
1034 }
1035 
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1036 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1037 			  size_t buf_len, int32_t timeout)
1038 {
1039 	int ret;
1040 
1041 	NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1042 
1043 	ret = websocket_send_msg(ctx->sock, buf, buf_len,
1044 				 WEBSOCKET_OPCODE_DATA_TEXT,
1045 				 true, true, timeout);
1046 	if (ret < 0) {
1047 		errno = -ret;
1048 		return -1;
1049 	}
1050 
1051 	NET_DBG("[%p] Sent %d bytes", ctx, ret);
1052 
1053 	return ret;
1054 }
1055 
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1056 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1057 			  size_t buf_len, int32_t timeout)
1058 {
1059 	uint32_t message_type;
1060 	uint64_t remaining;
1061 	int ret;
1062 
1063 	NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1064 
1065 	/* TODO: add support for recvmsg() so that we could return the
1066 	 *       websocket specific information in ancillary data.
1067 	 */
1068 	ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1069 				 &remaining, timeout);
1070 	if (ret < 0) {
1071 		if (ret == -ENOTCONN) {
1072 			ret = 0;
1073 		} else {
1074 			errno = -ret;
1075 			return -1;
1076 		}
1077 	}
1078 
1079 	NET_DBG("[%p] Received %d bytes", ctx, ret);
1080 
1081 	return ret;
1082 }
1083 
websocket_read_vmeth(void * obj,void * buffer,size_t count)1084 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1085 {
1086 	return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1087 }
1088 
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1089 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1090 				     size_t count)
1091 {
1092 	return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1093 }
1094 
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1095 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1096 				    int flags,
1097 				    const struct sockaddr *dest_addr,
1098 				    socklen_t addrlen)
1099 {
1100 	struct websocket_context *ctx = obj;
1101 	int32_t timeout = SYS_FOREVER_MS;
1102 
1103 	if (flags & ZSOCK_MSG_DONTWAIT) {
1104 		timeout = 0;
1105 	}
1106 
1107 	ARG_UNUSED(dest_addr);
1108 	ARG_UNUSED(addrlen);
1109 
1110 	return (ssize_t)websocket_send(ctx, buf, len, timeout);
1111 }
1112 
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1113 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1114 				      int flags, struct sockaddr *src_addr,
1115 				      socklen_t *addrlen)
1116 {
1117 	struct websocket_context *ctx = obj;
1118 	int32_t timeout = SYS_FOREVER_MS;
1119 
1120 	if (flags & ZSOCK_MSG_DONTWAIT) {
1121 		timeout = 0;
1122 	}
1123 
1124 	ARG_UNUSED(src_addr);
1125 	ARG_UNUSED(addrlen);
1126 
1127 	return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1128 }
1129 
1130 static const struct socket_op_vtable websocket_fd_op_vtable = {
1131 	.fd_vtable = {
1132 		.read = websocket_read_vmeth,
1133 		.write = websocket_write_vmeth,
1134 		.close = websocket_close_vmeth,
1135 		.ioctl = websocket_ioctl_vmeth,
1136 	},
1137 	.sendto = websocket_sendto_ctx,
1138 	.recvfrom = websocket_recvfrom_ctx,
1139 };
1140 
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1141 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1142 {
1143 	int i;
1144 
1145 	k_sem_take(&contexts_lock, K_FOREVER);
1146 
1147 	for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1148 		if (!websocket_context_is_used(&contexts[i])) {
1149 			continue;
1150 		}
1151 
1152 		k_mutex_lock(&contexts[i].lock, K_FOREVER);
1153 
1154 		cb(&contexts[i], user_data);
1155 
1156 		k_mutex_unlock(&contexts[i].lock);
1157 	}
1158 
1159 	k_sem_give(&contexts_lock);
1160 }
1161 
websocket_init(void)1162 void websocket_init(void)
1163 {
1164 	k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1165 }
1166