1 /** @file
2 * @brief Websocket client API
3 *
4 * An API for applications to setup a websocket connections.
5 */
6
7 /*
8 * Copyright (c) 2019 Intel Corporation
9 *
10 * SPDX-License-Identifier: Apache-2.0
11 */
12
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #if defined(CONFIG_POSIX_API)
26 #include <zephyr/posix/unistd.h>
27 #include <zephyr/posix/sys/socket.h>
28 #else
29 #include <zephyr/net/socket.h>
30 #endif
31 #include <zephyr/net/http/client.h>
32 #include <zephyr/net/websocket.h>
33
34 #include <zephyr/random/random.h>
35 #include <zephyr/sys/byteorder.h>
36 #include <zephyr/sys/base64.h>
37 #include <mbedtls/sha1.h>
38
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42
43 /* If you want to see the data that is being sent or received,
44 * then you can enable debugging and set the following variables to 1.
45 * This will print a lot of data so is not enabled by default.
46 */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51
52 static struct k_sem contexts_lock;
53
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg);
58 #endif
59
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 switch (opcode) {
63 case WEBSOCKET_OPCODE_DATA_TEXT:
64 return "TEXT";
65 case WEBSOCKET_OPCODE_DATA_BINARY:
66 return "BIN";
67 case WEBSOCKET_OPCODE_CONTINUE:
68 return "CONT";
69 case WEBSOCKET_OPCODE_CLOSE:
70 return "CLOSE";
71 case WEBSOCKET_OPCODE_PING:
72 return "PING";
73 case WEBSOCKET_OPCODE_PONG:
74 return "PONG";
75 default:
76 break;
77 }
78
79 return NULL;
80 }
81
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 int old_rc = atomic_inc(&ctx->refcount);
85
86 return old_rc + 1;
87 }
88
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 int old_rc = atomic_dec(&ctx->refcount);
92
93 if (old_rc != 1) {
94 return old_rc - 1;
95 }
96
97 return 0;
98 }
99
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 NET_ASSERT(ctx);
103
104 return !!atomic_get(&ctx->refcount);
105 }
106
websocket_get(void)107 static struct websocket_context *websocket_get(void)
108 {
109 struct websocket_context *ctx = NULL;
110 int i;
111
112 k_sem_take(&contexts_lock, K_FOREVER);
113
114 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
115 if (websocket_context_is_used(&contexts[i])) {
116 continue;
117 }
118
119 websocket_context_ref(&contexts[i]);
120 ctx = &contexts[i];
121 break;
122 }
123
124 k_sem_give(&contexts_lock);
125
126 return ctx;
127 }
128
websocket_find(int real_sock)129 static struct websocket_context *websocket_find(int real_sock)
130 {
131 struct websocket_context *ctx = NULL;
132 int i;
133
134 k_sem_take(&contexts_lock, K_FOREVER);
135
136 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
137 if (!websocket_context_is_used(&contexts[i])) {
138 continue;
139 }
140
141 if (contexts[i].real_sock != real_sock) {
142 continue;
143 }
144
145 ctx = &contexts[i];
146 break;
147 }
148
149 k_sem_give(&contexts_lock);
150
151 return ctx;
152 }
153
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)154 static void response_cb(struct http_response *rsp,
155 enum http_final_call final_data,
156 void *user_data)
157 {
158 struct websocket_context *ctx = user_data;
159
160 if (final_data == HTTP_DATA_MORE) {
161 NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
162 rsp->data_len);
163 ctx->all_received = false;
164 } else if (final_data == HTTP_DATA_FINAL) {
165 NET_DBG("[%p] All the data received (%zd bytes)", ctx,
166 rsp->data_len);
167 ctx->all_received = true;
168 }
169 }
170
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 size_t length)
173 {
174 struct http_request *req = CONTAINER_OF(parser,
175 struct http_request,
176 internal.parser);
177 struct websocket_context *ctx = req->internal.user_data;
178 const char *ws_accept_str = "Sec-WebSocket-Accept";
179 uint16_t len;
180
181 len = strlen(ws_accept_str);
182 if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 ctx->sec_accept_present = true;
184 }
185
186 if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 ctx->http_cb->on_header_field(parser, at, length);
188 }
189
190 return 0;
191 }
192
193 #define MAX_SEC_ACCEPT_LEN 32
194
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 size_t length)
197 {
198 struct http_request *req = CONTAINER_OF(parser,
199 struct http_request,
200 internal.parser);
201 struct websocket_context *ctx = req->internal.user_data;
202 char str[MAX_SEC_ACCEPT_LEN];
203
204 if (ctx->sec_accept_present) {
205 int ret;
206 size_t olen;
207
208 ctx->sec_accept_ok = false;
209 ctx->sec_accept_present = false;
210
211 ret = base64_encode(str, sizeof(str) - 1, &olen,
212 ctx->sec_accept_key,
213 WS_SHA1_OUTPUT_LEN);
214 if (ret == 0) {
215 if (strncmp(at, str, length)) {
216 NET_DBG("[%p] Security keys do not match "
217 "%s vs %s", ctx, str, at);
218 } else {
219 ctx->sec_accept_ok = true;
220 }
221 }
222 }
223
224 if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 ctx->http_cb->on_header_value(parser, at, length);
226 }
227
228 return 0;
229 }
230
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 int32_t timeout, void *user_data)
233 {
234 /* This is the expected Sec-WebSocket-Accept key. We are storing a
235 * pointer to this in ctx but the value is only used for the duration
236 * of this function call so there is no issue even if this variable
237 * is allocated from stack.
238 */
239 uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 struct http_parser_settings http_parser_settings;
241 struct websocket_context *ctx;
242 struct http_request req;
243 int ret, fd, key_len;
244 size_t olen;
245 char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 uint32_t rnd_value = sys_rand32_get();
247 char sec_ws_key[] =
248 "Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 char *headers[] = {
250 sec_ws_key,
251 "Upgrade: websocket\r\n",
252 "Connection: Upgrade\r\n",
253 "Sec-WebSocket-Version: 13\r\n",
254 NULL
255 };
256
257 fd = -1;
258
259 if (sock < 0 || wreq == NULL || wreq->host == NULL ||
260 wreq->url == NULL) {
261 return -EINVAL;
262 }
263
264 ctx = websocket_find(sock);
265 if (ctx) {
266 NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
267 sock);
268 return -EEXIST;
269 }
270
271 ctx = websocket_get();
272 if (!ctx) {
273 return -ENOENT;
274 }
275
276 ctx->real_sock = sock;
277 ctx->recv_buf.buf = wreq->tmp_buf;
278 ctx->recv_buf.size = wreq->tmp_buf_len;
279 ctx->sec_accept_key = sec_accept_key;
280 ctx->http_cb = wreq->http_cb;
281
282 mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value),
283 sec_accept_key);
284
285 ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
286 sizeof(sec_ws_key) -
287 sizeof("Sec-Websocket-Key: "),
288 &olen, sec_accept_key,
289 /* We are only interested in 16 first bytes so
290 * subtract 4 from the SHA-1 length
291 */
292 sizeof(sec_accept_key) - 4);
293 if (ret) {
294 NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
295 goto out;
296 }
297
298 if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
299 NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
300 olen + sizeof("Sec-Websocket-Key: ") + 2,
301 sizeof(sec_ws_key));
302 ret = -EMSGSIZE;
303 goto out;
304 }
305
306 memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
307 HTTP_CRLF, sizeof(HTTP_CRLF));
308
309 memset(&req, 0, sizeof(req));
310
311 req.method = HTTP_GET;
312 req.url = wreq->url;
313 req.host = wreq->host;
314 req.protocol = "HTTP/1.1";
315 req.header_fields = (const char **)headers;
316 req.optional_headers_cb = wreq->optional_headers_cb;
317 req.optional_headers = wreq->optional_headers;
318 req.response = response_cb;
319 req.http_cb = &http_parser_settings;
320 req.recv_buf = wreq->tmp_buf;
321 req.recv_buf_len = wreq->tmp_buf_len;
322
323 /* We need to catch the Sec-WebSocket-Accept field in order to verify
324 * that it contains the stuff that we sent in Sec-WebSocket-Key field
325 * so setup HTTP callbacks so that we will get the needed fields.
326 */
327 if (ctx->http_cb) {
328 memcpy(&http_parser_settings, ctx->http_cb,
329 sizeof(http_parser_settings));
330 } else {
331 memset(&http_parser_settings, 0, sizeof(http_parser_settings));
332 }
333
334 http_parser_settings.on_header_field = on_header_field;
335 http_parser_settings.on_header_value = on_header_value;
336
337 /* Pre-calculate the expected Sec-Websocket-Accept field */
338 key_len = MIN(sizeof(key_accept) - 1, olen);
339 strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
340 key_len);
341
342 olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
343 strncpy(key_accept + key_len, WS_MAGIC, olen);
344
345 /* This SHA-1 value is then checked when we receive the response */
346 mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
347
348 ret = http_client_req(sock, &req, timeout, ctx);
349 if (ret < 0) {
350 NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
351 wreq->host);
352 ret = -ECONNABORTED;
353 goto out;
354 }
355
356 if (!(ctx->all_received && ctx->sec_accept_ok)) {
357 NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
358 ctx->all_received, ctx->sec_accept_ok);
359 ret = -ECONNABORTED;
360 goto out;
361 }
362
363 ctx->user_data = user_data;
364
365 fd = z_reserve_fd();
366 if (fd < 0) {
367 ret = -ENOSPC;
368 goto out;
369 }
370
371 ctx->sock = fd;
372 z_finalize_fd(fd, ctx,
373 (const struct fd_op_vtable *)&websocket_fd_op_vtable);
374
375 /* Call the user specified callback and if it accepts the connection
376 * then continue.
377 */
378 if (wreq->cb) {
379 ret = wreq->cb(fd, &req, user_data);
380 if (ret < 0) {
381 NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
382 goto out;
383 }
384 }
385
386 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
387
388 /* We will re-use the temp buffer in receive function if needed but
389 * in order that to work the amount of data in buffer must be set to 0
390 */
391 ctx->recv_buf.count = 0;
392
393 /* Init parser FSM */
394 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
395
396 return fd;
397
398 out:
399 if (fd >= 0) {
400 (void)close(fd);
401 }
402
403 websocket_context_unref(ctx);
404 return ret;
405 }
406
websocket_disconnect(int ws_sock)407 int websocket_disconnect(int ws_sock)
408 {
409 return close(ws_sock);
410 }
411
websocket_interal_disconnect(struct websocket_context * ctx)412 static int websocket_interal_disconnect(struct websocket_context *ctx)
413 {
414 int ret;
415
416 if (ctx == NULL) {
417 return -ENOENT;
418 }
419
420 NET_DBG("[%p] Disconnecting", ctx);
421
422 ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
423 true, true, SYS_FOREVER_MS);
424 if (ret < 0) {
425 NET_ERR("[%p] Failed to send close message (err %d).", ctx, ret);
426 }
427
428 websocket_context_unref(ctx);
429
430 return ret;
431 }
432
websocket_close_vmeth(void * obj)433 static int websocket_close_vmeth(void *obj)
434 {
435 struct websocket_context *ctx = obj;
436 int ret;
437
438 ret = websocket_interal_disconnect(ctx);
439 if (ret < 0) {
440 NET_DBG("[%p] Cannot close (%d)", obj, ret);
441
442 errno = -ret;
443 return -1;
444 }
445
446 return ret;
447 }
448
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)449 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
450 int timeout)
451 {
452 int fd_backup[CONFIG_NET_SOCKETS_POLL_MAX];
453 const struct fd_op_vtable *vtable;
454 void *ctx;
455 int ret = 0;
456 int i;
457
458 /* Overwrite websocket file descriptors with underlying ones. */
459 for (i = 0; i < nfds; i++) {
460 fd_backup[i] = fds[i].fd;
461
462 ctx = z_get_fd_obj(fds[i].fd,
463 (const struct fd_op_vtable *)
464 &websocket_fd_op_vtable,
465 0);
466 if (ctx == NULL) {
467 continue;
468 }
469
470 fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
471 }
472
473 /* Get offloaded sockets vtable. */
474 ctx = z_get_fd_obj_and_vtable(fds[0].fd,
475 (const struct fd_op_vtable **)&vtable,
476 NULL);
477 if (ctx == NULL) {
478 errno = EINVAL;
479 ret = -1;
480 goto exit;
481 }
482
483 ret = z_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
484 fds, nfds, timeout);
485
486 exit:
487 /* Restore original fds. */
488 for (i = 0; i < nfds; i++) {
489 fds[i].fd = fd_backup[i];
490 }
491
492 return ret;
493 }
494
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)495 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
496 {
497 struct websocket_context *ctx = obj;
498
499 switch (request) {
500 case ZFD_IOCTL_POLL_OFFLOAD: {
501 struct zsock_pollfd *fds;
502 int nfds;
503 int timeout;
504
505 fds = va_arg(args, struct zsock_pollfd *);
506 nfds = va_arg(args, int);
507 timeout = va_arg(args, int);
508
509 return websocket_poll_offload(fds, nfds, timeout);
510 }
511
512 case ZFD_IOCTL_SET_LOCK:
513 /* Ignore, don't want to overwrite underlying socket lock. */
514 return 0;
515
516 default: {
517 const struct fd_op_vtable *vtable;
518 void *core_obj;
519
520 core_obj = z_get_fd_obj_and_vtable(
521 ctx->real_sock,
522 (const struct fd_op_vtable **)&vtable,
523 NULL);
524 if (core_obj == NULL) {
525 errno = EBADF;
526 return -1;
527 }
528
529 /* Pass the call to the core socket implementation. */
530 return vtable->ioctl(core_obj, request, args);
531 }
532 }
533
534 return 0;
535 }
536
537 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct msghdr * message,int flags)538 static int sendmsg_all(int sock, const struct msghdr *message, int flags)
539 {
540 int ret, i;
541 size_t offset = 0;
542 size_t total_len = 0;
543
544 for (i = 0; i < message->msg_iovlen; i++) {
545 total_len += message->msg_iov[i].iov_len;
546 }
547
548 while (offset < total_len) {
549 ret = zsock_sendmsg(sock, message, flags);
550 if (ret < 0) {
551 return -errno;
552 }
553
554 offset += ret;
555 if (offset >= total_len) {
556 break;
557 }
558
559 /* Update msghdr for the next iteration. */
560 for (i = 0; i < message->msg_iovlen; i++) {
561 if (ret < message->msg_iov[i].iov_len) {
562 message->msg_iov[i].iov_len -= ret;
563 message->msg_iov[i].iov_base =
564 (uint8_t *)message->msg_iov[i].iov_base + ret;
565 break;
566 }
567
568 ret -= message->msg_iov[i].iov_len;
569 message->msg_iov[i].iov_len = 0;
570 }
571 }
572
573 return total_len;
574 }
575 #endif /* !defined(CONFIG_NET_TEST) */
576
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)577 static int websocket_prepare_and_send(struct websocket_context *ctx,
578 uint8_t *header, size_t header_len,
579 uint8_t *payload, size_t payload_len,
580 int32_t timeout)
581 {
582 struct iovec io_vector[2];
583 struct msghdr msg;
584
585 io_vector[0].iov_base = header;
586 io_vector[0].iov_len = header_len;
587 io_vector[1].iov_base = payload;
588 io_vector[1].iov_len = payload_len;
589
590 memset(&msg, 0, sizeof(msg));
591
592 msg.msg_iov = io_vector;
593 msg.msg_iovlen = ARRAY_SIZE(io_vector);
594
595 if (HEXDUMP_SENT_PACKETS) {
596 LOG_HEXDUMP_DBG(header, header_len, "Header");
597 if ((payload != NULL) && (payload_len > 0)) {
598 LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
599 } else {
600 LOG_DBG("No payload");
601 }
602 }
603
604 #if defined(CONFIG_NET_TEST)
605 /* Simulate a case where the payload is split to two. The unit test
606 * does not set mask bit in this case.
607 */
608 return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
609 #else
610 k_timeout_t tout = K_FOREVER;
611
612 if (timeout != SYS_FOREVER_MS) {
613 tout = K_MSEC(timeout);
614 }
615
616 return sendmsg_all(ctx->real_sock, &msg,
617 K_TIMEOUT_EQ(tout, K_NO_WAIT) ? MSG_DONTWAIT : 0);
618 #endif /* CONFIG_NET_TEST */
619 }
620
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)621 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
622 enum websocket_opcode opcode, bool mask, bool final,
623 int32_t timeout)
624 {
625 struct websocket_context *ctx;
626 uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
627 uint8_t *data_to_send = (uint8_t *)payload;
628 int ret;
629
630 if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
631 opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
632 opcode != WEBSOCKET_OPCODE_CONTINUE &&
633 opcode != WEBSOCKET_OPCODE_CLOSE &&
634 opcode != WEBSOCKET_OPCODE_PING &&
635 opcode != WEBSOCKET_OPCODE_PONG) {
636 return -EINVAL;
637 }
638
639 ctx = z_get_fd_obj(ws_sock, NULL, 0);
640 if (ctx == NULL) {
641 return -EBADF;
642 }
643
644 #if !defined(CONFIG_NET_TEST)
645 /* Websocket unit test does not use context from pool but allocates
646 * its own, hence skip the check.
647 */
648
649 if (!PART_OF_ARRAY(contexts, ctx)) {
650 return -ENOENT;
651 }
652 #endif /* !defined(CONFIG_NET_TEST) */
653
654 NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
655 mask, final ? "final" : "more");
656
657 memset(header, 0, sizeof(header));
658
659 /* Is this the last packet? */
660 header[0] = final ? BIT(7) : 0;
661
662 /* Text, binary, ping, pong or close ? */
663 header[0] |= opcode;
664
665 /* Masking */
666 header[1] = mask ? BIT(7) : 0;
667
668 if (payload_len < 126) {
669 header[1] |= payload_len;
670 } else if (payload_len < 65536) {
671 header[1] |= 126;
672 header[2] = payload_len >> 8;
673 header[3] = payload_len;
674 hdr_len += 2;
675 } else {
676 header[1] |= 127;
677 header[2] = 0;
678 header[3] = 0;
679 header[4] = 0;
680 header[5] = 0;
681 header[6] = payload_len >> 24;
682 header[7] = payload_len >> 16;
683 header[8] = payload_len >> 8;
684 header[9] = payload_len;
685 hdr_len += 8;
686 }
687
688 /* Add masking value if needed */
689 if (mask) {
690 int i;
691
692 ctx->masking_value = sys_rand32_get();
693
694 header[hdr_len++] |= ctx->masking_value >> 24;
695 header[hdr_len++] |= ctx->masking_value >> 16;
696 header[hdr_len++] |= ctx->masking_value >> 8;
697 header[hdr_len++] |= ctx->masking_value;
698
699 if ((payload != NULL) && (payload_len > 0)) {
700 data_to_send = k_malloc(payload_len);
701 if (!data_to_send) {
702 return -ENOMEM;
703 }
704
705 memcpy(data_to_send, payload, payload_len);
706
707 for (i = 0; i < payload_len; i++) {
708 data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
709 }
710 }
711 }
712
713 ret = websocket_prepare_and_send(ctx, header, hdr_len,
714 data_to_send, payload_len, timeout);
715 if (ret < 0) {
716 NET_DBG("Cannot send ws msg (%d)", -errno);
717 goto quit;
718 }
719
720 quit:
721 if (data_to_send != payload) {
722 k_free(data_to_send);
723 }
724
725 /* Do no math with 0 and error codes */
726 if (ret <= 0) {
727 return ret;
728 }
729
730 return ret - hdr_len;
731 }
732
websocket_opcode2flag(uint8_t data)733 static uint32_t websocket_opcode2flag(uint8_t data)
734 {
735 switch (data & 0x0f) {
736 case WEBSOCKET_OPCODE_DATA_TEXT:
737 return WEBSOCKET_FLAG_TEXT;
738 case WEBSOCKET_OPCODE_DATA_BINARY:
739 return WEBSOCKET_FLAG_BINARY;
740 case WEBSOCKET_OPCODE_CLOSE:
741 return WEBSOCKET_FLAG_CLOSE;
742 case WEBSOCKET_OPCODE_PING:
743 return WEBSOCKET_FLAG_PING;
744 case WEBSOCKET_OPCODE_PONG:
745 return WEBSOCKET_FLAG_PONG;
746 default:
747 break;
748 }
749 return 0;
750 }
751
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)752 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
753 {
754 int len;
755 uint8_t data;
756 size_t parsed_count = 0;
757
758 do {
759 if (parsed_count >= ctx->recv_buf.count) {
760 return parsed_count;
761 }
762 if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
763 data = ctx->recv_buf.buf[parsed_count++];
764
765 switch (ctx->parser_state) {
766 case WEBSOCKET_PARSER_STATE_OPCODE:
767 ctx->message_type = websocket_opcode2flag(data);
768 if ((data & 0x80) != 0) {
769 ctx->message_type |= WEBSOCKET_FLAG_FINAL;
770 }
771 ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
772 break;
773 case WEBSOCKET_PARSER_STATE_LENGTH:
774 ctx->masked = (data & 0x80) != 0;
775 len = data & 0x7f;
776 if (len < 126) {
777 ctx->message_len = len;
778 if (ctx->masked) {
779 ctx->masking_value = 0;
780 ctx->parser_remaining = 4;
781 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
782 } else {
783 ctx->parser_remaining = ctx->message_len;
784 ctx->parser_state =
785 (ctx->parser_remaining == 0)
786 ? WEBSOCKET_PARSER_STATE_OPCODE
787 : WEBSOCKET_PARSER_STATE_PAYLOAD;
788 }
789 } else {
790 ctx->message_len = 0;
791 ctx->parser_remaining = (len < 127) ? 2 : 8;
792 ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
793 }
794 break;
795 case WEBSOCKET_PARSER_STATE_EXT_LEN:
796 ctx->parser_remaining--;
797 ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
798 if (ctx->parser_remaining == 0) {
799 if (ctx->masked) {
800 ctx->masking_value = 0;
801 ctx->parser_remaining = 4;
802 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
803 } else {
804 ctx->parser_remaining = ctx->message_len;
805 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
806 }
807 }
808 break;
809 case WEBSOCKET_PARSER_STATE_MASK:
810 ctx->parser_remaining--;
811 ctx->masking_value |= (data << (ctx->parser_remaining * 8));
812 if (ctx->parser_remaining == 0) {
813 if (ctx->message_len == 0) {
814 ctx->parser_remaining = 0;
815 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
816 } else {
817 ctx->parser_remaining = ctx->message_len;
818 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
819 }
820 }
821 break;
822 default:
823 return -EFAULT;
824 }
825
826 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
827 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
828 ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
829 (ctx->message_len == 0))) {
830 NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
831 ctx->masked ? "" : "un",
832 ctx->masked ? ctx->masking_value : 0, ctx->message_type,
833 (size_t)ctx->message_len);
834 }
835 #endif
836 } else {
837 size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
838 size_t payload_in_recv_buf =
839 MIN(remaining_in_recv_buf, ctx->parser_remaining);
840 size_t free_in_payload_buf = payload->size - payload->count;
841 size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
842
843 if (free_in_payload_buf == 0) {
844 break;
845 }
846
847 memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
848 ready_to_copy);
849 parsed_count += ready_to_copy;
850 payload->count += ready_to_copy;
851 ctx->parser_remaining -= ready_to_copy;
852 if (ctx->parser_remaining == 0) {
853 ctx->parser_remaining = 0;
854 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
855 }
856 }
857
858 } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
859
860 return parsed_count;
861 }
862
863 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)864 static int wait_rx(int sock, int timeout)
865 {
866 struct zsock_pollfd fds = {
867 .fd = sock,
868 .events = ZSOCK_POLLIN,
869 };
870 int ret;
871
872 ret = zsock_poll(&fds, 1, timeout);
873 if (ret < 0) {
874 return ret;
875 }
876
877 if (ret == 0) {
878 /* Timeout */
879 return -EAGAIN;
880 }
881
882 if (fds.revents & ZSOCK_POLLNVAL) {
883 return -EBADF;
884 }
885
886 if (fds.revents & ZSOCK_POLLERR) {
887 return -EIO;
888 }
889
890 return 0;
891 }
892
timeout_to_ms(k_timeout_t * timeout)893 static int timeout_to_ms(k_timeout_t *timeout)
894 {
895 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
896 return 0;
897 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
898 return SYS_FOREVER_MS;
899 } else {
900 return k_ticks_to_ms_floor32(timeout->ticks);
901 }
902 }
903
904 #endif /* !defined(CONFIG_NET_TEST) */
905
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)906 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
907 uint32_t *message_type, uint64_t *remaining, int32_t timeout)
908 {
909 struct websocket_context *ctx;
910 int ret;
911 k_timepoint_t end;
912 k_timeout_t tout = K_FOREVER;
913 struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
914
915 if (timeout != SYS_FOREVER_MS) {
916 tout = K_MSEC(timeout);
917 }
918
919 if ((buf == NULL) || (buf_len == 0)) {
920 return -EINVAL;
921 }
922
923 end = sys_timepoint_calc(tout);
924
925 #if defined(CONFIG_NET_TEST)
926 struct test_data *test_data = z_get_fd_obj(ws_sock, NULL, 0);
927
928 if (test_data == NULL) {
929 return -EBADF;
930 }
931
932 ctx = test_data->ctx;
933 #else
934 ctx = z_get_fd_obj(ws_sock, NULL, 0);
935 if (ctx == NULL) {
936 return -EBADF;
937 }
938
939 if (!PART_OF_ARRAY(contexts, ctx)) {
940 return -ENOENT;
941 }
942 #endif /* CONFIG_NET_TEST */
943
944 do {
945 size_t parsed_count;
946
947 if (ctx->recv_buf.count == 0) {
948 #if defined(CONFIG_NET_TEST)
949 size_t input_len = MIN(ctx->recv_buf.size,
950 test_data->input_len - test_data->input_pos);
951
952 if (input_len > 0) {
953 memcpy(ctx->recv_buf.buf,
954 &test_data->input_buf[test_data->input_pos], input_len);
955 test_data->input_pos += input_len;
956 ret = input_len;
957 } else {
958 /* emulate timeout */
959 ret = -EAGAIN;
960 }
961 #else
962 tout = sys_timepoint_timeout(end);
963
964 ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
965 if (ret == 0) {
966 ret = recv(ctx->real_sock, ctx->recv_buf.buf,
967 ctx->recv_buf.size, MSG_DONTWAIT);
968 if (ret < 0) {
969 ret = -errno;
970 }
971 }
972 #endif /* CONFIG_NET_TEST */
973
974 if (ret < 0) {
975 if ((ret == -EAGAIN) && (payload.count > 0)) {
976 /* go to unmasking */
977 break;
978 }
979 return ret;
980 }
981
982 if (ret == 0) {
983 /* Socket closed */
984 return -ENOTCONN;
985 }
986
987 ctx->recv_buf.count = ret;
988
989 NET_DBG("[%p] Received %d bytes", ctx, ret);
990 }
991
992 ret = websocket_parse(ctx, &payload);
993 if (ret < 0) {
994 return ret;
995 }
996 parsed_count = ret;
997
998 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
999 (payload.count >= payload.size)) {
1000 if (remaining != NULL) {
1001 *remaining = ctx->parser_remaining;
1002 }
1003 if (message_type != NULL) {
1004 *message_type = ctx->message_type;
1005 }
1006
1007 size_t left = ctx->recv_buf.count - parsed_count;
1008
1009 if (left > 0) {
1010 memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1011 }
1012 ctx->recv_buf.count = left;
1013 break;
1014 }
1015
1016 ctx->recv_buf.count -= parsed_count;
1017
1018 } while (true);
1019
1020 /* Unmask the data */
1021 if (ctx->masked) {
1022 uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1023 size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1024
1025 for (size_t i = 0; i < payload.count; i++) {
1026 size_t m = data_buf_offset % 4;
1027
1028 payload.buf[i] ^= mask_as_bytes[3 - m];
1029 data_buf_offset++;
1030 }
1031 }
1032
1033 return payload.count;
1034 }
1035
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1036 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1037 size_t buf_len, int32_t timeout)
1038 {
1039 int ret;
1040
1041 NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1042
1043 ret = websocket_send_msg(ctx->sock, buf, buf_len,
1044 WEBSOCKET_OPCODE_DATA_TEXT,
1045 true, true, timeout);
1046 if (ret < 0) {
1047 errno = -ret;
1048 return -1;
1049 }
1050
1051 NET_DBG("[%p] Sent %d bytes", ctx, ret);
1052
1053 return ret;
1054 }
1055
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1056 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1057 size_t buf_len, int32_t timeout)
1058 {
1059 uint32_t message_type;
1060 uint64_t remaining;
1061 int ret;
1062
1063 NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1064
1065 /* TODO: add support for recvmsg() so that we could return the
1066 * websocket specific information in ancillary data.
1067 */
1068 ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1069 &remaining, timeout);
1070 if (ret < 0) {
1071 if (ret == -ENOTCONN) {
1072 ret = 0;
1073 } else {
1074 errno = -ret;
1075 return -1;
1076 }
1077 }
1078
1079 NET_DBG("[%p] Received %d bytes", ctx, ret);
1080
1081 return ret;
1082 }
1083
websocket_read_vmeth(void * obj,void * buffer,size_t count)1084 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1085 {
1086 return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1087 }
1088
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1089 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1090 size_t count)
1091 {
1092 return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1093 }
1094
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)1095 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1096 int flags,
1097 const struct sockaddr *dest_addr,
1098 socklen_t addrlen)
1099 {
1100 struct websocket_context *ctx = obj;
1101 int32_t timeout = SYS_FOREVER_MS;
1102
1103 if (flags & ZSOCK_MSG_DONTWAIT) {
1104 timeout = 0;
1105 }
1106
1107 ARG_UNUSED(dest_addr);
1108 ARG_UNUSED(addrlen);
1109
1110 return (ssize_t)websocket_send(ctx, buf, len, timeout);
1111 }
1112
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)1113 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1114 int flags, struct sockaddr *src_addr,
1115 socklen_t *addrlen)
1116 {
1117 struct websocket_context *ctx = obj;
1118 int32_t timeout = SYS_FOREVER_MS;
1119
1120 if (flags & ZSOCK_MSG_DONTWAIT) {
1121 timeout = 0;
1122 }
1123
1124 ARG_UNUSED(src_addr);
1125 ARG_UNUSED(addrlen);
1126
1127 return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1128 }
1129
1130 static const struct socket_op_vtable websocket_fd_op_vtable = {
1131 .fd_vtable = {
1132 .read = websocket_read_vmeth,
1133 .write = websocket_write_vmeth,
1134 .close = websocket_close_vmeth,
1135 .ioctl = websocket_ioctl_vmeth,
1136 },
1137 .sendto = websocket_sendto_ctx,
1138 .recvfrom = websocket_recvfrom_ctx,
1139 };
1140
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1141 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1142 {
1143 int i;
1144
1145 k_sem_take(&contexts_lock, K_FOREVER);
1146
1147 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1148 if (!websocket_context_is_used(&contexts[i])) {
1149 continue;
1150 }
1151
1152 k_mutex_lock(&contexts[i].lock, K_FOREVER);
1153
1154 cb(&contexts[i], user_data);
1155
1156 k_mutex_unlock(&contexts[i].lock);
1157 }
1158
1159 k_sem_give(&contexts_lock);
1160 }
1161
websocket_init(void)1162 void websocket_init(void)
1163 {
1164 k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1165 }
1166