1 /*
2 * Copyright (c) 2020 Intel Corporation
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_test, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
9
10 #include <zephyr/ztest_assert.h>
11
12 #include <zephyr/misc/lorem_ipsum.h>
13 #include <zephyr/net/net_ip.h>
14 #include <zephyr/net/socket.h>
15 #include <zephyr/net/websocket.h>
16 #include <zephyr/sys/fdtable.h>
17
18 #include "websocket_internal.h"
19
20 /* Generated by http://www.lipsum.com/
21 * 2 paragraphs, 178 words, 1160 bytes of Lorem Ipsum
22 */
23 static const char lorem_ipsum[] = LOREM_IPSUM;
24
25 #define MAX_RECV_BUF_LEN 256
26 static uint8_t recv_buf[MAX(sizeof(lorem_ipsum), MAX_RECV_BUF_LEN)];
27
28 /* We need to allocate bigger buffer for the websocket data we receive so that
29 * the websocket header fits into it.
30 */
31 #define EXTRA_BUF_SPACE 30
32
33 static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
34 static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
35 static size_t test_msg_len;
36
test_fd_alloc(void * obj)37 static int test_fd_alloc(void *obj)
38 {
39 int fd;
40
41 fd = zvfs_reserve_fd();
42 zassert_not_equal(fd, -1, "Failed to allocate FD");
43 zvfs_finalize_fd(fd, obj, NULL);
44
45 return fd;
46 }
47
test_recv_buf(uint8_t * input_buf,size_t input_len,struct websocket_context * ctx,uint32_t * msg_type,uint64_t * remaining,uint8_t * recv_buffer,size_t recv_len)48 static int test_recv_buf(uint8_t *input_buf, size_t input_len,
49 struct websocket_context *ctx,
50 uint32_t *msg_type, uint64_t *remaining,
51 uint8_t *recv_buffer, size_t recv_len)
52 {
53 static struct test_data test_data;
54 int fd, ret;
55
56 test_data.ctx = ctx;
57 test_data.input_buf = input_buf;
58 test_data.input_len = input_len;
59 test_data.input_pos = 0;
60
61 fd = test_fd_alloc(&test_data);
62
63 ret = websocket_recv_msg(fd, recv_buffer, recv_len,
64 msg_type, remaining, 0);
65
66 zvfs_free_fd(fd);
67
68 return ret;
69 }
70
71 /* Websocket frame, header is 6 bytes, FIN bit is set, opcode is text (1),
72 * payload length is 12, masking key is e17e8eb9,
73 * unmasked data is "test message"
74 */
75 static const unsigned char frame1[] = {
76 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
77 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
78 0xe9, 0xdc
79 };
80
81 static const unsigned char frame1_msg[] = {
82 /* Null added for printing purposes */
83 't', 'e', 's', 't', ' ', 'm', 'e', 's', 's', 'a', 'g', 'e', '\0'
84 };
85
86 /* The frame2 has frame1 + frame1. The idea is to test a case where we
87 * read full frame1 and then part of second frame
88 */
89 static const unsigned char frame2[] = {
90 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
91 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
92 0xe9, 0xdc,
93 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
94 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
95 0xe9, 0xdc
96 };
97
98 /* Empty websocket frame, opcode is ping, without mask */
99 static const unsigned char ping[] = {0x89, 0x00};
100
101 #define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1))
102
test_recv(int count)103 static void test_recv(int count)
104 {
105 struct websocket_context ctx;
106 uint32_t msg_type = -1;
107 uint64_t remaining = -1;
108 int total_read = 0;
109 int ret, i, left;
110
111 memset(&ctx, 0, sizeof(ctx));
112
113 ctx.recv_buf.buf = temp_recv_buf;
114 ctx.recv_buf.size = sizeof(temp_recv_buf);
115 ctx.recv_buf.count = 0;
116
117 memcpy(feed_buf, &frame1, sizeof(frame1));
118
119 NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
120 sizeof(frame1), FRAME1_HDR_SIZE);
121
122 /* We feed the frame N byte(s) at a time */
123 for (i = 0; i < sizeof(frame1) / count; i++) {
124 ret = test_recv_buf(&feed_buf[i * count], count,
125 &ctx, &msg_type, &remaining,
126 recv_buf + total_read,
127 sizeof(recv_buf) - total_read);
128 if (count < 7 && (i * count) < FRAME1_HDR_SIZE) {
129 zassert_equal(ret, -EAGAIN,
130 "[%d] Header parse failed (ret %d)",
131 i * count, ret);
132 } else {
133 total_read += ret;
134 }
135 }
136
137 /* Read any remaining data */
138 left = sizeof(frame1) % count;
139 if (left > 0) {
140 /* Some leftover bytes are still there */
141 ret = test_recv_buf(&feed_buf[sizeof(frame1) - left], left,
142 &ctx, &msg_type, &remaining,
143 recv_buf + total_read,
144 sizeof(recv_buf) - total_read);
145 zassert_true(ret <= (sizeof(recv_buf) - total_read),
146 "Invalid number of bytes read (%d)", ret);
147 total_read += ret;
148 zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
149 "Invalid amount of data read (%d)", ret);
150
151 } else if (total_read < (sizeof(frame1) - FRAME1_HDR_SIZE)) {
152 /* We read the whole message earlier, but we have parsed
153 * only part of the message. Parse the reset of the message
154 * here.
155 */
156 ret = test_recv_buf(&feed_buf[FRAME1_HDR_SIZE + total_read],
157 sizeof(frame1) - FRAME1_HDR_SIZE - total_read,
158 &ctx, &msg_type, &remaining,
159 recv_buf + total_read,
160 sizeof(recv_buf) - total_read);
161 total_read += ret;
162 zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
163 "Invalid amount of data read (%d)", ret);
164 }
165
166 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
167 "Invalid message, should be '%s' was '%s'",
168 frame1_msg, recv_buf);
169
170 zassert_equal(remaining, 0, "Msg not empty");
171 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
172 }
173
ZTEST(net_websocket,test_recv_1_byte)174 ZTEST(net_websocket, test_recv_1_byte)
175 {
176 test_recv(1);
177 }
178
ZTEST(net_websocket,test_recv_2_byte)179 ZTEST(net_websocket, test_recv_2_byte)
180 {
181 test_recv(2);
182 }
183
ZTEST(net_websocket,test_recv_3_byte)184 ZTEST(net_websocket, test_recv_3_byte)
185 {
186 test_recv(3);
187 }
188
ZTEST(net_websocket,test_recv_6_byte)189 ZTEST(net_websocket, test_recv_6_byte)
190 {
191 test_recv(6);
192 }
193
ZTEST(net_websocket,test_recv_7_byte)194 ZTEST(net_websocket, test_recv_7_byte)
195 {
196 test_recv(7);
197 }
198
ZTEST(net_websocket,test_recv_8_byte)199 ZTEST(net_websocket, test_recv_8_byte)
200 {
201 test_recv(8);
202 }
203
ZTEST(net_websocket,test_recv_9_byte)204 ZTEST(net_websocket, test_recv_9_byte)
205 {
206 test_recv(9);
207 }
208
ZTEST(net_websocket,test_recv_10_byte)209 ZTEST(net_websocket, test_recv_10_byte)
210 {
211 test_recv(10);
212 }
213
ZTEST(net_websocket,test_recv_12_byte)214 ZTEST(net_websocket, test_recv_12_byte)
215 {
216 test_recv(12);
217 }
218
ZTEST(net_websocket,test_recv_whole_msg)219 ZTEST(net_websocket, test_recv_whole_msg)
220 {
221 test_recv(sizeof(frame1));
222 }
223
ZTEST(net_websocket,test_recv_empty_ping)224 ZTEST(net_websocket, test_recv_empty_ping)
225 {
226 struct websocket_context ctx;
227 int total_read = 0;
228 uint32_t msg_type = -1;
229 uint64_t remaining = -1;
230
231 memset(&ctx, 0, sizeof(ctx));
232
233 ctx.recv_buf.buf = temp_recv_buf;
234 ctx.recv_buf.size = sizeof(temp_recv_buf);
235 ctx.recv_buf.count = 0;
236
237 memcpy(feed_buf, &ping, sizeof(ping));
238
239 total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining,
240 recv_buf, sizeof(recv_buf));
241
242 zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read);
243 zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping");
244 }
245
test_recv_2(int count)246 static void test_recv_2(int count)
247 {
248 struct websocket_context ctx;
249 uint32_t msg_type = -1;
250 uint64_t remaining = -1;
251 int total_read = 0;
252 int ret;
253
254 memset(&ctx, 0, sizeof(ctx));
255
256 ctx.recv_buf.buf = temp_recv_buf;
257 ctx.recv_buf.size = sizeof(temp_recv_buf);
258
259 memcpy(feed_buf, &frame2, sizeof(frame2));
260
261 NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
262 sizeof(frame2), FRAME1_HDR_SIZE);
263
264 total_read = test_recv_buf(&feed_buf[0], count, &ctx, &msg_type,
265 &remaining, recv_buf, sizeof(recv_buf));
266
267 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
268 "Invalid message, should be '%s' was '%s'",
269 frame1_msg, recv_buf);
270
271 zassert_equal(remaining, 0, "Msg not empty");
272 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
273
274 /* Then read again. Take in account that part of second frame
275 * have read from tx buffer to rx buffer.
276 */
277 ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining,
278 recv_buf, sizeof(recv_buf));
279
280 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
281 "Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf);
282
283 zassert_equal(remaining, 0, "Msg not empty");
284 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
285 }
286
ZTEST(net_websocket,test_recv_two_msg)287 ZTEST(net_websocket, test_recv_two_msg)
288 {
289 test_recv_2(sizeof(frame1) + FRAME1_HDR_SIZE / 2);
290 }
291
verify_sent_and_received_msg(struct msghdr * msg,bool split_msg)292 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg)
293 {
294 static struct websocket_context ctx;
295 uint32_t msg_type = -1;
296 uint64_t remaining = -1;
297 size_t split_len = 0, total_read = 0;
298 int ret;
299
300 memset(&ctx, 0, sizeof(ctx));
301
302 ctx.recv_buf.buf = temp_recv_buf;
303 ctx.recv_buf.size = sizeof(temp_recv_buf);
304
305 /* Read first the header */
306 ret = test_recv_buf(msg->msg_iov[0].iov_base,
307 msg->msg_iov[0].iov_len,
308 &ctx, &msg_type, &remaining,
309 recv_buf, sizeof(recv_buf));
310 if (remaining > 0) {
311 zassert_equal(ret, -EAGAIN, "Msg header not found");
312 } else {
313 zassert_equal(ret, 0, "Msg header read error (ret %d)", ret);
314 }
315
316 /* Then the first split if it is enabled */
317 if (split_msg) {
318 split_len = msg->msg_iov[1].iov_len / 2;
319
320 ret = test_recv_buf(msg->msg_iov[1].iov_base,
321 split_len,
322 &ctx, &msg_type, &remaining,
323 recv_buf, sizeof(recv_buf));
324 zassert_true(ret > 0, "Cannot read data (%d)", ret);
325
326 total_read = ret;
327 }
328
329 /* Then the data */
330 while (remaining > 0) {
331 ret = test_recv_buf((uint8_t *)msg->msg_iov[1].iov_base +
332 total_read,
333 msg->msg_iov[1].iov_len - total_read,
334 &ctx, &msg_type, &remaining,
335 recv_buf, sizeof(recv_buf));
336 zassert_true(ret > 0, "Cannot read data (%d)", ret);
337
338 if (memcmp(recv_buf, lorem_ipsum + total_read, ret) != 0) {
339 LOG_HEXDUMP_ERR(lorem_ipsum + total_read, ret,
340 "Received message should be");
341 LOG_HEXDUMP_ERR(recv_buf, ret, "but it was instead");
342 zassert_true(false, "Invalid received message "
343 "after %d bytes", total_read);
344 }
345
346 total_read += ret;
347 }
348
349 zassert_equal(total_read, test_msg_len,
350 "Msg body not valid, received %d instead of %zd",
351 total_read, test_msg_len);
352
353 NET_DBG("Received %zd header and %zd body",
354 msg->msg_iov[0].iov_len, total_read);
355
356 return msg->msg_iov[0].iov_len + total_read;
357 }
358
ZTEST(net_websocket,test_send_and_recv_lorem_ipsum)359 ZTEST(net_websocket, test_send_and_recv_lorem_ipsum)
360 {
361 static struct websocket_context ctx;
362 int fd, ret;
363
364 memset(&ctx, 0, sizeof(ctx));
365
366 ctx.recv_buf.buf = temp_recv_buf;
367 ctx.recv_buf.size = sizeof(temp_recv_buf);
368
369 test_msg_len = sizeof(lorem_ipsum) - 1;
370
371 fd = test_fd_alloc(&ctx);
372 ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
373 WEBSOCKET_OPCODE_DATA_TEXT, true, true,
374 SYS_FOREVER_MS);
375 zassert_equal(ret, test_msg_len,
376 "Should have sent %zd bytes but sent %d instead",
377 test_msg_len, ret);
378
379 zvfs_free_fd(fd);
380 }
381
ZTEST(net_websocket,test_recv_two_large_split_msg)382 ZTEST(net_websocket, test_recv_two_large_split_msg)
383 {
384 static struct websocket_context ctx;
385 int fd, ret;
386
387 memset(&ctx, 0, sizeof(ctx));
388
389 ctx.recv_buf.buf = temp_recv_buf;
390 ctx.recv_buf.size = sizeof(temp_recv_buf);
391
392 test_msg_len = sizeof(lorem_ipsum) - 1;
393
394 fd = test_fd_alloc(&ctx);
395 ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
396 WEBSOCKET_OPCODE_DATA_TEXT, false, true,
397 SYS_FOREVER_MS);
398 zassert_equal(ret, test_msg_len,
399 "1st should have sent %zd bytes but sent %d instead",
400 test_msg_len, ret);
401
402 zvfs_free_fd(fd);
403 }
404
ZTEST(net_websocket,test_send_and_recv_empty_pong)405 ZTEST(net_websocket, test_send_and_recv_empty_pong)
406 {
407 static struct websocket_context ctx;
408 int fd, ret;
409
410 memset(&ctx, 0, sizeof(ctx));
411
412 ctx.recv_buf.buf = temp_recv_buf;
413 ctx.recv_buf.size = sizeof(temp_recv_buf);
414
415 test_msg_len = 0;
416
417 fd = test_fd_alloc(&ctx);
418 ret = websocket_send_msg(fd, NULL, test_msg_len, WEBSOCKET_OPCODE_PING,
419 true, true, SYS_FOREVER_MS);
420 zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead",
421 test_msg_len, ret);
422
423 zvfs_free_fd(fd);
424 }
425
ZTEST(net_websocket,test_recv_in_small_buffer)426 ZTEST(net_websocket, test_recv_in_small_buffer)
427 {
428 struct websocket_context ctx;
429 uint32_t msg_type = -1;
430 uint64_t remaining = -1;
431 int total_read = 0;
432 int ret;
433 const size_t frame1_msg_size = sizeof(frame1_msg) - 1;
434 const size_t recv_buf_size = 7;
435
436 memset(&ctx, 0, sizeof(ctx));
437
438 ctx.recv_buf.buf = temp_recv_buf;
439 ctx.recv_buf.size = sizeof(temp_recv_buf);
440
441 memcpy(feed_buf, &frame1, sizeof(frame1));
442
443 /* Receive first part of message */
444 ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf,
445 recv_buf_size);
446 zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d",
447 recv_buf_size, ret);
448 total_read += ret;
449
450 /* Receive second part of message */
451 ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining,
452 &recv_buf[recv_buf_size], recv_buf_size);
453 zassert_equal(ret, frame1_msg_size - recv_buf_size,
454 "Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size,
455 ret);
456 total_read += ret;
457
458 /* Check receiving whole message */
459 zassert_equal(total_read, frame1_msg_size, "Received not whole message");
460 zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size,
461 "Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf);
462 }
463
setup(void)464 static void *setup(void)
465 {
466 k_thread_system_pool_assign(k_current_get());
467 return NULL;
468 }
469
470 ZTEST_SUITE(net_websocket, NULL, setup, NULL, NULL, NULL);
471