1 /** @file
2 * @brief Websocket client API
3 *
4 * An API for applications to setup a websocket connections.
5 */
6
7 /*
8 * Copyright (c) 2019 Intel Corporation
9 *
10 * SPDX-License-Identifier: Apache-2.0
11 */
12
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42
43 /* If you want to see the data that is being sent or received,
44 * then you can enable debugging and set the following variables to 1.
45 * This will print a lot of data so is not enabled by default.
46 */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51
52 static struct k_sem contexts_lock;
53
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 switch (opcode) {
63 case WEBSOCKET_OPCODE_DATA_TEXT:
64 return "TEXT";
65 case WEBSOCKET_OPCODE_DATA_BINARY:
66 return "BIN";
67 case WEBSOCKET_OPCODE_CONTINUE:
68 return "CONT";
69 case WEBSOCKET_OPCODE_CLOSE:
70 return "CLOSE";
71 case WEBSOCKET_OPCODE_PING:
72 return "PING";
73 case WEBSOCKET_OPCODE_PONG:
74 return "PONG";
75 default:
76 break;
77 }
78
79 return NULL;
80 }
81
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 int old_rc = atomic_inc(&ctx->refcount);
85
86 return old_rc + 1;
87 }
88
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 int old_rc = atomic_dec(&ctx->refcount);
92
93 if (old_rc != 1) {
94 return old_rc - 1;
95 }
96
97 return 0;
98 }
99
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 NET_ASSERT(ctx);
103
104 return !!atomic_get(&ctx->refcount);
105 }
106
websocket_get(void)107 static struct websocket_context *websocket_get(void)
108 {
109 struct websocket_context *ctx = NULL;
110 int i;
111
112 k_sem_take(&contexts_lock, K_FOREVER);
113
114 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
115 if (websocket_context_is_used(&contexts[i])) {
116 continue;
117 }
118
119 websocket_context_ref(&contexts[i]);
120 ctx = &contexts[i];
121 break;
122 }
123
124 k_sem_give(&contexts_lock);
125
126 return ctx;
127 }
128
websocket_find(int real_sock)129 static struct websocket_context *websocket_find(int real_sock)
130 {
131 struct websocket_context *ctx = NULL;
132 int i;
133
134 k_sem_take(&contexts_lock, K_FOREVER);
135
136 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
137 if (!websocket_context_is_used(&contexts[i])) {
138 continue;
139 }
140
141 if (contexts[i].real_sock != real_sock) {
142 continue;
143 }
144
145 ctx = &contexts[i];
146 break;
147 }
148
149 k_sem_give(&contexts_lock);
150
151 return ctx;
152 }
153
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)154 static void response_cb(struct http_response *rsp,
155 enum http_final_call final_data,
156 void *user_data)
157 {
158 struct websocket_context *ctx = user_data;
159
160 if (final_data == HTTP_DATA_MORE) {
161 NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
162 rsp->data_len);
163 ctx->all_received = false;
164 } else if (final_data == HTTP_DATA_FINAL) {
165 NET_DBG("[%p] All the data received (%zd bytes)", ctx,
166 rsp->data_len);
167 ctx->all_received = true;
168 }
169 }
170
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 size_t length)
173 {
174 struct http_request *req = CONTAINER_OF(parser,
175 struct http_request,
176 internal.parser);
177 struct websocket_context *ctx = req->internal.user_data;
178 const char *ws_accept_str = "Sec-WebSocket-Accept";
179 uint16_t len;
180
181 len = strlen(ws_accept_str);
182 if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 ctx->sec_accept_present = true;
184 }
185
186 if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 ctx->http_cb->on_header_field(parser, at, length);
188 }
189
190 return 0;
191 }
192
193 #define MAX_SEC_ACCEPT_LEN 32
194
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 size_t length)
197 {
198 struct http_request *req = CONTAINER_OF(parser,
199 struct http_request,
200 internal.parser);
201 struct websocket_context *ctx = req->internal.user_data;
202 char str[MAX_SEC_ACCEPT_LEN];
203
204 if (ctx->sec_accept_present) {
205 int ret;
206 size_t olen;
207
208 ctx->sec_accept_ok = false;
209 ctx->sec_accept_present = false;
210
211 ret = base64_encode(str, sizeof(str) - 1, &olen,
212 ctx->sec_accept_key,
213 WS_SHA1_OUTPUT_LEN);
214 if (ret == 0) {
215 if (strncmp(at, str, length)) {
216 NET_DBG("[%p] Security keys do not match "
217 "%s vs %s", ctx, str, at);
218 } else {
219 ctx->sec_accept_ok = true;
220 }
221 }
222 }
223
224 if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 ctx->http_cb->on_header_value(parser, at, length);
226 }
227
228 return 0;
229 }
230
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 int32_t timeout, void *user_data)
233 {
234 /* This is the expected Sec-WebSocket-Accept key. We are storing a
235 * pointer to this in ctx but the value is only used for the duration
236 * of this function call so there is no issue even if this variable
237 * is allocated from stack.
238 */
239 uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 struct http_parser_settings http_parser_settings;
241 struct websocket_context *ctx;
242 struct http_request req;
243 int ret, fd, key_len;
244 size_t olen;
245 char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 uint32_t rnd_value = sys_rand32_get();
247 char sec_ws_key[] =
248 "Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 char *headers[] = {
250 sec_ws_key,
251 "Upgrade: websocket\r\n",
252 "Connection: Upgrade\r\n",
253 "Sec-WebSocket-Version: 13\r\n",
254 NULL
255 };
256
257 fd = -1;
258
259 if (sock < 0 || wreq == NULL || wreq->host == NULL ||
260 wreq->url == NULL) {
261 return -EINVAL;
262 }
263
264 ctx = websocket_find(sock);
265 if (ctx) {
266 NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
267 sock);
268 return -EEXIST;
269 }
270
271 ctx = websocket_get();
272 if (!ctx) {
273 return -ENOENT;
274 }
275
276 ctx->real_sock = sock;
277 ctx->recv_buf.buf = wreq->tmp_buf;
278 ctx->recv_buf.size = wreq->tmp_buf_len;
279 ctx->sec_accept_key = sec_accept_key;
280 ctx->http_cb = wreq->http_cb;
281
282 mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
283 sec_accept_key);
284
285 ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
286 sizeof(sec_ws_key) -
287 sizeof("Sec-Websocket-Key: "),
288 &olen, sec_accept_key,
289 /* We are only interested in 16 first bytes so
290 * subtract 4 from the SHA-1 length
291 */
292 sizeof(sec_accept_key) - 4);
293 if (ret) {
294 NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
295 goto out;
296 }
297
298 if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
299 NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
300 olen + sizeof("Sec-Websocket-Key: ") + 2,
301 sizeof(sec_ws_key));
302 ret = -EMSGSIZE;
303 goto out;
304 }
305
306 memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
307 HTTP_CRLF, sizeof(HTTP_CRLF));
308
309 memset(&req, 0, sizeof(req));
310
311 req.method = HTTP_GET;
312 req.url = wreq->url;
313 req.host = wreq->host;
314 req.protocol = "HTTP/1.1";
315 req.header_fields = (const char **)headers;
316 req.optional_headers_cb = wreq->optional_headers_cb;
317 req.optional_headers = wreq->optional_headers;
318 req.response = response_cb;
319 req.http_cb = &http_parser_settings;
320 req.recv_buf = wreq->tmp_buf;
321 req.recv_buf_len = wreq->tmp_buf_len;
322
323 /* We need to catch the Sec-WebSocket-Accept field in order to verify
324 * that it contains the stuff that we sent in Sec-WebSocket-Key field
325 * so setup HTTP callbacks so that we will get the needed fields.
326 */
327 if (ctx->http_cb) {
328 memcpy(&http_parser_settings, ctx->http_cb,
329 sizeof(http_parser_settings));
330 } else {
331 memset(&http_parser_settings, 0, sizeof(http_parser_settings));
332 }
333
334 http_parser_settings.on_header_field = on_header_field;
335 http_parser_settings.on_header_value = on_header_value;
336
337 /* Pre-calculate the expected Sec-Websocket-Accept field */
338 key_len = MIN(sizeof(key_accept) - 1, olen);
339 strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
340 key_len);
341
342 olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
343 strncpy(key_accept + key_len, WS_MAGIC, olen);
344
345 /* This SHA-1 value is then checked when we receive the response */
346 mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
347
348 ret = http_client_req(sock, &req, timeout, ctx);
349 if (ret < 0) {
350 NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
351 wreq->host);
352 ret = -ECONNABORTED;
353 goto out;
354 }
355
356 if (!(ctx->all_received && ctx->sec_accept_ok)) {
357 NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
358 ctx->all_received, ctx->sec_accept_ok);
359 ret = -ECONNABORTED;
360 goto out;
361 }
362
363 ctx->user_data = user_data;
364
365 fd = zvfs_reserve_fd();
366 if (fd < 0) {
367 ret = -ENOSPC;
368 goto out;
369 }
370
371 ctx->sock = fd;
372 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
373 ZVFS_MODE_IFSOCK);
374
375 /* Call the user specified callback and if it accepts the connection
376 * then continue.
377 */
378 if (wreq->cb) {
379 ret = wreq->cb(fd, &req, user_data);
380 if (ret < 0) {
381 NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
382 goto out;
383 }
384 }
385
386 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
387
388 /* We will re-use the temp buffer in receive function if needed but
389 * in order that to work the amount of data in buffer must be set to 0
390 */
391 ctx->recv_buf.count = 0;
392
393 /* Init parser FSM */
394 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
395
396 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
397
398 return fd;
399
400 out:
401 if (fd >= 0) {
402 (void)zsock_close(fd);
403 }
404
405 websocket_context_unref(ctx);
406 return ret;
407 }
408
websocket_disconnect(int ws_sock)409 int websocket_disconnect(int ws_sock)
410 {
411 return zsock_close(ws_sock);
412 }
413
websocket_interal_disconnect(struct websocket_context * ctx)414 static int websocket_interal_disconnect(struct websocket_context *ctx)
415 {
416 int ret;
417
418 if (ctx == NULL) {
419 return -ENOENT;
420 }
421
422 NET_DBG("[%p] Disconnecting", ctx);
423
424 ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
425 true, true, SYS_FOREVER_MS);
426 if (ret < 0) {
427 NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
428 }
429
430 (void)sock_obj_core_dealloc(ctx->sock);
431
432 websocket_context_unref(ctx);
433
434 return ret;
435 }
436
websocket_close_vmeth(void * obj)437 static int websocket_close_vmeth(void *obj)
438 {
439 struct websocket_context *ctx = obj;
440 int ret;
441
442 ret = websocket_interal_disconnect(ctx);
443 if (ret < 0) {
444 /* Ignore error if we are not connected */
445 if (ret != -ENOTCONN) {
446 NET_DBG("[%p] Cannot close (%d)", obj, ret);
447
448 errno = -ret;
449 return -1;
450 }
451
452 ret = 0;
453 }
454
455 return ret;
456 }
457
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)458 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
459 int timeout)
460 {
461 int fd_backup[CONFIG_NET_SOCKETS_POLL_MAX];
462 const struct fd_op_vtable *vtable;
463 void *ctx;
464 int ret = 0;
465 int i;
466
467 /* Overwrite websocket file descriptors with underlying ones. */
468 for (i = 0; i < nfds; i++) {
469 fd_backup[i] = fds[i].fd;
470
471 ctx = zvfs_get_fd_obj(fds[i].fd,
472 (const struct fd_op_vtable *)
473 &websocket_fd_op_vtable,
474 0);
475 if (ctx == NULL) {
476 continue;
477 }
478
479 fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
480 }
481
482 /* Get offloaded sockets vtable. */
483 ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
484 (const struct fd_op_vtable **)&vtable,
485 NULL);
486 if (ctx == NULL) {
487 errno = EINVAL;
488 ret = -1;
489 goto exit;
490 }
491
492 ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
493 fds, nfds, timeout);
494
495 exit:
496 /* Restore original fds. */
497 for (i = 0; i < nfds; i++) {
498 fds[i].fd = fd_backup[i];
499 }
500
501 return ret;
502 }
503
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)504 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
505 {
506 struct websocket_context *ctx = obj;
507
508 switch (request) {
509 case ZFD_IOCTL_POLL_OFFLOAD: {
510 struct zsock_pollfd *fds;
511 int nfds;
512 int timeout;
513
514 fds = va_arg(args, struct zsock_pollfd *);
515 nfds = va_arg(args, int);
516 timeout = va_arg(args, int);
517
518 return websocket_poll_offload(fds, nfds, timeout);
519 }
520
521 case ZFD_IOCTL_SET_LOCK:
522 /* Ignore, don't want to overwrite underlying socket lock. */
523 return 0;
524
525 default: {
526 const struct fd_op_vtable *vtable;
527 void *core_obj;
528
529 core_obj = zvfs_get_fd_obj_and_vtable(
530 ctx->real_sock,
531 (const struct fd_op_vtable **)&vtable,
532 NULL);
533 if (core_obj == NULL) {
534 errno = EBADF;
535 return -1;
536 }
537
538 /* Pass the call to the core socket implementation. */
539 return vtable->ioctl(core_obj, request, args);
540 }
541 }
542
543 return 0;
544 }
545
546 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags,const k_timepoint_t req_end_timepoint)547 static int sendmsg_all(int sock, const struct msghdr *message, int flags,
548 const k_timepoint_t req_end_timepoint)
549 {
550 int ret, i;
551 size_t offset = 0;
552 size_t total_len = 0;
553
554 for (i = 0; i < message->msg_iovlen; i++) {
555 total_len += message->msg_iov[i].iov_len;
556 }
557
558 while (offset < total_len) {
559 ret = zsock_sendmsg(sock, message, flags);
560
561 if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
562 struct zsock_pollfd pfd;
563 int pollres;
564 k_ticks_t req_timeout_ticks =
565 sys_timepoint_timeout(req_end_timepoint).ticks;
566 int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
567
568 pfd.fd = sock;
569 pfd.events = ZSOCK_POLLOUT;
570 pollres = zsock_poll(&pfd, 1, req_timeout_ms);
571 if (pollres == 0) {
572 return -ETIMEDOUT;
573 } else if (pollres > 0) {
574 continue;
575 } else {
576 return -errno;
577 }
578 } else if (ret < 0) {
579 return -errno;
580 }
581
582 offset += ret;
583 if (offset >= total_len) {
584 break;
585 }
586
587 /* Update msghdr for the next iteration. */
588 for (i = 0; i < message->msg_iovlen; i++) {
589 if (ret < message->msg_iov[i].iov_len) {
590 message->msg_iov[i].iov_len -= ret;
591 message->msg_iov[i].iov_base =
592 (uint8_t *)message->msg_iov[i].iov_base + ret;
593 break;
594 }
595
596 ret -= message->msg_iov[i].iov_len;
597 message->msg_iov[i].iov_len = 0;
598 }
599 }
600
601 return total_len;
602 }
603 #endif /* !defined(CONFIG_NET_TEST) */
604
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)605 static int websocket_prepare_and_send(struct websocket_context *ctx,
606 uint8_t *header, size_t header_len,
607 uint8_t *payload, size_t payload_len,
608 int32_t timeout)
609 {
610 struct iovec io_vector[2];
611 struct msghdr msg;
612
613 io_vector[0].iov_base = header;
614 io_vector[0].iov_len = header_len;
615 io_vector[1].iov_base = payload;
616 io_vector[1].iov_len = payload_len;
617
618 memset(&msg, 0, sizeof(msg));
619
620 msg.msg_iov = io_vector;
621 msg.msg_iovlen = ARRAY_SIZE(io_vector);
622
623 if (HEXDUMP_SENT_PACKETS) {
624 LOG_HEXDUMP_DBG(header, header_len, "Header");
625 if ((payload != NULL) && (payload_len > 0)) {
626 LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
627 } else {
628 LOG_DBG("No payload");
629 }
630 }
631
632 #if defined(CONFIG_NET_TEST)
633 /* Simulate a case where the payload is split to two. The unit test
634 * does not set mask bit in this case.
635 */
636 return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
637 #else
638 k_timeout_t tout = K_FOREVER;
639
640 if (timeout != SYS_FOREVER_MS) {
641 tout = K_MSEC(timeout);
642 }
643
644 k_timeout_t req_timeout = K_MSEC(timeout);
645 k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
646
647 return sendmsg_all(ctx->real_sock, &msg,
648 K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0, req_end_timepoint);
649 #endif /* CONFIG_NET_TEST */
650 }
651
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)652 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
653 enum websocket_opcode opcode, bool mask, bool final,
654 int32_t timeout)
655 {
656 struct websocket_context *ctx;
657 uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
658 uint8_t *data_to_send = (uint8_t *)payload;
659 int ret;
660
661 if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
662 opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
663 opcode != WEBSOCKET_OPCODE_CONTINUE &&
664 opcode != WEBSOCKET_OPCODE_CLOSE &&
665 opcode != WEBSOCKET_OPCODE_PING &&
666 opcode != WEBSOCKET_OPCODE_PONG) {
667 return -EINVAL;
668 }
669
670 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
671 if (ctx == NULL) {
672 return -EBADF;
673 }
674
675 #if !defined(CONFIG_NET_TEST)
676 /* Websocket unit test does not use context from pool but allocates
677 * its own, hence skip the check.
678 */
679
680 if (!PART_OF_ARRAY(contexts, ctx)) {
681 return -ENOENT;
682 }
683 #endif /* !defined(CONFIG_NET_TEST) */
684
685 NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
686 mask, final ? "final" : "more");
687
688 memset(header, 0, sizeof(header));
689
690 /* Is this the last packet? */
691 header[0] = final ? BIT(7) : 0;
692
693 /* Text, binary, ping, pong or close ? */
694 header[0] |= opcode;
695
696 /* Masking */
697 header[1] = mask ? BIT(7) : 0;
698
699 if (payload_len < 126) {
700 header[1] |= payload_len;
701 } else if (payload_len < 65536) {
702 header[1] |= 126;
703 header[2] = payload_len >> 8;
704 header[3] = payload_len;
705 hdr_len += 2;
706 } else {
707 header[1] |= 127;
708 header[2] = 0;
709 header[3] = 0;
710 header[4] = 0;
711 header[5] = 0;
712 header[6] = payload_len >> 24;
713 header[7] = payload_len >> 16;
714 header[8] = payload_len >> 8;
715 header[9] = payload_len;
716 hdr_len += 8;
717 }
718
719 /* Add masking value if needed */
720 if (mask) {
721 int i;
722
723 ctx->masking_value = sys_rand32_get();
724
725 header[hdr_len++] |= ctx->masking_value >> 24;
726 header[hdr_len++] |= ctx->masking_value >> 16;
727 header[hdr_len++] |= ctx->masking_value >> 8;
728 header[hdr_len++] |= ctx->masking_value;
729
730 if ((payload != NULL) && (payload_len > 0)) {
731 data_to_send = k_malloc(payload_len);
732 if (!data_to_send) {
733 return -ENOMEM;
734 }
735
736 memcpy(data_to_send, payload, payload_len);
737
738 for (i = 0; i < payload_len; i++) {
739 data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
740 }
741 }
742 }
743
744 ret = websocket_prepare_and_send(ctx, header, hdr_len,
745 data_to_send, payload_len, timeout);
746 if (ret < 0) {
747 NET_DBG("Cannot send ws msg (%d)", -errno);
748 goto quit;
749 }
750
751 quit:
752 if (data_to_send != payload) {
753 k_free(data_to_send);
754 }
755
756 /* Do no math with 0 and error codes */
757 if (ret <= 0) {
758 return ret;
759 }
760
761 return ret - hdr_len;
762 }
763
websocket_opcode2flag(uint8_t data)764 static uint32_t websocket_opcode2flag(uint8_t data)
765 {
766 switch (data & 0x0f) {
767 case WEBSOCKET_OPCODE_DATA_TEXT:
768 return WEBSOCKET_FLAG_TEXT;
769 case WEBSOCKET_OPCODE_DATA_BINARY:
770 return WEBSOCKET_FLAG_BINARY;
771 case WEBSOCKET_OPCODE_CLOSE:
772 return WEBSOCKET_FLAG_CLOSE;
773 case WEBSOCKET_OPCODE_PING:
774 return WEBSOCKET_FLAG_PING;
775 case WEBSOCKET_OPCODE_PONG:
776 return WEBSOCKET_FLAG_PONG;
777 default:
778 break;
779 }
780 return 0;
781 }
782
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)783 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
784 {
785 int len;
786 uint8_t data;
787 size_t parsed_count = 0;
788
789 do {
790 if (parsed_count >= ctx->recv_buf.count) {
791 return parsed_count;
792 }
793 if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
794 data = ctx->recv_buf.buf[parsed_count++];
795
796 switch (ctx->parser_state) {
797 case WEBSOCKET_PARSER_STATE_OPCODE:
798 ctx->message_type = websocket_opcode2flag(data);
799 if ((data & 0x80) != 0) {
800 ctx->message_type |= WEBSOCKET_FLAG_FINAL;
801 }
802 ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
803 break;
804 case WEBSOCKET_PARSER_STATE_LENGTH:
805 ctx->masked = (data & 0x80) != 0;
806 len = data & 0x7f;
807 if (len < 126) {
808 ctx->message_len = len;
809 if (ctx->masked) {
810 ctx->masking_value = 0;
811 ctx->parser_remaining = 4;
812 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
813 } else {
814 ctx->parser_remaining = ctx->message_len;
815 ctx->parser_state =
816 (ctx->parser_remaining == 0)
817 ? WEBSOCKET_PARSER_STATE_OPCODE
818 : WEBSOCKET_PARSER_STATE_PAYLOAD;
819 }
820 } else {
821 ctx->message_len = 0;
822 ctx->parser_remaining = (len < 127) ? 2 : 8;
823 ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
824 }
825 break;
826 case WEBSOCKET_PARSER_STATE_EXT_LEN:
827 ctx->parser_remaining--;
828 ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
829 if (ctx->parser_remaining == 0) {
830 if (ctx->masked) {
831 ctx->masking_value = 0;
832 ctx->parser_remaining = 4;
833 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
834 } else {
835 ctx->parser_remaining = ctx->message_len;
836 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
837 }
838 }
839 break;
840 case WEBSOCKET_PARSER_STATE_MASK:
841 ctx->parser_remaining--;
842 ctx->masking_value |= (data << (ctx->parser_remaining * 8));
843 if (ctx->parser_remaining == 0) {
844 if (ctx->message_len == 0) {
845 ctx->parser_remaining = 0;
846 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
847 } else {
848 ctx->parser_remaining = ctx->message_len;
849 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
850 }
851 }
852 break;
853 default:
854 return -EFAULT;
855 }
856
857 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
858 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
859 ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
860 (ctx->message_len == 0))) {
861 NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
862 ctx->masked ? "" : "un",
863 ctx->masked ? ctx->masking_value : 0, ctx->message_type,
864 (size_t)ctx->message_len);
865 }
866 #endif
867 } else {
868 size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
869 size_t payload_in_recv_buf =
870 MIN(remaining_in_recv_buf, ctx->parser_remaining);
871 size_t free_in_payload_buf = payload->size - payload->count;
872 size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
873
874 if (free_in_payload_buf == 0) {
875 break;
876 }
877
878 memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
879 ready_to_copy);
880 parsed_count += ready_to_copy;
881 payload->count += ready_to_copy;
882 ctx->parser_remaining -= ready_to_copy;
883 if (ctx->parser_remaining == 0) {
884 ctx->parser_remaining = 0;
885 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
886 }
887 }
888
889 } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
890
891 return parsed_count;
892 }
893
894 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)895 static int wait_rx(int sock, int timeout)
896 {
897 struct zsock_pollfd fds = {
898 .fd = sock,
899 .events = ZSOCK_POLLIN,
900 };
901 int ret;
902
903 ret = zsock_poll(&fds, 1, timeout);
904 if (ret < 0) {
905 return ret;
906 }
907
908 if (ret == 0) {
909 /* Timeout */
910 return -EAGAIN;
911 }
912
913 if (fds.revents & ZSOCK_POLLNVAL) {
914 return -EBADF;
915 }
916
917 if (fds.revents & ZSOCK_POLLERR) {
918 return -EIO;
919 }
920
921 return 0;
922 }
923
timeout_to_ms(k_timeout_t * timeout)924 static int timeout_to_ms(k_timeout_t *timeout)
925 {
926 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
927 return 0;
928 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
929 return SYS_FOREVER_MS;
930 } else {
931 return k_ticks_to_ms_floor32(timeout->ticks);
932 }
933 }
934
935 #endif /* !defined(CONFIG_NET_TEST) */
936
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)937 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
938 uint32_t *message_type, uint64_t *remaining, int32_t timeout)
939 {
940 struct websocket_context *ctx;
941 int ret;
942 k_timepoint_t end;
943 k_timeout_t tout = K_FOREVER;
944 struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
945
946 if (timeout != SYS_FOREVER_MS) {
947 tout = K_MSEC(timeout);
948 }
949
950 if ((buf == NULL) || (buf_len == 0)) {
951 return -EINVAL;
952 }
953
954 end = sys_timepoint_calc(tout);
955
956 #if defined(CONFIG_NET_TEST)
957 struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
958
959 if (test_data == NULL) {
960 return -EBADF;
961 }
962
963 ctx = test_data->ctx;
964 #else
965 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
966 if (ctx == NULL) {
967 return -EBADF;
968 }
969
970 if (!PART_OF_ARRAY(contexts, ctx)) {
971 return -ENOENT;
972 }
973 #endif /* CONFIG_NET_TEST */
974
975 do {
976 size_t parsed_count;
977
978 if (ctx->recv_buf.count == 0) {
979 #if defined(CONFIG_NET_TEST)
980 size_t input_len = MIN(ctx->recv_buf.size,
981 test_data->input_len - test_data->input_pos);
982
983 if (input_len > 0) {
984 memcpy(ctx->recv_buf.buf,
985 &test_data->input_buf[test_data->input_pos], input_len);
986 test_data->input_pos += input_len;
987 ret = input_len;
988 } else {
989 /* emulate timeout */
990 ret = -EAGAIN;
991 }
992 #else
993 tout = sys_timepoint_timeout(end);
994
995 ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
996 if (ret == 0) {
997 ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
998 ctx->recv_buf.size, MSG_DONTWAIT);
999 if (ret < 0) {
1000 ret = -errno;
1001 }
1002 }
1003 #endif /* CONFIG_NET_TEST */
1004
1005 if (ret < 0) {
1006 if ((ret == -EAGAIN) && (payload.count > 0)) {
1007 /* go to unmasking */
1008 break;
1009 }
1010 return ret;
1011 }
1012
1013 if (ret == 0) {
1014 /* Socket closed */
1015 return -ENOTCONN;
1016 }
1017
1018 ctx->recv_buf.count = ret;
1019
1020 NET_DBG("[%p] Received %d bytes", ctx, ret);
1021 }
1022
1023 ret = websocket_parse(ctx, &payload);
1024 if (ret < 0) {
1025 return ret;
1026 }
1027 parsed_count = ret;
1028
1029 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1030 (payload.count >= payload.size)) {
1031 if (remaining != NULL) {
1032 *remaining = ctx->parser_remaining;
1033 }
1034 if (message_type != NULL) {
1035 *message_type = ctx->message_type;
1036 }
1037
1038 size_t left = ctx->recv_buf.count - parsed_count;
1039
1040 if (left > 0) {
1041 memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1042 }
1043 ctx->recv_buf.count = left;
1044 break;
1045 }
1046
1047 ctx->recv_buf.count -= parsed_count;
1048
1049 } while (true);
1050
1051 /* Unmask the data */
1052 if (ctx->masked) {
1053 uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1054 size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1055
1056 for (size_t i = 0; i < payload.count; i++) {
1057 size_t m = data_buf_offset % 4;
1058
1059 payload.buf[i] ^= mask_as_bytes[3 - m];
1060 data_buf_offset++;
1061 }
1062 }
1063
1064 return payload.count;
1065 }
1066
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1067 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1068 size_t buf_len, int32_t timeout)
1069 {
1070 int ret;
1071
1072 NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1073
1074 ret = websocket_send_msg(ctx->sock, buf, buf_len,
1075 WEBSOCKET_OPCODE_DATA_TEXT,
1076 true, true, timeout);
1077 if (ret < 0) {
1078 errno = -ret;
1079 return -1;
1080 }
1081
1082 NET_DBG("[%p] Sent %d bytes", ctx, ret);
1083
1084 sock_obj_core_update_send_stats(ctx->sock, ret);
1085
1086 return ret;
1087 }
1088
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1089 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1090 size_t buf_len, int32_t timeout)
1091 {
1092 uint32_t message_type;
1093 uint64_t remaining;
1094 int ret;
1095
1096 NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1097
1098 /* TODO: add support for recvmsg() so that we could return the
1099 * websocket specific information in ancillary data.
1100 */
1101 ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1102 &remaining, timeout);
1103 if (ret < 0) {
1104 if (ret == -ENOTCONN) {
1105 ret = 0;
1106 } else {
1107 errno = -ret;
1108 return -1;
1109 }
1110 }
1111
1112 NET_DBG("[%p] Received %d bytes", ctx, ret);
1113
1114 sock_obj_core_update_recv_stats(ctx->sock, ret);
1115
1116 return ret;
1117 }
1118
websocket_read_vmeth(void * obj,void * buffer,size_t count)1119 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1120 {
1121 return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1122 }
1123
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1124 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1125 size_t count)
1126 {
1127 return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1128 }
1129
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1130 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1131 int flags,
1132 const struct sockaddr *dest_addr,
1133 socklen_t addrlen)
1134 {
1135 struct websocket_context *ctx = obj;
1136 int32_t timeout = SYS_FOREVER_MS;
1137
1138 if (flags & ZSOCK_MSG_DONTWAIT) {
1139 timeout = 0;
1140 }
1141
1142 ARG_UNUSED(dest_addr);
1143 ARG_UNUSED(addrlen);
1144
1145 return (ssize_t)websocket_send(ctx, buf, len, timeout);
1146 }
1147
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1148 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1149 int flags, struct sockaddr *src_addr,
1150 socklen_t *addrlen)
1151 {
1152 struct websocket_context *ctx = obj;
1153 int32_t timeout = SYS_FOREVER_MS;
1154
1155 if (flags & ZSOCK_MSG_DONTWAIT) {
1156 timeout = 0;
1157 }
1158
1159 ARG_UNUSED(src_addr);
1160 ARG_UNUSED(addrlen);
1161
1162 return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1163 }
1164
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1165 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1166 {
1167 struct websocket_context *ctx;
1168 int ret, fd;
1169
1170 if (sock < 0) {
1171 return -EINVAL;
1172 }
1173
1174 ctx = websocket_find(sock);
1175 if (ctx) {
1176 NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1177 return -EEXIST;
1178 }
1179
1180 ctx = websocket_get();
1181 if (!ctx) {
1182 return -ENOENT;
1183 }
1184
1185 ctx->real_sock = sock;
1186 ctx->recv_buf.buf = recv_buf;
1187 ctx->recv_buf.size = recv_buf_len;
1188
1189 fd = zvfs_reserve_fd();
1190 if (fd < 0) {
1191 ret = -ENOSPC;
1192 goto out;
1193 }
1194
1195 ctx->sock = fd;
1196 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1197 ZVFS_MODE_IFSOCK);
1198
1199 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1200
1201 ctx->recv_buf.count = 0;
1202 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1203
1204 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
1205
1206 return fd;
1207
1208 out:
1209 websocket_context_unref(ctx);
1210
1211 return ret;
1212 }
1213
websocket_search(int sock)1214 static struct websocket_context *websocket_search(int sock)
1215 {
1216 struct websocket_context *ctx = NULL;
1217 int i;
1218
1219 k_sem_take(&contexts_lock, K_FOREVER);
1220
1221 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1222 if (!websocket_context_is_used(&contexts[i])) {
1223 continue;
1224 }
1225
1226 if (contexts[i].sock != sock) {
1227 continue;
1228 }
1229
1230 ctx = &contexts[i];
1231 break;
1232 }
1233
1234 k_sem_give(&contexts_lock);
1235
1236 return ctx;
1237 }
1238
websocket_unregister(int sock)1239 int websocket_unregister(int sock)
1240 {
1241 struct websocket_context *ctx;
1242
1243 if (sock < 0) {
1244 return -EINVAL;
1245 }
1246
1247 ctx = websocket_search(sock);
1248 if (ctx == NULL) {
1249 NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1250 return -ENOENT;
1251 }
1252
1253 if (ctx->real_sock < 0) {
1254 return -EALREADY;
1255 }
1256
1257 (void)zsock_close(sock);
1258 (void)zsock_close(ctx->real_sock);
1259
1260 ctx->real_sock = -1;
1261 ctx->sock = -1;
1262
1263 return 0;
1264 }
1265
1266 static const struct socket_op_vtable websocket_fd_op_vtable = {
1267 .fd_vtable = {
1268 .read = websocket_read_vmeth,
1269 .write = websocket_write_vmeth,
1270 .close = websocket_close_vmeth,
1271 .ioctl = websocket_ioctl_vmeth,
1272 },
1273 .sendto = websocket_sendto_ctx,
1274 .recvfrom = websocket_recvfrom_ctx,
1275 };
1276
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1277 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1278 {
1279 int i;
1280
1281 k_sem_take(&contexts_lock, K_FOREVER);
1282
1283 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1284 if (!websocket_context_is_used(&contexts[i])) {
1285 continue;
1286 }
1287
1288 k_mutex_lock(&contexts[i].lock, K_FOREVER);
1289
1290 cb(&contexts[i], user_data);
1291
1292 k_mutex_unlock(&contexts[i].lock);
1293 }
1294
1295 k_sem_give(&contexts_lock);
1296 }
1297
websocket_init(void)1298 void websocket_init(void)
1299 {
1300 k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1301 }
1302