1 /** @file
2 * @brief Websocket client API
3 *
4 * An API for applications to setup a websocket connections.
5 */
6
7 /*
8 * Copyright (c) 2019 Intel Corporation
9 *
10 * SPDX-License-Identifier: Apache-2.0
11 */
12
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42
43 /* If you want to see the data that is being sent or received,
44 * then you can enable debugging and set the following variables to 1.
45 * This will print a lot of data so is not enabled by default.
46 */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51
52 static struct k_sem contexts_lock;
53
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 switch (opcode) {
63 case WEBSOCKET_OPCODE_DATA_TEXT:
64 return "TEXT";
65 case WEBSOCKET_OPCODE_DATA_BINARY:
66 return "BIN";
67 case WEBSOCKET_OPCODE_CONTINUE:
68 return "CONT";
69 case WEBSOCKET_OPCODE_CLOSE:
70 return "CLOSE";
71 case WEBSOCKET_OPCODE_PING:
72 return "PING";
73 case WEBSOCKET_OPCODE_PONG:
74 return "PONG";
75 default:
76 break;
77 }
78
79 return NULL;
80 }
81
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 int old_rc = atomic_inc(&ctx->refcount);
85
86 return old_rc + 1;
87 }
88
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 int old_rc = atomic_dec(&ctx->refcount);
92
93 if (old_rc != 1) {
94 return old_rc - 1;
95 }
96
97 return 0;
98 }
99
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 return !!atomic_get(&ctx->refcount);
103 }
104
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 struct websocket_context *ctx = NULL;
108 int i;
109
110 k_sem_take(&contexts_lock, K_FOREVER);
111
112 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 if (websocket_context_is_used(&contexts[i])) {
114 continue;
115 }
116
117 websocket_context_ref(&contexts[i]);
118 ctx = &contexts[i];
119 break;
120 }
121
122 k_sem_give(&contexts_lock);
123
124 return ctx;
125 }
126
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 struct websocket_context *ctx = NULL;
130 int i;
131
132 k_sem_take(&contexts_lock, K_FOREVER);
133
134 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 if (!websocket_context_is_used(&contexts[i])) {
136 continue;
137 }
138
139 if (contexts[i].real_sock != real_sock) {
140 continue;
141 }
142
143 ctx = &contexts[i];
144 break;
145 }
146
147 k_sem_give(&contexts_lock);
148
149 return ctx;
150 }
151
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static void response_cb(struct http_response *rsp,
153 enum http_final_call final_data,
154 void *user_data)
155 {
156 struct websocket_context *ctx = user_data;
157
158 if (final_data == HTTP_DATA_MORE) {
159 NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 rsp->data_len);
161 ctx->all_received = false;
162 } else if (final_data == HTTP_DATA_FINAL) {
163 NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 rsp->data_len);
165 ctx->all_received = true;
166 }
167 }
168
on_header_field(struct http_parser * parser,const char * at,size_t length)169 static int on_header_field(struct http_parser *parser, const char *at,
170 size_t length)
171 {
172 struct http_request *req = CONTAINER_OF(parser,
173 struct http_request,
174 internal.parser);
175 struct websocket_context *ctx = req->internal.user_data;
176 const char *ws_accept_str = "Sec-WebSocket-Accept";
177 uint16_t len;
178
179 len = strlen(ws_accept_str);
180 if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
181 ctx->sec_accept_present = true;
182 }
183
184 if (ctx->http_cb && ctx->http_cb->on_header_field) {
185 ctx->http_cb->on_header_field(parser, at, length);
186 }
187
188 return 0;
189 }
190
191 #define MAX_SEC_ACCEPT_LEN 32
192
on_header_value(struct http_parser * parser,const char * at,size_t length)193 static int on_header_value(struct http_parser *parser, const char *at,
194 size_t length)
195 {
196 struct http_request *req = CONTAINER_OF(parser,
197 struct http_request,
198 internal.parser);
199 struct websocket_context *ctx = req->internal.user_data;
200 char str[MAX_SEC_ACCEPT_LEN];
201
202 if (ctx->sec_accept_present) {
203 int ret;
204 size_t olen;
205
206 ctx->sec_accept_ok = false;
207 ctx->sec_accept_present = false;
208
209 ret = base64_encode(str, sizeof(str) - 1, &olen,
210 ctx->sec_accept_key,
211 WS_SHA1_OUTPUT_LEN);
212 if (ret == 0) {
213 if (strncmp(at, str, length)) {
214 NET_DBG("[%p] Security keys do not match "
215 "%s vs %s", ctx, str, at);
216 } else {
217 ctx->sec_accept_ok = true;
218 }
219 }
220 }
221
222 if (ctx->http_cb && ctx->http_cb->on_header_value) {
223 ctx->http_cb->on_header_value(parser, at, length);
224 }
225
226 return 0;
227 }
228
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)229 int websocket_connect(int sock, struct websocket_request *wreq,
230 int32_t timeout, void *user_data)
231 {
232 /* This is the expected Sec-WebSocket-Accept key. We are storing a
233 * pointer to this in ctx but the value is only used for the duration
234 * of this function call so there is no issue even if this variable
235 * is allocated from stack.
236 */
237 uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
238 struct http_parser_settings http_parser_settings;
239 struct websocket_context *ctx;
240 struct http_request req;
241 int ret, fd, key_len;
242 size_t olen;
243 char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
244 uint32_t rnd_value = sys_rand32_get();
245 char sec_ws_key[] =
246 "Sec-WebSocket-Key: 0123456789012345678901==\r\n";
247 char *headers[] = {
248 sec_ws_key,
249 "Upgrade: websocket\r\n",
250 "Connection: Upgrade\r\n",
251 "Sec-WebSocket-Version: 13\r\n",
252 NULL
253 };
254
255 fd = -1;
256
257 if (sock < 0 || wreq == NULL || wreq->host == NULL ||
258 wreq->url == NULL) {
259 return -EINVAL;
260 }
261
262 ctx = websocket_find(sock);
263 if (ctx) {
264 NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
265 sock);
266 return -EEXIST;
267 }
268
269 ctx = websocket_get();
270 if (!ctx) {
271 return -ENOENT;
272 }
273
274 ctx->real_sock = sock;
275 ctx->recv_buf.buf = wreq->tmp_buf;
276 ctx->recv_buf.size = wreq->tmp_buf_len;
277 ctx->sec_accept_key = sec_accept_key;
278 ctx->http_cb = wreq->http_cb;
279 ctx->is_client = 1;
280
281 mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
282 sec_accept_key);
283
284 ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
285 sizeof(sec_ws_key) -
286 sizeof("Sec-Websocket-Key: "),
287 &olen, sec_accept_key,
288 /* We are only interested in 16 first bytes so
289 * subtract 4 from the SHA-1 length
290 */
291 sizeof(sec_accept_key) - 4);
292 if (ret) {
293 NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
294 goto out;
295 }
296
297 if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
298 NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
299 olen + sizeof("Sec-Websocket-Key: ") + 2,
300 sizeof(sec_ws_key));
301 ret = -EMSGSIZE;
302 goto out;
303 }
304
305 memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
306 HTTP_CRLF, sizeof(HTTP_CRLF));
307
308 memset(&req, 0, sizeof(req));
309
310 req.method = HTTP_GET;
311 req.url = wreq->url;
312 req.host = wreq->host;
313 req.protocol = "HTTP/1.1";
314 req.header_fields = (const char **)headers;
315 req.optional_headers_cb = wreq->optional_headers_cb;
316 req.optional_headers = wreq->optional_headers;
317 req.response = response_cb;
318 req.http_cb = &http_parser_settings;
319 req.recv_buf = wreq->tmp_buf;
320 req.recv_buf_len = wreq->tmp_buf_len;
321
322 /* We need to catch the Sec-WebSocket-Accept field in order to verify
323 * that it contains the stuff that we sent in Sec-WebSocket-Key field
324 * so setup HTTP callbacks so that we will get the needed fields.
325 */
326 if (ctx->http_cb) {
327 memcpy(&http_parser_settings, ctx->http_cb,
328 sizeof(http_parser_settings));
329 } else {
330 memset(&http_parser_settings, 0, sizeof(http_parser_settings));
331 }
332
333 http_parser_settings.on_header_field = on_header_field;
334 http_parser_settings.on_header_value = on_header_value;
335
336 /* Pre-calculate the expected Sec-Websocket-Accept field */
337 key_len = MIN(sizeof(key_accept) - 1, olen);
338 strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
339 key_len);
340
341 olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
342 strncpy(key_accept + key_len, WS_MAGIC, olen);
343
344 /* This SHA-1 value is then checked when we receive the response */
345 mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
346
347 ret = http_client_req(sock, &req, timeout, ctx);
348 if (ret < 0) {
349 NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
350 wreq->host);
351 ret = -ECONNABORTED;
352 goto out;
353 }
354
355 if (!(ctx->all_received && ctx->sec_accept_ok)) {
356 NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
357 ctx->all_received, ctx->sec_accept_ok);
358 ret = -ECONNABORTED;
359 goto out;
360 }
361
362 ctx->user_data = user_data;
363
364 fd = zvfs_reserve_fd();
365 if (fd < 0) {
366 ret = -ENOSPC;
367 goto out;
368 }
369
370 ctx->sock = fd;
371 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
372 ZVFS_MODE_IFSOCK);
373
374 /* Call the user specified callback and if it accepts the connection
375 * then continue.
376 */
377 if (wreq->cb) {
378 ret = wreq->cb(fd, &req, user_data);
379 if (ret < 0) {
380 NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
381 goto out;
382 }
383 }
384
385 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
386
387 /* We will re-use the temp buffer in receive function if needed but
388 * in order that to work the amount of data in buffer must be set to 0
389 */
390 ctx->recv_buf.count = 0;
391
392 /* Init parser FSM */
393 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
394
395 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
396
397 return fd;
398
399 out:
400 if (fd >= 0) {
401 (void)zsock_close(fd);
402 }
403
404 websocket_context_unref(ctx);
405 return ret;
406 }
407
websocket_disconnect(int ws_sock)408 int websocket_disconnect(int ws_sock)
409 {
410 return zsock_close(ws_sock);
411 }
412
websocket_interal_disconnect(struct websocket_context * ctx)413 static int websocket_interal_disconnect(struct websocket_context *ctx)
414 {
415 int ret;
416
417 if (ctx == NULL) {
418 return -ENOENT;
419 }
420
421 NET_DBG("[%p] Disconnecting", ctx);
422
423 ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
424 true, true, SYS_FOREVER_MS);
425 if (ret < 0) {
426 NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
427 }
428
429 (void)sock_obj_core_dealloc(ctx->sock);
430
431 websocket_context_unref(ctx);
432
433 return ret;
434 }
435
websocket_close_vmeth(void * obj)436 static int websocket_close_vmeth(void *obj)
437 {
438 struct websocket_context *ctx = obj;
439 int ret;
440
441 ret = websocket_interal_disconnect(ctx);
442 if (ret < 0) {
443 /* Ignore error if we are not connected */
444 if (ret != -ENOTCONN) {
445 NET_DBG("[%p] Cannot close (%d)", obj, ret);
446
447 errno = -ret;
448 return -1;
449 }
450
451 ret = 0;
452 }
453
454 return ret;
455 }
456
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)457 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
458 int timeout)
459 {
460 int fd_backup[CONFIG_ZVFS_POLL_MAX];
461 const struct fd_op_vtable *vtable;
462 void *ctx;
463 int ret = 0;
464 int i;
465
466 /* Overwrite websocket file descriptors with underlying ones. */
467 for (i = 0; i < nfds; i++) {
468 fd_backup[i] = fds[i].fd;
469
470 ctx = zvfs_get_fd_obj(fds[i].fd,
471 (const struct fd_op_vtable *)
472 &websocket_fd_op_vtable,
473 0);
474 if (ctx == NULL) {
475 continue;
476 }
477
478 fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
479 }
480
481 /* Get offloaded sockets vtable. */
482 ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
483 (const struct fd_op_vtable **)&vtable,
484 NULL);
485 if (ctx == NULL) {
486 errno = EINVAL;
487 ret = -1;
488 goto exit;
489 }
490
491 ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
492 fds, nfds, timeout);
493
494 exit:
495 /* Restore original fds. */
496 for (i = 0; i < nfds; i++) {
497 fds[i].fd = fd_backup[i];
498 }
499
500 return ret;
501 }
502
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)503 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
504 {
505 struct websocket_context *ctx = obj;
506
507 switch (request) {
508 case ZFD_IOCTL_POLL_OFFLOAD: {
509 struct zsock_pollfd *fds;
510 int nfds;
511 int timeout;
512
513 fds = va_arg(args, struct zsock_pollfd *);
514 nfds = va_arg(args, int);
515 timeout = va_arg(args, int);
516
517 return websocket_poll_offload(fds, nfds, timeout);
518 }
519
520 case ZFD_IOCTL_SET_LOCK:
521 /* Ignore, don't want to overwrite underlying socket lock. */
522 return 0;
523
524 default: {
525 const struct fd_op_vtable *vtable;
526 void *core_obj;
527
528 core_obj = zvfs_get_fd_obj_and_vtable(
529 ctx->real_sock,
530 (const struct fd_op_vtable **)&vtable,
531 NULL);
532 if (core_obj == NULL) {
533 errno = EBADF;
534 return -1;
535 }
536
537 /* Pass the call to the core socket implementation. */
538 return vtable->ioctl(core_obj, request, args);
539 }
540 }
541
542 return 0;
543 }
544
545 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags,const k_timepoint_t req_end_timepoint)546 static int sendmsg_all(int sock, const struct msghdr *message, int flags,
547 const k_timepoint_t req_end_timepoint)
548 {
549 int ret, i;
550 size_t offset = 0;
551 size_t total_len = 0;
552
553 for (i = 0; i < message->msg_iovlen; i++) {
554 total_len += message->msg_iov[i].iov_len;
555 }
556
557 while (offset < total_len) {
558 ret = zsock_sendmsg(sock, message, flags);
559
560 if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
561 struct zsock_pollfd pfd;
562 int pollres;
563 k_ticks_t req_timeout_ticks =
564 sys_timepoint_timeout(req_end_timepoint).ticks;
565 int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
566
567 pfd.fd = sock;
568 pfd.events = ZSOCK_POLLOUT;
569 pollres = zsock_poll(&pfd, 1, req_timeout_ms);
570 if (pollres == 0) {
571 return -ETIMEDOUT;
572 } else if (pollres > 0) {
573 continue;
574 } else {
575 return -errno;
576 }
577 } else if (ret < 0) {
578 return -errno;
579 }
580
581 offset += ret;
582 if (offset >= total_len) {
583 break;
584 }
585
586 /* Update msghdr for the next iteration. */
587 for (i = 0; i < message->msg_iovlen; i++) {
588 if (ret < message->msg_iov[i].iov_len) {
589 message->msg_iov[i].iov_len -= ret;
590 message->msg_iov[i].iov_base =
591 (uint8_t *)message->msg_iov[i].iov_base + ret;
592 break;
593 }
594
595 ret -= message->msg_iov[i].iov_len;
596 message->msg_iov[i].iov_len = 0;
597 }
598 }
599
600 return total_len;
601 }
602 #endif /* !defined(CONFIG_NET_TEST) */
603
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)604 static int websocket_prepare_and_send(struct websocket_context *ctx,
605 uint8_t *header, size_t header_len,
606 uint8_t *payload, size_t payload_len,
607 int32_t timeout)
608 {
609 struct iovec io_vector[2];
610 struct msghdr msg;
611
612 io_vector[0].iov_base = header;
613 io_vector[0].iov_len = header_len;
614 io_vector[1].iov_base = payload;
615 io_vector[1].iov_len = payload_len;
616
617 memset(&msg, 0, sizeof(msg));
618
619 msg.msg_iov = io_vector;
620 msg.msg_iovlen = ARRAY_SIZE(io_vector);
621
622 if (HEXDUMP_SENT_PACKETS) {
623 LOG_HEXDUMP_DBG(header, header_len, "Header");
624 if ((payload != NULL) && (payload_len > 0)) {
625 LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
626 } else {
627 LOG_DBG("No payload");
628 }
629 }
630
631 #if defined(CONFIG_NET_TEST)
632 /* Simulate a case where the payload is split to two. The unit test
633 * does not set mask bit in this case.
634 */
635 return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
636 #else
637 k_timeout_t tout = K_FOREVER;
638
639 if (timeout != SYS_FOREVER_MS) {
640 tout = K_MSEC(timeout);
641 }
642
643 k_timeout_t req_timeout = K_MSEC(timeout);
644 k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
645
646 return sendmsg_all(ctx->real_sock, &msg,
647 K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
648 req_end_timepoint);
649 #endif /* CONFIG_NET_TEST */
650 }
651
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)652 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
653 enum websocket_opcode opcode, bool mask, bool final,
654 int32_t timeout)
655 {
656 struct websocket_context *ctx;
657 uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
658 uint8_t *data_to_send = (uint8_t *)payload;
659 int ret;
660
661 if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
662 opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
663 opcode != WEBSOCKET_OPCODE_CONTINUE &&
664 opcode != WEBSOCKET_OPCODE_CLOSE &&
665 opcode != WEBSOCKET_OPCODE_PING &&
666 opcode != WEBSOCKET_OPCODE_PONG) {
667 return -EINVAL;
668 }
669
670 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
671 if (ctx == NULL) {
672 return -EBADF;
673 }
674
675 #if !defined(CONFIG_NET_TEST)
676 /* Websocket unit test does not use context from pool but allocates
677 * its own, hence skip the check.
678 */
679
680 if (!PART_OF_ARRAY(contexts, ctx)) {
681 return -ENOENT;
682 }
683 #endif /* !defined(CONFIG_NET_TEST) */
684
685 NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
686 mask, final ? "final" : "more");
687
688 memset(header, 0, sizeof(header));
689
690 /* Is this the last packet? */
691 header[0] = final ? BIT(7) : 0;
692
693 /* Text, binary, ping, pong or close ? */
694 header[0] |= opcode;
695
696 /* Masking */
697 header[1] = mask ? BIT(7) : 0;
698
699 if (payload_len < 126) {
700 header[1] |= payload_len;
701 } else if (payload_len < 65536) {
702 header[1] |= 126;
703 header[2] = payload_len >> 8;
704 header[3] = payload_len;
705 hdr_len += 2;
706 } else {
707 header[1] |= 127;
708 header[2] = 0;
709 header[3] = 0;
710 header[4] = 0;
711 header[5] = 0;
712 header[6] = payload_len >> 24;
713 header[7] = payload_len >> 16;
714 header[8] = payload_len >> 8;
715 header[9] = payload_len;
716 hdr_len += 8;
717 }
718
719 /* Add masking value if needed */
720 if (mask) {
721 int i;
722
723 ctx->masking_value = sys_rand32_get();
724
725 header[hdr_len++] |= ctx->masking_value >> 24;
726 header[hdr_len++] |= ctx->masking_value >> 16;
727 header[hdr_len++] |= ctx->masking_value >> 8;
728 header[hdr_len++] |= ctx->masking_value;
729
730 if ((payload != NULL) && (payload_len > 0)) {
731 data_to_send = k_malloc(payload_len);
732 if (!data_to_send) {
733 return -ENOMEM;
734 }
735
736 memcpy(data_to_send, payload, payload_len);
737
738 for (i = 0; i < payload_len; i++) {
739 data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
740 }
741 }
742 }
743
744 ret = websocket_prepare_and_send(ctx, header, hdr_len,
745 data_to_send, payload_len, timeout);
746 if (ret < 0) {
747 NET_DBG("Cannot send ws msg (%d)", -errno);
748 goto quit;
749 }
750
751 quit:
752 if (data_to_send != payload) {
753 k_free(data_to_send);
754 }
755
756 /* Do no math with 0 and error codes */
757 if (ret <= 0) {
758 return ret;
759 }
760
761 return ret - hdr_len;
762 }
763
websocket_opcode2flag(uint8_t data)764 static uint32_t websocket_opcode2flag(uint8_t data)
765 {
766 switch (data & 0x0f) {
767 case WEBSOCKET_OPCODE_DATA_TEXT:
768 return WEBSOCKET_FLAG_TEXT;
769 case WEBSOCKET_OPCODE_DATA_BINARY:
770 return WEBSOCKET_FLAG_BINARY;
771 case WEBSOCKET_OPCODE_CLOSE:
772 return WEBSOCKET_FLAG_CLOSE;
773 case WEBSOCKET_OPCODE_PING:
774 return WEBSOCKET_FLAG_PING;
775 case WEBSOCKET_OPCODE_PONG:
776 return WEBSOCKET_FLAG_PONG;
777 default:
778 break;
779 }
780 return 0;
781 }
782
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)783 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
784 {
785 int len;
786 uint8_t data;
787 size_t parsed_count = 0;
788
789 do {
790 if (parsed_count >= ctx->recv_buf.count) {
791 return parsed_count;
792 }
793 if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
794 data = ctx->recv_buf.buf[parsed_count++];
795
796 switch (ctx->parser_state) {
797 case WEBSOCKET_PARSER_STATE_OPCODE:
798 ctx->message_type = websocket_opcode2flag(data);
799 if ((data & 0x80) != 0) {
800 ctx->message_type |= WEBSOCKET_FLAG_FINAL;
801 }
802 ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
803 break;
804 case WEBSOCKET_PARSER_STATE_LENGTH:
805 ctx->masked = (data & 0x80) != 0;
806 len = data & 0x7f;
807 if (len < 126) {
808 ctx->message_len = len;
809 if (ctx->masked) {
810 ctx->masking_value = 0;
811 ctx->parser_remaining = 4;
812 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
813 } else {
814 ctx->parser_remaining = ctx->message_len;
815 ctx->parser_state =
816 (ctx->parser_remaining == 0)
817 ? WEBSOCKET_PARSER_STATE_OPCODE
818 : WEBSOCKET_PARSER_STATE_PAYLOAD;
819 }
820 } else {
821 ctx->message_len = 0;
822 ctx->parser_remaining = (len < 127) ? 2 : 8;
823 ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
824 }
825 break;
826 case WEBSOCKET_PARSER_STATE_EXT_LEN:
827 ctx->parser_remaining--;
828 ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
829 if (ctx->parser_remaining == 0) {
830 if (ctx->masked) {
831 ctx->masking_value = 0;
832 ctx->parser_remaining = 4;
833 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
834 } else {
835 ctx->parser_remaining = ctx->message_len;
836 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
837 }
838 }
839 break;
840 case WEBSOCKET_PARSER_STATE_MASK:
841 ctx->parser_remaining--;
842 ctx->masking_value |= (data << (ctx->parser_remaining * 8));
843 if (ctx->parser_remaining == 0) {
844 if (ctx->message_len == 0) {
845 ctx->parser_remaining = 0;
846 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
847 } else {
848 ctx->parser_remaining = ctx->message_len;
849 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
850 }
851 }
852 break;
853 default:
854 return -EFAULT;
855 }
856
857 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
858 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
859 ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
860 (ctx->message_len == 0))) {
861 NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
862 ctx->masked ? "" : "un",
863 ctx->masked ? ctx->masking_value : 0, ctx->message_type,
864 (size_t)ctx->message_len);
865 }
866 #endif
867 } else {
868 size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
869 size_t payload_in_recv_buf =
870 MIN(remaining_in_recv_buf, ctx->parser_remaining);
871 size_t free_in_payload_buf = payload->size - payload->count;
872 size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
873
874 if (free_in_payload_buf == 0) {
875 break;
876 }
877
878 memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
879 ready_to_copy);
880 parsed_count += ready_to_copy;
881 payload->count += ready_to_copy;
882 ctx->parser_remaining -= ready_to_copy;
883 if (ctx->parser_remaining == 0) {
884 ctx->parser_remaining = 0;
885 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
886 }
887 }
888
889 } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
890
891 return parsed_count;
892 }
893
894 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)895 static int wait_rx(int sock, int timeout)
896 {
897 struct zsock_pollfd fds = {
898 .fd = sock,
899 .events = ZSOCK_POLLIN,
900 };
901 int ret;
902
903 ret = zsock_poll(&fds, 1, timeout);
904 if (ret < 0) {
905 return ret;
906 }
907
908 if (ret == 0) {
909 /* Timeout */
910 return -EAGAIN;
911 }
912
913 if (fds.revents & ZSOCK_POLLNVAL) {
914 return -EBADF;
915 }
916
917 if (fds.revents & ZSOCK_POLLERR) {
918 return -EIO;
919 }
920
921 return 0;
922 }
923
timeout_to_ms(k_timeout_t * timeout)924 static int timeout_to_ms(k_timeout_t *timeout)
925 {
926 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
927 return 0;
928 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
929 return SYS_FOREVER_MS;
930 } else {
931 return k_ticks_to_ms_floor32(timeout->ticks);
932 }
933 }
934
935 #endif /* !defined(CONFIG_NET_TEST) */
936
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)937 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
938 uint32_t *message_type, uint64_t *remaining, int32_t timeout)
939 {
940 struct websocket_context *ctx;
941 int ret;
942 k_timepoint_t end;
943 k_timeout_t tout = K_FOREVER;
944 struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
945
946 if (timeout != SYS_FOREVER_MS) {
947 tout = K_MSEC(timeout);
948 }
949
950 if ((buf == NULL) || (buf_len == 0)) {
951 return -EINVAL;
952 }
953
954 end = sys_timepoint_calc(tout);
955
956 #if defined(CONFIG_NET_TEST)
957 struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
958
959 if (test_data == NULL) {
960 return -EBADF;
961 }
962
963 ctx = test_data->ctx;
964 #else
965 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
966 if (ctx == NULL) {
967 return -EBADF;
968 }
969
970 if (!PART_OF_ARRAY(contexts, ctx)) {
971 return -ENOENT;
972 }
973 #endif /* CONFIG_NET_TEST */
974
975 do {
976 size_t parsed_count;
977
978 if (ctx->recv_buf.count == 0) {
979 #if defined(CONFIG_NET_TEST)
980 size_t input_len = MIN(ctx->recv_buf.size,
981 test_data->input_len - test_data->input_pos);
982
983 if (input_len > 0) {
984 memcpy(ctx->recv_buf.buf,
985 &test_data->input_buf[test_data->input_pos], input_len);
986 test_data->input_pos += input_len;
987 ret = input_len;
988 } else {
989 /* emulate timeout */
990 ret = -EAGAIN;
991 }
992 #else
993 tout = sys_timepoint_timeout(end);
994
995 ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
996 if (ret == 0) {
997 ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
998 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
999 if (ret < 0) {
1000 ret = -errno;
1001 }
1002 }
1003 #endif /* CONFIG_NET_TEST */
1004
1005 if (ret < 0) {
1006 if ((ret == -EAGAIN) && (payload.count > 0)) {
1007 /* go to unmasking */
1008 break;
1009 }
1010 return ret;
1011 }
1012
1013 if (ret == 0) {
1014 /* Socket closed */
1015 return -ENOTCONN;
1016 }
1017
1018 ctx->recv_buf.count = ret;
1019
1020 NET_DBG("[%p] Received %d bytes", ctx, ret);
1021 }
1022
1023 ret = websocket_parse(ctx, &payload);
1024 if (ret < 0) {
1025 return ret;
1026 }
1027 parsed_count = ret;
1028
1029 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1030 (payload.count >= payload.size)) {
1031 if (remaining != NULL) {
1032 *remaining = ctx->parser_remaining;
1033 }
1034 if (message_type != NULL) {
1035 *message_type = ctx->message_type;
1036 }
1037
1038 size_t left = ctx->recv_buf.count - parsed_count;
1039
1040 if (left > 0) {
1041 memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1042 }
1043 ctx->recv_buf.count = left;
1044 break;
1045 }
1046
1047 ctx->recv_buf.count -= parsed_count;
1048
1049 } while (true);
1050
1051 /* Unmask the data */
1052 if (ctx->masked) {
1053 uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1054 size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1055
1056 for (size_t i = 0; i < payload.count; i++) {
1057 size_t m = data_buf_offset % 4;
1058
1059 payload.buf[i] ^= mask_as_bytes[3 - m];
1060 data_buf_offset++;
1061 }
1062 }
1063
1064 return payload.count;
1065 }
1066
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1067 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1068 size_t buf_len, int32_t timeout)
1069 {
1070 int ret;
1071
1072 NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1073
1074 ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1075 ctx->is_client, true, timeout);
1076 if (ret < 0) {
1077 errno = -ret;
1078 return -1;
1079 }
1080
1081 NET_DBG("[%p] Sent %d bytes", ctx, ret);
1082
1083 sock_obj_core_update_send_stats(ctx->sock, ret);
1084
1085 return ret;
1086 }
1087
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1088 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1089 size_t buf_len, int32_t timeout)
1090 {
1091 uint32_t message_type;
1092 uint64_t remaining;
1093 int ret;
1094
1095 NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1096
1097 /* TODO: add support for recvmsg() so that we could return the
1098 * websocket specific information in ancillary data.
1099 */
1100 ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1101 &remaining, timeout);
1102 if (ret < 0) {
1103 if (ret == -ENOTCONN) {
1104 ret = 0;
1105 } else {
1106 errno = -ret;
1107 return -1;
1108 }
1109 }
1110
1111 NET_DBG("[%p] Received %d bytes", ctx, ret);
1112
1113 sock_obj_core_update_recv_stats(ctx->sock, ret);
1114
1115 return ret;
1116 }
1117
websocket_read_vmeth(void * obj,void * buffer,size_t count)1118 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1119 {
1120 return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1121 }
1122
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1123 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1124 size_t count)
1125 {
1126 return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1127 }
1128
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1129 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1130 int flags,
1131 const struct sockaddr *dest_addr,
1132 socklen_t addrlen)
1133 {
1134 struct websocket_context *ctx = obj;
1135 int32_t timeout = SYS_FOREVER_MS;
1136
1137 if (flags & ZSOCK_MSG_DONTWAIT) {
1138 timeout = 0;
1139 }
1140
1141 ARG_UNUSED(dest_addr);
1142 ARG_UNUSED(addrlen);
1143
1144 return (ssize_t)websocket_send(ctx, buf, len, timeout);
1145 }
1146
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1147 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1148 int flags, struct sockaddr *src_addr,
1149 socklen_t *addrlen)
1150 {
1151 struct websocket_context *ctx = obj;
1152 int32_t timeout = SYS_FOREVER_MS;
1153
1154 if (flags & ZSOCK_MSG_DONTWAIT) {
1155 timeout = 0;
1156 }
1157
1158 ARG_UNUSED(src_addr);
1159 ARG_UNUSED(addrlen);
1160
1161 return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1162 }
1163
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1164 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1165 {
1166 struct websocket_context *ctx;
1167 int ret, fd;
1168
1169 if (sock < 0) {
1170 return -EINVAL;
1171 }
1172
1173 ctx = websocket_find(sock);
1174 if (ctx) {
1175 NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1176 return -EEXIST;
1177 }
1178
1179 ctx = websocket_get();
1180 if (!ctx) {
1181 return -ENOENT;
1182 }
1183
1184 ctx->real_sock = sock;
1185 ctx->recv_buf.buf = recv_buf;
1186 ctx->recv_buf.size = recv_buf_len;
1187 ctx->is_client = 0;
1188
1189 fd = zvfs_reserve_fd();
1190 if (fd < 0) {
1191 ret = -ENOSPC;
1192 goto out;
1193 }
1194
1195 ctx->sock = fd;
1196 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1197 ZVFS_MODE_IFSOCK);
1198
1199 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1200
1201 ctx->recv_buf.count = 0;
1202 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1203
1204 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, SOCK_STREAM);
1205
1206 return fd;
1207
1208 out:
1209 websocket_context_unref(ctx);
1210
1211 return ret;
1212 }
1213
websocket_search(int sock)1214 static struct websocket_context *websocket_search(int sock)
1215 {
1216 struct websocket_context *ctx = NULL;
1217 int i;
1218
1219 k_sem_take(&contexts_lock, K_FOREVER);
1220
1221 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1222 if (!websocket_context_is_used(&contexts[i])) {
1223 continue;
1224 }
1225
1226 if (contexts[i].sock != sock) {
1227 continue;
1228 }
1229
1230 ctx = &contexts[i];
1231 break;
1232 }
1233
1234 k_sem_give(&contexts_lock);
1235
1236 return ctx;
1237 }
1238
websocket_unregister(int sock)1239 int websocket_unregister(int sock)
1240 {
1241 struct websocket_context *ctx;
1242
1243 if (sock < 0) {
1244 return -EINVAL;
1245 }
1246
1247 ctx = websocket_search(sock);
1248 if (ctx == NULL) {
1249 NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1250 return -ENOENT;
1251 }
1252
1253 if (ctx->real_sock < 0) {
1254 return -EALREADY;
1255 }
1256
1257 (void)zsock_close(sock);
1258 (void)zsock_close(ctx->real_sock);
1259
1260 ctx->real_sock = -1;
1261 ctx->sock = -1;
1262
1263 return 0;
1264 }
1265
1266 static const struct socket_op_vtable websocket_fd_op_vtable = {
1267 .fd_vtable = {
1268 .read = websocket_read_vmeth,
1269 .write = websocket_write_vmeth,
1270 .close = websocket_close_vmeth,
1271 .ioctl = websocket_ioctl_vmeth,
1272 },
1273 .sendto = websocket_sendto_ctx,
1274 .recvfrom = websocket_recvfrom_ctx,
1275 };
1276
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1277 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1278 {
1279 int i;
1280
1281 k_sem_take(&contexts_lock, K_FOREVER);
1282
1283 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1284 if (!websocket_context_is_used(&contexts[i])) {
1285 continue;
1286 }
1287
1288 k_mutex_lock(&contexts[i].lock, K_FOREVER);
1289
1290 cb(&contexts[i], user_data);
1291
1292 k_mutex_unlock(&contexts[i].lock);
1293 }
1294
1295 k_sem_give(&contexts_lock);
1296 }
1297
websocket_init(void)1298 void websocket_init(void)
1299 {
1300 k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1301 }
1302