1 /** @file
2 * @brief Websocket client API
3 *
4 * An API for applications to setup a websocket connections.
5 */
6
7 /*
8 * Copyright (c) 2019 Intel Corporation
9 *
10 * SPDX-License-Identifier: Apache-2.0
11 */
12
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_websocket, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
15
16 #include <zephyr/kernel.h>
17 #include <strings.h>
18 #include <errno.h>
19 #include <stdbool.h>
20 #include <stdlib.h>
21
22 #include <zephyr/sys/fdtable.h>
23 #include <zephyr/net/net_core.h>
24 #include <zephyr/net/net_ip.h>
25 #include <zephyr/net/socket.h>
26 #include <zephyr/net/http/client.h>
27 #include <zephyr/net/websocket.h>
28
29 #include <zephyr/random/random.h>
30 #include <zephyr/sys/byteorder.h>
31 #include <zephyr/sys/base64.h>
32
33 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
34 #include <psa/crypto.h>
35 #else
36 #include <mbedtls/sha1.h>
37 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
38
39 #include "net_private.h"
40 #include "sockets_internal.h"
41 #include "websocket_internal.h"
42
43 /* If you want to see the data that is being sent or received,
44 * then you can enable debugging and set the following variables to 1.
45 * This will print a lot of data so is not enabled by default.
46 */
47 #define HEXDUMP_SENT_PACKETS 0
48 #define HEXDUMP_RECV_PACKETS 0
49
50 static struct websocket_context contexts[CONFIG_WEBSOCKET_MAX_CONTEXTS];
51
52 static struct k_sem contexts_lock;
53
54 static const struct socket_op_vtable websocket_fd_op_vtable;
55
56 #if defined(CONFIG_NET_TEST)
57 int verify_sent_and_received_msg(struct net_msghdr *msg, bool split_msg);
58 #endif
59
opcode2str(enum websocket_opcode opcode)60 static const char *opcode2str(enum websocket_opcode opcode)
61 {
62 switch (opcode) {
63 case WEBSOCKET_OPCODE_DATA_TEXT:
64 return "TEXT";
65 case WEBSOCKET_OPCODE_DATA_BINARY:
66 return "BIN";
67 case WEBSOCKET_OPCODE_CONTINUE:
68 return "CONT";
69 case WEBSOCKET_OPCODE_CLOSE:
70 return "CLOSE";
71 case WEBSOCKET_OPCODE_PING:
72 return "PING";
73 case WEBSOCKET_OPCODE_PONG:
74 return "PONG";
75 default:
76 break;
77 }
78
79 return NULL;
80 }
81
websocket_context_ref(struct websocket_context * ctx)82 static int websocket_context_ref(struct websocket_context *ctx)
83 {
84 int old_rc = atomic_inc(&ctx->refcount);
85
86 return old_rc + 1;
87 }
88
websocket_context_unref(struct websocket_context * ctx)89 static int websocket_context_unref(struct websocket_context *ctx)
90 {
91 int old_rc = atomic_dec(&ctx->refcount);
92
93 if (old_rc != 1) {
94 return old_rc - 1;
95 }
96
97 return 0;
98 }
99
websocket_context_is_used(struct websocket_context * ctx)100 static inline bool websocket_context_is_used(struct websocket_context *ctx)
101 {
102 return !!atomic_get(&ctx->refcount);
103 }
104
websocket_get(void)105 static struct websocket_context *websocket_get(void)
106 {
107 struct websocket_context *ctx = NULL;
108 int i;
109
110 k_sem_take(&contexts_lock, K_FOREVER);
111
112 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
113 if (websocket_context_is_used(&contexts[i])) {
114 continue;
115 }
116
117 websocket_context_ref(&contexts[i]);
118 ctx = &contexts[i];
119 break;
120 }
121
122 k_sem_give(&contexts_lock);
123
124 return ctx;
125 }
126
websocket_find(int real_sock)127 static struct websocket_context *websocket_find(int real_sock)
128 {
129 struct websocket_context *ctx = NULL;
130 int i;
131
132 k_sem_take(&contexts_lock, K_FOREVER);
133
134 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
135 if (!websocket_context_is_used(&contexts[i])) {
136 continue;
137 }
138
139 if (contexts[i].real_sock != real_sock) {
140 continue;
141 }
142
143 ctx = &contexts[i];
144 break;
145 }
146
147 k_sem_give(&contexts_lock);
148
149 return ctx;
150 }
151
response_cb(struct http_response * rsp,enum http_final_call final_data,void * user_data)152 static int response_cb(struct http_response *rsp,
153 enum http_final_call final_data,
154 void *user_data)
155 {
156 struct websocket_context *ctx = user_data;
157
158 if (final_data == HTTP_DATA_MORE) {
159 NET_DBG("[%p] Partial data received (%zd bytes)", ctx,
160 rsp->data_len);
161 ctx->all_received = false;
162 } else if (final_data == HTTP_DATA_FINAL) {
163 NET_DBG("[%p] All the data received (%zd bytes)", ctx,
164 rsp->data_len);
165 ctx->all_received = true;
166 }
167
168 return 0;
169 }
170
on_header_field(struct http_parser * parser,const char * at,size_t length)171 static int on_header_field(struct http_parser *parser, const char *at,
172 size_t length)
173 {
174 struct http_request *req = CONTAINER_OF(parser,
175 struct http_request,
176 internal.parser);
177 struct websocket_context *ctx = req->internal.user_data;
178 const char *ws_accept_str = "Sec-WebSocket-Accept";
179 uint16_t len;
180
181 len = strlen(ws_accept_str);
182 if (length >= len && strncasecmp(at, ws_accept_str, len) == 0) {
183 ctx->sec_accept_present = true;
184 }
185
186 if (ctx->http_cb && ctx->http_cb->on_header_field) {
187 ctx->http_cb->on_header_field(parser, at, length);
188 }
189
190 return 0;
191 }
192
193 #define MAX_SEC_ACCEPT_LEN 32
194
on_header_value(struct http_parser * parser,const char * at,size_t length)195 static int on_header_value(struct http_parser *parser, const char *at,
196 size_t length)
197 {
198 struct http_request *req = CONTAINER_OF(parser,
199 struct http_request,
200 internal.parser);
201 struct websocket_context *ctx = req->internal.user_data;
202 char str[MAX_SEC_ACCEPT_LEN];
203
204 if (ctx->sec_accept_present) {
205 int ret;
206 size_t olen;
207
208 ctx->sec_accept_ok = false;
209 ctx->sec_accept_present = false;
210
211 ret = base64_encode(str, sizeof(str) - 1, &olen,
212 ctx->sec_accept_key,
213 WS_SHA1_OUTPUT_LEN);
214 if (ret == 0) {
215 if (strncmp(at, str, length)) {
216 NET_DBG("[%p] Security keys do not match "
217 "%s vs %s", ctx, str, at);
218 } else {
219 ctx->sec_accept_ok = true;
220 }
221 }
222 }
223
224 if (ctx->http_cb && ctx->http_cb->on_header_value) {
225 ctx->http_cb->on_header_value(parser, at, length);
226 }
227
228 return 0;
229 }
230
websocket_connect(int sock,struct websocket_request * wreq,int32_t timeout,void * user_data)231 int websocket_connect(int sock, struct websocket_request *wreq,
232 int32_t timeout, void *user_data)
233 {
234 /* This is the expected Sec-WebSocket-Accept key. We are storing a
235 * pointer to this in ctx but the value is only used for the duration
236 * of this function call so there is no issue even if this variable
237 * is allocated from stack.
238 */
239 uint8_t sec_accept_key[WS_SHA1_OUTPUT_LEN];
240 struct http_parser_settings http_parser_settings;
241 struct websocket_context *ctx;
242 struct http_request req;
243 int ret, fd, key_len;
244 size_t olen;
245 char key_accept[MAX_SEC_ACCEPT_LEN + sizeof(WS_MAGIC)];
246 uint32_t rnd_value = sys_rand32_get();
247 char sec_ws_key[] =
248 "Sec-WebSocket-Key: 0123456789012345678901==\r\n";
249 char *headers[] = {
250 sec_ws_key,
251 "Upgrade: websocket\r\n",
252 "Connection: Upgrade\r\n",
253 "Sec-WebSocket-Version: 13\r\n",
254 NULL
255 };
256 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
257 psa_status_t psa_status;
258 size_t hash_length;
259 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
260
261 fd = -1;
262
263 if (sock < 0 || wreq == NULL || wreq->host == NULL ||
264 wreq->url == NULL) {
265 return -EINVAL;
266 }
267
268 ctx = websocket_find(sock);
269 if (ctx) {
270 NET_DBG("[%p] Websocket for sock %d already exists!", ctx,
271 sock);
272 return -EEXIST;
273 }
274
275 ctx = websocket_get();
276 if (!ctx) {
277 return -ENOENT;
278 }
279
280 ctx->real_sock = sock;
281 ctx->recv_buf.buf = wreq->tmp_buf;
282 ctx->recv_buf.size = wreq->tmp_buf_len;
283 ctx->sec_accept_key = sec_accept_key;
284 ctx->http_cb = wreq->http_cb;
285 ctx->is_client = 1;
286
287 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
288 psa_status = psa_hash_compute(PSA_ALG_SHA_1, (const uint8_t *)&rnd_value, sizeof(rnd_value),
289 sec_accept_key, sizeof(sec_accept_key), &hash_length);
290 if (psa_status != PSA_SUCCESS) {
291 NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, psa_status);
292 ret = -EPROTO;
293 goto out;
294 }
295 #else
296 ret = mbedtls_sha1((const unsigned char *)&rnd_value, sizeof(rnd_value), sec_accept_key);
297 if (ret != 0) {
298 NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, ret);
299 ret = -EPROTO;
300 goto out;
301 }
302 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
303
304
305 ret = base64_encode(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
306 sizeof(sec_ws_key) -
307 sizeof("Sec-Websocket-Key: "),
308 &olen, sec_accept_key,
309 /* We are only interested in 16 first bytes so
310 * subtract 4 from the SHA-1 length
311 */
312 sizeof(sec_accept_key) - 4);
313 if (ret) {
314 NET_DBG("[%p] Cannot encode base64 (%d)", ctx, ret);
315 goto out;
316 }
317
318 if ((olen + sizeof("Sec-Websocket-Key: ") + 2) > sizeof(sec_ws_key)) {
319 NET_DBG("[%p] Too long message (%zd > %zd)", ctx,
320 olen + sizeof("Sec-Websocket-Key: ") + 2,
321 sizeof(sec_ws_key));
322 ret = -EMSGSIZE;
323 goto out;
324 }
325
326 memcpy(sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1 + olen,
327 HTTP_CRLF, sizeof(HTTP_CRLF));
328
329 memset(&req, 0, sizeof(req));
330
331 req.method = HTTP_GET;
332 req.url = wreq->url;
333 req.host = wreq->host;
334 req.protocol = "HTTP/1.1";
335 req.header_fields = (const char **)headers;
336 req.optional_headers_cb = wreq->optional_headers_cb;
337 req.optional_headers = wreq->optional_headers;
338 req.response = response_cb;
339 req.http_cb = &http_parser_settings;
340 req.recv_buf = wreq->tmp_buf;
341 req.recv_buf_len = wreq->tmp_buf_len;
342
343 /* We need to catch the Sec-WebSocket-Accept field in order to verify
344 * that it contains the stuff that we sent in Sec-WebSocket-Key field
345 * so setup HTTP callbacks so that we will get the needed fields.
346 */
347 if (ctx->http_cb) {
348 memcpy(&http_parser_settings, ctx->http_cb,
349 sizeof(http_parser_settings));
350 } else {
351 memset(&http_parser_settings, 0, sizeof(http_parser_settings));
352 }
353
354 http_parser_settings.on_header_field = on_header_field;
355 http_parser_settings.on_header_value = on_header_value;
356
357 /* Pre-calculate the expected Sec-Websocket-Accept field */
358 key_len = MIN(sizeof(key_accept) - 1, olen);
359 strncpy(key_accept, sec_ws_key + sizeof("Sec-Websocket-Key: ") - 1,
360 key_len);
361
362 olen = MIN(sizeof(key_accept) - 1 - key_len, sizeof(WS_MAGIC) - 1);
363 memcpy(key_accept + key_len, WS_MAGIC, olen);
364
365 /* This SHA-1 value is then checked when we receive the response */
366 #ifdef CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT
367 psa_status = psa_hash_compute(PSA_ALG_SHA_1, (const uint8_t *)key_accept, olen + key_len,
368 sec_accept_key, sizeof(sec_accept_key), &hash_length);
369 if (psa_status != PSA_SUCCESS) {
370 NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, psa_status);
371 ret = -EPROTO;
372 goto out;
373 }
374 #else
375 ret = mbedtls_sha1(key_accept, olen + key_len, sec_accept_key);
376 if (ret != 0) {
377 NET_DBG("[%p] Cannot calculate sha1 (%d)", ctx, ret);
378 ret = -EPROTO;
379 goto out;
380 }
381 #endif /* CONFIG_MBEDTLS_PSA_CRYPTO_CLIENT */
382
383 ret = http_client_req(sock, &req, timeout, ctx);
384 if (ret < 0) {
385 NET_DBG("[%p] Cannot connect to Websocket host %s", ctx,
386 wreq->host);
387 ret = -ECONNABORTED;
388 goto out;
389 }
390
391 if (!(ctx->all_received && ctx->sec_accept_ok)) {
392 NET_DBG("[%p] WS handshake failed (%d/%d)", ctx,
393 ctx->all_received, ctx->sec_accept_ok);
394 ret = -ECONNABORTED;
395 goto out;
396 }
397
398 ctx->user_data = user_data;
399
400 fd = zvfs_reserve_fd();
401 if (fd < 0) {
402 ret = -ENOSPC;
403 goto out;
404 }
405
406 ctx->sock = fd;
407 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
408 ZVFS_MODE_IFSOCK);
409
410 /* Call the user specified callback and if it accepts the connection
411 * then continue.
412 */
413 if (wreq->cb) {
414 ret = wreq->cb(fd, &req, user_data);
415 if (ret < 0) {
416 NET_DBG("[%p] Connection aborted (%d)", ctx, ret);
417 goto out;
418 }
419 }
420
421 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
422
423 /* We will re-use the temp buffer in receive function. If there were
424 * any leftover data from HTTP headers processing, we need to reflect
425 * this in the count variable.
426 */
427 ctx->recv_buf.count = req.data_len;
428
429 /* Init parser FSM */
430 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
431
432 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, NET_SOCK_STREAM);
433
434 return fd;
435
436 out:
437 if (fd >= 0) {
438 (void)zsock_close(fd);
439 }
440
441 websocket_context_unref(ctx);
442 return ret;
443 }
444
websocket_disconnect(int ws_sock)445 int websocket_disconnect(int ws_sock)
446 {
447 return zsock_close(ws_sock);
448 }
449
websocket_internal_disconnect(struct websocket_context * ctx)450 static int websocket_internal_disconnect(struct websocket_context *ctx)
451 {
452 int ret;
453
454 if (ctx == NULL) {
455 return -ENOENT;
456 }
457
458 NET_DBG("[%p] Disconnecting", ctx);
459
460 ret = websocket_send_msg(ctx->sock, NULL, 0, WEBSOCKET_OPCODE_CLOSE,
461 ctx->is_client, true, SYS_FOREVER_MS);
462 if (ret < 0) {
463 NET_DBG("[%p] Failed to send close message (err %d).", ctx, ret);
464 }
465
466 (void)sock_obj_core_dealloc(ctx->sock);
467
468 websocket_context_unref(ctx);
469
470 return ret;
471 }
472
websocket_close_vmeth(void * obj)473 static int websocket_close_vmeth(void *obj)
474 {
475 struct websocket_context *ctx = obj;
476 int ret;
477
478 ret = websocket_internal_disconnect(ctx);
479 if (ret < 0) {
480 /* Ignore error if we are not connected */
481 if (ret != -ENOTCONN) {
482 NET_DBG("[%p] Cannot close (%d)", obj, ret);
483
484 errno = -ret;
485 return -1;
486 }
487
488 ret = 0;
489 }
490
491 return ret;
492 }
493
websocket_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)494 static inline int websocket_poll_offload(struct zsock_pollfd *fds, int nfds,
495 int timeout)
496 {
497 int fd_backup[CONFIG_ZVFS_POLL_MAX];
498 const struct fd_op_vtable *vtable;
499 void *ctx;
500 int ret = 0;
501 int i;
502
503 /* Overwrite websocket file descriptors with underlying ones. */
504 for (i = 0; i < nfds; i++) {
505 fd_backup[i] = fds[i].fd;
506
507 ctx = zvfs_get_fd_obj(fds[i].fd,
508 (const struct fd_op_vtable *)
509 &websocket_fd_op_vtable,
510 0);
511 if (ctx == NULL) {
512 continue;
513 }
514
515 fds[i].fd = ((struct websocket_context *)ctx)->real_sock;
516 }
517
518 /* Get offloaded sockets vtable. */
519 ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
520 (const struct fd_op_vtable **)&vtable,
521 NULL);
522 if (ctx == NULL) {
523 errno = EINVAL;
524 ret = -1;
525 goto exit;
526 }
527
528 ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
529 fds, nfds, timeout);
530
531 exit:
532 /* Restore original fds. */
533 for (i = 0; i < nfds; i++) {
534 fds[i].fd = fd_backup[i];
535 }
536
537 return ret;
538 }
539
websocket_ioctl_vmeth(void * obj,unsigned int request,va_list args)540 static int websocket_ioctl_vmeth(void *obj, unsigned int request, va_list args)
541 {
542 struct websocket_context *ctx = obj;
543
544 switch (request) {
545 case ZFD_IOCTL_POLL_OFFLOAD: {
546 struct zsock_pollfd *fds;
547 int nfds;
548 int timeout;
549
550 fds = va_arg(args, struct zsock_pollfd *);
551 nfds = va_arg(args, int);
552 timeout = va_arg(args, int);
553
554 return websocket_poll_offload(fds, nfds, timeout);
555 }
556
557 case ZFD_IOCTL_SET_LOCK:
558 /* Ignore, don't want to overwrite underlying socket lock. */
559 return 0;
560
561 default: {
562 const struct fd_op_vtable *vtable;
563 void *core_obj;
564
565 core_obj = zvfs_get_fd_obj_and_vtable(
566 ctx->real_sock,
567 (const struct fd_op_vtable **)&vtable,
568 NULL);
569 if (core_obj == NULL) {
570 errno = EBADF;
571 return -1;
572 }
573
574 /* Pass the call to the core socket implementation. */
575 return vtable->ioctl(core_obj, request, args);
576 }
577 }
578
579 return 0;
580 }
581
582 #if !defined(CONFIG_NET_TEST)
sendmsg_all(int sock,const struct net_msghdr * message,int flags,const k_timepoint_t req_end_timepoint)583 static int sendmsg_all(int sock, const struct net_msghdr *message, int flags,
584 const k_timepoint_t req_end_timepoint)
585 {
586 int ret, i;
587 size_t offset = 0;
588 size_t total_len = 0;
589
590 for (i = 0; i < message->msg_iovlen; i++) {
591 total_len += message->msg_iov[i].iov_len;
592 }
593
594 while (offset < total_len) {
595 ret = zsock_sendmsg(sock, message, flags);
596
597 if ((ret == 0) || (ret < 0 && errno == EAGAIN)) {
598 struct zsock_pollfd pfd;
599 int pollres;
600 k_ticks_t req_timeout_ticks =
601 sys_timepoint_timeout(req_end_timepoint).ticks;
602 int req_timeout_ms = k_ticks_to_ms_floor32(req_timeout_ticks);
603
604 pfd.fd = sock;
605 pfd.events = ZSOCK_POLLOUT;
606 pollres = zsock_poll(&pfd, 1, req_timeout_ms);
607 if (pollres == 0) {
608 return -ETIMEDOUT;
609 } else if (pollres > 0) {
610 continue;
611 } else {
612 return -errno;
613 }
614 } else if (ret < 0) {
615 return -errno;
616 }
617
618 offset += ret;
619 if (offset >= total_len) {
620 break;
621 }
622
623 /* Update msghdr for the next iteration. */
624 for (i = 0; i < message->msg_iovlen; i++) {
625 if (ret < message->msg_iov[i].iov_len) {
626 message->msg_iov[i].iov_len -= ret;
627 message->msg_iov[i].iov_base =
628 (uint8_t *)message->msg_iov[i].iov_base + ret;
629 break;
630 }
631
632 ret -= message->msg_iov[i].iov_len;
633 message->msg_iov[i].iov_len = 0;
634 }
635 }
636
637 return total_len;
638 }
639 #endif /* !defined(CONFIG_NET_TEST) */
640
websocket_prepare_and_send(struct websocket_context * ctx,uint8_t * header,size_t header_len,uint8_t * payload,size_t payload_len,int32_t timeout)641 static int websocket_prepare_and_send(struct websocket_context *ctx,
642 uint8_t *header, size_t header_len,
643 uint8_t *payload, size_t payload_len,
644 int32_t timeout)
645 {
646 struct net_iovec io_vector[2];
647 struct net_msghdr msg;
648
649 io_vector[0].iov_base = header;
650 io_vector[0].iov_len = header_len;
651 io_vector[1].iov_base = payload;
652 io_vector[1].iov_len = payload_len;
653
654 memset(&msg, 0, sizeof(msg));
655
656 msg.msg_iov = io_vector;
657 msg.msg_iovlen = ARRAY_SIZE(io_vector);
658
659 if (HEXDUMP_SENT_PACKETS) {
660 LOG_HEXDUMP_DBG(header, header_len, "Header");
661 if ((payload != NULL) && (payload_len > 0)) {
662 LOG_HEXDUMP_DBG(payload, payload_len, "Payload");
663 } else {
664 LOG_DBG("No payload");
665 }
666 }
667
668 #if defined(CONFIG_NET_TEST)
669 /* Simulate a case where the payload is split to two. The unit test
670 * does not set mask bit in this case.
671 */
672 return verify_sent_and_received_msg(&msg, !(header[1] & BIT(7)));
673 #else
674 k_timeout_t tout = K_FOREVER;
675
676 if (timeout != SYS_FOREVER_MS) {
677 tout = K_MSEC(timeout);
678 }
679
680 k_timeout_t req_timeout = K_MSEC(timeout);
681 k_timepoint_t req_end_timepoint = sys_timepoint_calc(req_timeout);
682
683 return sendmsg_all(ctx->real_sock, &msg,
684 K_TIMEOUT_EQ(tout, K_NO_WAIT) ? ZSOCK_MSG_DONTWAIT : 0,
685 req_end_timepoint);
686 #endif /* CONFIG_NET_TEST */
687 }
688
websocket_send_msg(int ws_sock,const uint8_t * payload,size_t payload_len,enum websocket_opcode opcode,bool mask,bool final,int32_t timeout)689 int websocket_send_msg(int ws_sock, const uint8_t *payload, size_t payload_len,
690 enum websocket_opcode opcode, bool mask, bool final,
691 int32_t timeout)
692 {
693 struct websocket_context *ctx;
694 uint8_t header[MAX_HEADER_LEN], hdr_len = 2;
695 uint8_t *data_to_send = (uint8_t *)payload;
696 int ret;
697
698 if (opcode != WEBSOCKET_OPCODE_DATA_TEXT &&
699 opcode != WEBSOCKET_OPCODE_DATA_BINARY &&
700 opcode != WEBSOCKET_OPCODE_CONTINUE &&
701 opcode != WEBSOCKET_OPCODE_CLOSE &&
702 opcode != WEBSOCKET_OPCODE_PING &&
703 opcode != WEBSOCKET_OPCODE_PONG) {
704 return -EINVAL;
705 }
706
707 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
708 if (ctx == NULL) {
709 return -EBADF;
710 }
711
712 #if !defined(CONFIG_NET_TEST)
713 /* Websocket unit test does not use context from pool but allocates
714 * its own, hence skip the check.
715 */
716
717 if (!PART_OF_ARRAY(contexts, ctx)) {
718 return -ENOENT;
719 }
720 #endif /* !defined(CONFIG_NET_TEST) */
721
722 NET_DBG("[%p] Len %zd %s/%d/%s", ctx, payload_len, opcode2str(opcode),
723 mask, final ? "final" : "more");
724
725 memset(header, 0, sizeof(header));
726
727 /* Is this the last packet? */
728 header[0] = final ? BIT(7) : 0;
729
730 /* Text, binary, ping, pong or close ? */
731 header[0] |= opcode;
732
733 /* Masking */
734 header[1] = mask ? BIT(7) : 0;
735
736 if (payload_len < 126) {
737 header[1] |= payload_len;
738 } else if (payload_len < 65536) {
739 header[1] |= 126;
740 header[2] = payload_len >> 8;
741 header[3] = payload_len;
742 hdr_len += 2;
743 } else {
744 header[1] |= 127;
745 header[2] = 0;
746 header[3] = 0;
747 header[4] = 0;
748 header[5] = 0;
749 header[6] = payload_len >> 24;
750 header[7] = payload_len >> 16;
751 header[8] = payload_len >> 8;
752 header[9] = payload_len;
753 hdr_len += 8;
754 }
755
756 /* Add masking value if needed */
757 if (mask) {
758 int i;
759
760 ctx->masking_value = sys_rand32_get();
761
762 header[hdr_len++] |= ctx->masking_value >> 24;
763 header[hdr_len++] |= ctx->masking_value >> 16;
764 header[hdr_len++] |= ctx->masking_value >> 8;
765 header[hdr_len++] |= ctx->masking_value;
766
767 if ((payload != NULL) && (payload_len > 0)) {
768 data_to_send = k_malloc(payload_len);
769 if (!data_to_send) {
770 return -ENOMEM;
771 }
772
773 memcpy(data_to_send, payload, payload_len);
774
775 for (i = 0; i < payload_len; i++) {
776 data_to_send[i] ^= ctx->masking_value >> (8 * (3 - i % 4));
777 }
778 }
779 }
780
781 ret = websocket_prepare_and_send(ctx, header, hdr_len,
782 data_to_send, payload_len, timeout);
783 if (ret < 0) {
784 NET_DBG("Cannot send ws msg (%d)", -errno);
785 goto quit;
786 }
787
788 quit:
789 if (data_to_send != payload) {
790 k_free(data_to_send);
791 }
792
793 /* Do no math with 0 and error codes */
794 if (ret <= 0) {
795 return ret;
796 }
797
798 return ret - hdr_len;
799 }
800
websocket_opcode2flag(uint8_t data)801 static uint32_t websocket_opcode2flag(uint8_t data)
802 {
803 switch (data & 0x0f) {
804 case WEBSOCKET_OPCODE_DATA_TEXT:
805 return WEBSOCKET_FLAG_TEXT;
806 case WEBSOCKET_OPCODE_DATA_BINARY:
807 return WEBSOCKET_FLAG_BINARY;
808 case WEBSOCKET_OPCODE_CLOSE:
809 return WEBSOCKET_FLAG_CLOSE;
810 case WEBSOCKET_OPCODE_PING:
811 return WEBSOCKET_FLAG_PING;
812 case WEBSOCKET_OPCODE_PONG:
813 return WEBSOCKET_FLAG_PONG;
814 default:
815 break;
816 }
817 return 0;
818 }
819
websocket_parse(struct websocket_context * ctx,struct websocket_buffer * payload)820 static int websocket_parse(struct websocket_context *ctx, struct websocket_buffer *payload)
821 {
822 int len;
823 uint8_t data;
824 size_t parsed_count = 0;
825
826 do {
827 if (parsed_count >= ctx->recv_buf.count) {
828 return parsed_count;
829 }
830 if (ctx->parser_state != WEBSOCKET_PARSER_STATE_PAYLOAD) {
831 data = ctx->recv_buf.buf[parsed_count++];
832
833 switch (ctx->parser_state) {
834 case WEBSOCKET_PARSER_STATE_OPCODE:
835 ctx->message_type = websocket_opcode2flag(data);
836 if ((data & 0x80) != 0) {
837 ctx->message_type |= WEBSOCKET_FLAG_FINAL;
838 }
839 ctx->parser_state = WEBSOCKET_PARSER_STATE_LENGTH;
840 break;
841 case WEBSOCKET_PARSER_STATE_LENGTH:
842 ctx->masked = (data & 0x80) != 0;
843 len = data & 0x7f;
844 if (len < 126) {
845 ctx->message_len = len;
846 if (ctx->masked) {
847 ctx->masking_value = 0;
848 ctx->parser_remaining = 4;
849 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
850 } else {
851 ctx->parser_remaining = ctx->message_len;
852 ctx->parser_state =
853 (ctx->parser_remaining == 0)
854 ? WEBSOCKET_PARSER_STATE_OPCODE
855 : WEBSOCKET_PARSER_STATE_PAYLOAD;
856 }
857 } else {
858 ctx->message_len = 0;
859 ctx->parser_remaining = (len < 127) ? 2 : 8;
860 ctx->parser_state = WEBSOCKET_PARSER_STATE_EXT_LEN;
861 }
862 break;
863 case WEBSOCKET_PARSER_STATE_EXT_LEN:
864 ctx->parser_remaining--;
865 ctx->message_len |= ((uint64_t)data << (ctx->parser_remaining * 8));
866 if (ctx->parser_remaining == 0) {
867 if (ctx->masked) {
868 ctx->masking_value = 0;
869 ctx->parser_remaining = 4;
870 ctx->parser_state = WEBSOCKET_PARSER_STATE_MASK;
871 } else {
872 ctx->parser_remaining = ctx->message_len;
873 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
874 }
875 }
876 break;
877 case WEBSOCKET_PARSER_STATE_MASK:
878 ctx->parser_remaining--;
879 ctx->masking_value |=
880 (uint32_t)((uint64_t)data << (ctx->parser_remaining * 8));
881 if (ctx->parser_remaining == 0) {
882 if (ctx->message_len == 0) {
883 ctx->parser_remaining = 0;
884 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
885 } else {
886 ctx->parser_remaining = ctx->message_len;
887 ctx->parser_state = WEBSOCKET_PARSER_STATE_PAYLOAD;
888 }
889 }
890 break;
891 default:
892 return -EFAULT;
893 }
894
895 #if (LOG_LEVEL >= LOG_LEVEL_DBG)
896 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_PAYLOAD) ||
897 ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) &&
898 (ctx->message_len == 0))) {
899 NET_DBG("[%p] %smasked, mask 0x%08x, type 0x%02x, msg %zd", ctx,
900 ctx->masked ? "" : "un",
901 ctx->masked ? ctx->masking_value : 0, ctx->message_type,
902 (size_t)ctx->message_len);
903 }
904 #endif
905 } else {
906 size_t remaining_in_recv_buf = ctx->recv_buf.count - parsed_count;
907 size_t payload_in_recv_buf =
908 MIN(remaining_in_recv_buf, ctx->parser_remaining);
909 size_t free_in_payload_buf = payload->size - payload->count;
910 size_t ready_to_copy = MIN(payload_in_recv_buf, free_in_payload_buf);
911
912 if (free_in_payload_buf == 0) {
913 break;
914 }
915
916 memcpy(&payload->buf[payload->count], &ctx->recv_buf.buf[parsed_count],
917 ready_to_copy);
918 parsed_count += ready_to_copy;
919 payload->count += ready_to_copy;
920 ctx->parser_remaining -= ready_to_copy;
921 if (ctx->parser_remaining == 0) {
922 ctx->parser_remaining = 0;
923 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
924 }
925 }
926
927 } while (ctx->parser_state != WEBSOCKET_PARSER_STATE_OPCODE);
928
929 return parsed_count;
930 }
931
932 #if !defined(CONFIG_NET_TEST)
wait_rx(int sock,int timeout)933 static int wait_rx(int sock, int timeout)
934 {
935 struct zsock_pollfd fds = {
936 .fd = sock,
937 .events = ZSOCK_POLLIN,
938 };
939 int ret;
940
941 ret = zsock_poll(&fds, 1, timeout);
942 if (ret < 0) {
943 return ret;
944 }
945
946 if (ret == 0) {
947 /* Timeout */
948 return -EAGAIN;
949 }
950
951 if (fds.revents & ZSOCK_POLLNVAL) {
952 return -EBADF;
953 }
954
955 if (fds.revents & ZSOCK_POLLERR) {
956 return -EIO;
957 }
958
959 return 0;
960 }
961
timeout_to_ms(k_timeout_t * timeout)962 static int timeout_to_ms(k_timeout_t *timeout)
963 {
964 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
965 return 0;
966 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
967 return SYS_FOREVER_MS;
968 } else {
969 return k_ticks_to_ms_floor32(timeout->ticks);
970 }
971 }
972
973 #endif /* !defined(CONFIG_NET_TEST) */
974
websocket_recv_msg(int ws_sock,uint8_t * buf,size_t buf_len,uint32_t * message_type,uint64_t * remaining,int32_t timeout)975 int websocket_recv_msg(int ws_sock, uint8_t *buf, size_t buf_len,
976 uint32_t *message_type, uint64_t *remaining, int32_t timeout)
977 {
978 struct websocket_context *ctx;
979 int ret;
980 k_timepoint_t end;
981 k_timeout_t tout = K_FOREVER;
982 struct websocket_buffer payload = {.buf = buf, .size = buf_len, .count = 0};
983
984 if (timeout != SYS_FOREVER_MS) {
985 tout = K_MSEC(timeout);
986 }
987
988 if ((buf == NULL) || (buf_len == 0)) {
989 return -EINVAL;
990 }
991
992 end = sys_timepoint_calc(tout);
993
994 #if defined(CONFIG_NET_TEST)
995 struct test_data *test_data = zvfs_get_fd_obj(ws_sock, NULL, 0);
996
997 if (test_data == NULL) {
998 return -EBADF;
999 }
1000
1001 ctx = test_data->ctx;
1002 #else
1003 ctx = zvfs_get_fd_obj(ws_sock, NULL, 0);
1004 if (ctx == NULL) {
1005 return -EBADF;
1006 }
1007
1008 if (!PART_OF_ARRAY(contexts, ctx)) {
1009 return -ENOENT;
1010 }
1011 #endif /* CONFIG_NET_TEST */
1012
1013 do {
1014 size_t parsed_count;
1015
1016 if (ctx->recv_buf.count == 0) {
1017 #if defined(CONFIG_NET_TEST)
1018 size_t input_len = MIN(ctx->recv_buf.size,
1019 test_data->input_len - test_data->input_pos);
1020
1021 if (input_len > 0) {
1022 memcpy(ctx->recv_buf.buf,
1023 &test_data->input_buf[test_data->input_pos], input_len);
1024 test_data->input_pos += input_len;
1025 ret = input_len;
1026 } else {
1027 /* emulate timeout */
1028 ret = -EAGAIN;
1029 }
1030 #else
1031 tout = sys_timepoint_timeout(end);
1032
1033 ret = wait_rx(ctx->real_sock, timeout_to_ms(&tout));
1034 if (ret == 0) {
1035 ret = zsock_recv(ctx->real_sock, ctx->recv_buf.buf,
1036 ctx->recv_buf.size, ZSOCK_MSG_DONTWAIT);
1037 if (ret < 0) {
1038 ret = -errno;
1039 }
1040 }
1041 #endif /* CONFIG_NET_TEST */
1042
1043 if (ret < 0) {
1044 if ((ret == -EAGAIN) && (payload.count > 0)) {
1045 /* go to unmasking */
1046 break;
1047 }
1048 return ret;
1049 }
1050
1051 if (ret == 0) {
1052 /* Socket closed */
1053 return -ENOTCONN;
1054 }
1055
1056 ctx->recv_buf.count = ret;
1057
1058 NET_DBG("[%p] Received %d bytes", ctx, ret);
1059 }
1060
1061 ret = websocket_parse(ctx, &payload);
1062 if (ret < 0) {
1063 return ret;
1064 }
1065 parsed_count = ret;
1066
1067 if ((ctx->parser_state == WEBSOCKET_PARSER_STATE_OPCODE) ||
1068 (payload.count >= payload.size)) {
1069 if (remaining != NULL) {
1070 *remaining = ctx->parser_remaining;
1071 }
1072 if (message_type != NULL) {
1073 *message_type = ctx->message_type;
1074 }
1075
1076 size_t left = ctx->recv_buf.count - parsed_count;
1077
1078 if (left > 0) {
1079 memmove(ctx->recv_buf.buf, &ctx->recv_buf.buf[parsed_count], left);
1080 }
1081 ctx->recv_buf.count = left;
1082 break;
1083 }
1084
1085 ctx->recv_buf.count -= parsed_count;
1086
1087 } while (true);
1088
1089 /* Unmask the data */
1090 if (ctx->masked) {
1091 uint8_t *mask_as_bytes = (uint8_t *)&ctx->masking_value;
1092 size_t data_buf_offset = ctx->message_len - ctx->parser_remaining - payload.count;
1093
1094 for (size_t i = 0; i < payload.count; i++) {
1095 size_t m = data_buf_offset % 4;
1096
1097 payload.buf[i] ^= mask_as_bytes[3 - m];
1098 data_buf_offset++;
1099 }
1100 }
1101
1102 if (ctx->message_type == WEBSOCKET_FLAG_CLOSE) {
1103 return websocket_internal_disconnect(ctx);
1104 }
1105
1106 return payload.count;
1107 }
1108
websocket_send(struct websocket_context * ctx,const uint8_t * buf,size_t buf_len,int32_t timeout)1109 static int websocket_send(struct websocket_context *ctx, const uint8_t *buf,
1110 size_t buf_len, int32_t timeout)
1111 {
1112 int ret;
1113
1114 NET_DBG("[%p] Sending %zd bytes", ctx, buf_len);
1115
1116 ret = websocket_send_msg(ctx->sock, buf, buf_len, WEBSOCKET_OPCODE_DATA_TEXT,
1117 ctx->is_client, true, timeout);
1118 if (ret < 0) {
1119 errno = -ret;
1120 return -1;
1121 }
1122
1123 NET_DBG("[%p] Sent %d bytes", ctx, ret);
1124
1125 sock_obj_core_update_send_stats(ctx->sock, ret);
1126
1127 return ret;
1128 }
1129
websocket_recv(struct websocket_context * ctx,uint8_t * buf,size_t buf_len,int32_t timeout)1130 static int websocket_recv(struct websocket_context *ctx, uint8_t *buf,
1131 size_t buf_len, int32_t timeout)
1132 {
1133 uint32_t message_type;
1134 uint64_t remaining;
1135 int ret;
1136
1137 NET_DBG("[%p] Waiting data, buf len %zd bytes", ctx, buf_len);
1138
1139 /* TODO: add support for recvmsg() so that we could return the
1140 * websocket specific information in ancillary data.
1141 */
1142 ret = websocket_recv_msg(ctx->sock, buf, buf_len, &message_type,
1143 &remaining, timeout);
1144 if (ret < 0) {
1145 if (ret == -ENOTCONN) {
1146 ret = 0;
1147 } else {
1148 errno = -ret;
1149 return -1;
1150 }
1151 }
1152
1153 NET_DBG("[%p] Received %d bytes", ctx, ret);
1154
1155 sock_obj_core_update_recv_stats(ctx->sock, ret);
1156
1157 return ret;
1158 }
1159
websocket_read_vmeth(void * obj,void * buffer,size_t count)1160 static ssize_t websocket_read_vmeth(void *obj, void *buffer, size_t count)
1161 {
1162 return (ssize_t)websocket_recv(obj, buffer, count, SYS_FOREVER_MS);
1163 }
1164
websocket_write_vmeth(void * obj,const void * buffer,size_t count)1165 static ssize_t websocket_write_vmeth(void *obj, const void *buffer,
1166 size_t count)
1167 {
1168 return (ssize_t)websocket_send(obj, buffer, count, SYS_FOREVER_MS);
1169 }
1170
websocket_sendto_ctx(void * obj,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)1171 static ssize_t websocket_sendto_ctx(void *obj, const void *buf, size_t len,
1172 int flags,
1173 const struct net_sockaddr *dest_addr,
1174 net_socklen_t addrlen)
1175 {
1176 struct websocket_context *ctx = obj;
1177 int32_t timeout = SYS_FOREVER_MS;
1178
1179 if (flags & ZSOCK_MSG_DONTWAIT) {
1180 timeout = 0;
1181 }
1182
1183 ARG_UNUSED(dest_addr);
1184 ARG_UNUSED(addrlen);
1185
1186 return (ssize_t)websocket_send(ctx, buf, len, timeout);
1187 }
1188
websocket_recvfrom_ctx(void * obj,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)1189 static ssize_t websocket_recvfrom_ctx(void *obj, void *buf, size_t max_len,
1190 int flags, struct net_sockaddr *src_addr,
1191 net_socklen_t *addrlen)
1192 {
1193 struct websocket_context *ctx = obj;
1194 int32_t timeout = SYS_FOREVER_MS;
1195
1196 if (flags & ZSOCK_MSG_DONTWAIT) {
1197 timeout = 0;
1198 }
1199
1200 ARG_UNUSED(src_addr);
1201 ARG_UNUSED(addrlen);
1202
1203 return (ssize_t)websocket_recv(ctx, buf, max_len, timeout);
1204 }
1205
websocket_register(int sock,uint8_t * recv_buf,size_t recv_buf_len)1206 int websocket_register(int sock, uint8_t *recv_buf, size_t recv_buf_len)
1207 {
1208 struct websocket_context *ctx;
1209 int ret, fd;
1210
1211 if (sock < 0) {
1212 return -EINVAL;
1213 }
1214
1215 ctx = websocket_find(sock);
1216 if (ctx) {
1217 NET_DBG("[%p] Websocket for sock %d already exists!", ctx, sock);
1218 return -EEXIST;
1219 }
1220
1221 ctx = websocket_get();
1222 if (!ctx) {
1223 return -ENOENT;
1224 }
1225
1226 ctx->real_sock = sock;
1227 ctx->recv_buf.buf = recv_buf;
1228 ctx->recv_buf.size = recv_buf_len;
1229 ctx->is_client = 0;
1230
1231 fd = zvfs_reserve_fd();
1232 if (fd < 0) {
1233 ret = -ENOSPC;
1234 goto out;
1235 }
1236
1237 ctx->sock = fd;
1238 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&websocket_fd_op_vtable,
1239 ZVFS_MODE_IFSOCK);
1240
1241 NET_DBG("[%p] WS connection to peer established (fd %d)", ctx, fd);
1242
1243 ctx->recv_buf.count = 0;
1244 ctx->parser_state = WEBSOCKET_PARSER_STATE_OPCODE;
1245
1246 (void)sock_obj_core_alloc_find(ctx->real_sock, fd, NET_SOCK_STREAM);
1247
1248 return fd;
1249
1250 out:
1251 websocket_context_unref(ctx);
1252
1253 return ret;
1254 }
1255
websocket_search(int sock)1256 static struct websocket_context *websocket_search(int sock)
1257 {
1258 struct websocket_context *ctx = NULL;
1259 int i;
1260
1261 k_sem_take(&contexts_lock, K_FOREVER);
1262
1263 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1264 if (!websocket_context_is_used(&contexts[i])) {
1265 continue;
1266 }
1267
1268 if (contexts[i].sock != sock) {
1269 continue;
1270 }
1271
1272 ctx = &contexts[i];
1273 break;
1274 }
1275
1276 k_sem_give(&contexts_lock);
1277
1278 return ctx;
1279 }
1280
websocket_unregister(int sock)1281 int websocket_unregister(int sock)
1282 {
1283 struct websocket_context *ctx;
1284
1285 if (sock < 0) {
1286 return -EINVAL;
1287 }
1288
1289 ctx = websocket_search(sock);
1290 if (ctx == NULL) {
1291 NET_DBG("[%p] Real socket for websocket sock %d not found!", ctx, sock);
1292 return -ENOENT;
1293 }
1294
1295 if (ctx->real_sock < 0) {
1296 return -EALREADY;
1297 }
1298
1299 (void)zsock_close(sock);
1300 (void)zsock_close(ctx->real_sock);
1301
1302 ctx->real_sock = -1;
1303 ctx->sock = -1;
1304
1305 return 0;
1306 }
1307
1308 static const struct socket_op_vtable websocket_fd_op_vtable = {
1309 .fd_vtable = {
1310 .read = websocket_read_vmeth,
1311 .write = websocket_write_vmeth,
1312 .close = websocket_close_vmeth,
1313 .ioctl = websocket_ioctl_vmeth,
1314 },
1315 .sendto = websocket_sendto_ctx,
1316 .recvfrom = websocket_recvfrom_ctx,
1317 };
1318
websocket_context_foreach(websocket_context_cb_t cb,void * user_data)1319 void websocket_context_foreach(websocket_context_cb_t cb, void *user_data)
1320 {
1321 int i;
1322
1323 k_sem_take(&contexts_lock, K_FOREVER);
1324
1325 for (i = 0; i < ARRAY_SIZE(contexts); i++) {
1326 if (!websocket_context_is_used(&contexts[i])) {
1327 continue;
1328 }
1329
1330 k_mutex_lock(&contexts[i].lock, K_FOREVER);
1331
1332 cb(&contexts[i], user_data);
1333
1334 k_mutex_unlock(&contexts[i].lock);
1335 }
1336
1337 k_sem_give(&contexts_lock);
1338 }
1339
websocket_init(void)1340 void websocket_init(void)
1341 {
1342 k_sem_init(&contexts_lock, 1, K_SEM_MAX_LIMIT);
1343 }
1344