1 /*
2 * Copyright (c) 2020 Intel Corporation
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_test, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
9
10 #include <zephyr/ztest_assert.h>
11
12 #include <zephyr/net/net_ip.h>
13 #include <zephyr/net/socket.h>
14 #include <zephyr/net/websocket.h>
15 #include <zephyr/sys/fdtable.h>
16
17 #include "websocket_internal.h"
18
19 /* Generated by http://www.lipsum.com/
20 * 2 paragraphs, 178 words, 1160 bytes of Lorem Ipsum
21 */
22 static const char lorem_ipsum[] =
23 "Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
24 "Vestibulum ultricies sapien tellus, ac viverra dolor bibendum "
25 "lacinia. Vestibulum et nisl tristique tellus finibus gravida "
26 "vitae sit amet nunc. Suspendisse maximus justo mi, vitae porta "
27 "risus suscipit vitae. Curabitur ut fringilla velit. Donec ac nisi "
28 "in dui semper lobortis sed nec ante. Sed nec luctus dui. Sed ut "
29 "ante nisi. Mauris congue euismod felis, et maximus ex pellentesque "
30 "nec. Proin nibh nisl, semper at nunc in, mattis pharetra metus. Nam "
31 "turpis risus, pulvinar sit amet varius ac, pellentesque quis purus."
32 " "
33 "Nam consequat purus in lacinia fringilla. Morbi volutpat, tellus "
34 "nec tempus dapibus, ante sem aliquam dui, eu feugiat libero diam "
35 "at leo. Sed suscipit egestas orci in ultrices. Integer in elementum "
36 "ligula, vel sollicitudin velit. Nullam sit amet eleifend libero. "
37 "Proin sit amet consequat tellus, vel vulputate arcu. Curabitur quis "
38 "lobortis lacus. Sed faucibus vestibulum enim vel elementum. Vivamus "
39 "enim nunc, auctor in purus at, aliquet pulvinar eros. Cras dapibus "
40 "nec quam laoreet sagittis. Quisque dictum ante odio, at imperdiet "
41 "est convallis a. Morbi mattis ut orci vitae volutpat."
42 "\n";
43
44 #define MAX_RECV_BUF_LEN 256
45 static uint8_t recv_buf[MAX(sizeof(lorem_ipsum), MAX_RECV_BUF_LEN)];
46
47 /* We need to allocate bigger buffer for the websocket data we receive so that
48 * the websocket header fits into it.
49 */
50 #define EXTRA_BUF_SPACE 30
51
52 static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
53 static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
54 static size_t test_msg_len;
55
test_fd_alloc(void * obj)56 static int test_fd_alloc(void *obj)
57 {
58 int fd;
59
60 fd = z_reserve_fd();
61 zassert_not_equal(fd, -1, "Failed to allocate FD");
62 z_finalize_fd(fd, obj, NULL);
63
64 return fd;
65 }
66
test_recv_buf(uint8_t * input_buf,size_t input_len,struct websocket_context * ctx,uint32_t * msg_type,uint64_t * remaining,uint8_t * recv_buffer,size_t recv_len)67 static int test_recv_buf(uint8_t *input_buf, size_t input_len,
68 struct websocket_context *ctx,
69 uint32_t *msg_type, uint64_t *remaining,
70 uint8_t *recv_buffer, size_t recv_len)
71 {
72 static struct test_data test_data;
73 int fd, ret;
74
75 test_data.ctx = ctx;
76 test_data.input_buf = input_buf;
77 test_data.input_len = input_len;
78 test_data.input_pos = 0;
79
80 fd = test_fd_alloc(&test_data);
81
82 ret = websocket_recv_msg(fd, recv_buffer, recv_len,
83 msg_type, remaining, 0);
84
85 z_free_fd(fd);
86
87 return ret;
88 }
89
90 /* Websocket frame, header is 6 bytes, FIN bit is set, opcode is text (1),
91 * payload length is 12, masking key is e17e8eb9,
92 * unmasked data is "test message"
93 */
94 static const unsigned char frame1[] = {
95 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
96 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
97 0xe9, 0xdc
98 };
99
100 static const unsigned char frame1_msg[] = {
101 /* Null added for printing purposes */
102 't', 'e', 's', 't', ' ', 'm', 'e', 's', 's', 'a', 'g', 'e', '\0'
103 };
104
105 /* The frame2 has frame1 + frame1. The idea is to test a case where we
106 * read full frame1 and then part of second frame
107 */
108 static const unsigned char frame2[] = {
109 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
110 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
111 0xe9, 0xdc,
112 0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
113 0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
114 0xe9, 0xdc
115 };
116
117 /* Empty websocket frame, opcode is ping, without mask */
118 static const unsigned char ping[] = {0x89, 0x00};
119
120 #define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1))
121
test_recv(int count)122 static void test_recv(int count)
123 {
124 struct websocket_context ctx;
125 uint32_t msg_type = -1;
126 uint64_t remaining = -1;
127 int total_read = 0;
128 int ret, i, left;
129
130 memset(&ctx, 0, sizeof(ctx));
131
132 ctx.recv_buf.buf = temp_recv_buf;
133 ctx.recv_buf.size = sizeof(temp_recv_buf);
134 ctx.recv_buf.count = 0;
135
136 memcpy(feed_buf, &frame1, sizeof(frame1));
137
138 NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
139 sizeof(frame1), FRAME1_HDR_SIZE);
140
141 /* We feed the frame N byte(s) at a time */
142 for (i = 0; i < sizeof(frame1) / count; i++) {
143 ret = test_recv_buf(&feed_buf[i * count], count,
144 &ctx, &msg_type, &remaining,
145 recv_buf + total_read,
146 sizeof(recv_buf) - total_read);
147 if (count < 7 && (i * count) < FRAME1_HDR_SIZE) {
148 zassert_equal(ret, -EAGAIN,
149 "[%d] Header parse failed (ret %d)",
150 i * count, ret);
151 } else {
152 total_read += ret;
153 }
154 }
155
156 /* Read any remaining data */
157 left = sizeof(frame1) % count;
158 if (left > 0) {
159 /* Some leftover bytes are still there */
160 ret = test_recv_buf(&feed_buf[sizeof(frame1) - left], left,
161 &ctx, &msg_type, &remaining,
162 recv_buf + total_read,
163 sizeof(recv_buf) - total_read);
164 zassert_true(ret <= (sizeof(recv_buf) - total_read),
165 "Invalid number of bytes read (%d)", ret);
166 total_read += ret;
167 zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
168 "Invalid amount of data read (%d)", ret);
169
170 } else if (total_read < (sizeof(frame1) - FRAME1_HDR_SIZE)) {
171 /* We read the whole message earlier, but we have parsed
172 * only part of the message. Parse the reset of the message
173 * here.
174 */
175 ret = test_recv_buf(&feed_buf[FRAME1_HDR_SIZE + total_read],
176 sizeof(frame1) - FRAME1_HDR_SIZE - total_read,
177 &ctx, &msg_type, &remaining,
178 recv_buf + total_read,
179 sizeof(recv_buf) - total_read);
180 total_read += ret;
181 zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
182 "Invalid amount of data read (%d)", ret);
183 }
184
185 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
186 "Invalid message, should be '%s' was '%s'",
187 frame1_msg, recv_buf);
188
189 zassert_equal(remaining, 0, "Msg not empty");
190 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
191 }
192
ZTEST(net_websocket,test_recv_1_byte)193 ZTEST(net_websocket, test_recv_1_byte)
194 {
195 test_recv(1);
196 }
197
ZTEST(net_websocket,test_recv_2_byte)198 ZTEST(net_websocket, test_recv_2_byte)
199 {
200 test_recv(2);
201 }
202
ZTEST(net_websocket,test_recv_3_byte)203 ZTEST(net_websocket, test_recv_3_byte)
204 {
205 test_recv(3);
206 }
207
ZTEST(net_websocket,test_recv_6_byte)208 ZTEST(net_websocket, test_recv_6_byte)
209 {
210 test_recv(6);
211 }
212
ZTEST(net_websocket,test_recv_7_byte)213 ZTEST(net_websocket, test_recv_7_byte)
214 {
215 test_recv(7);
216 }
217
ZTEST(net_websocket,test_recv_8_byte)218 ZTEST(net_websocket, test_recv_8_byte)
219 {
220 test_recv(8);
221 }
222
ZTEST(net_websocket,test_recv_9_byte)223 ZTEST(net_websocket, test_recv_9_byte)
224 {
225 test_recv(9);
226 }
227
ZTEST(net_websocket,test_recv_10_byte)228 ZTEST(net_websocket, test_recv_10_byte)
229 {
230 test_recv(10);
231 }
232
ZTEST(net_websocket,test_recv_12_byte)233 ZTEST(net_websocket, test_recv_12_byte)
234 {
235 test_recv(12);
236 }
237
ZTEST(net_websocket,test_recv_whole_msg)238 ZTEST(net_websocket, test_recv_whole_msg)
239 {
240 test_recv(sizeof(frame1));
241 }
242
ZTEST(net_websocket,test_recv_empty_ping)243 ZTEST(net_websocket, test_recv_empty_ping)
244 {
245 struct websocket_context ctx;
246 int total_read = 0;
247 uint32_t msg_type = -1;
248 uint64_t remaining = -1;
249
250 memset(&ctx, 0, sizeof(ctx));
251
252 ctx.recv_buf.buf = temp_recv_buf;
253 ctx.recv_buf.size = sizeof(temp_recv_buf);
254 ctx.recv_buf.count = 0;
255
256 memcpy(feed_buf, &ping, sizeof(ping));
257
258 total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining,
259 recv_buf, sizeof(recv_buf));
260
261 zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read);
262 zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping");
263 }
264
test_recv_2(int count)265 static void test_recv_2(int count)
266 {
267 struct websocket_context ctx;
268 uint32_t msg_type = -1;
269 uint64_t remaining = -1;
270 int total_read = 0;
271 int ret;
272
273 memset(&ctx, 0, sizeof(ctx));
274
275 ctx.recv_buf.buf = temp_recv_buf;
276 ctx.recv_buf.size = sizeof(temp_recv_buf);
277
278 memcpy(feed_buf, &frame2, sizeof(frame2));
279
280 NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
281 sizeof(frame2), FRAME1_HDR_SIZE);
282
283 total_read = test_recv_buf(&feed_buf[0], count, &ctx, &msg_type,
284 &remaining, recv_buf, sizeof(recv_buf));
285
286 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
287 "Invalid message, should be '%s' was '%s'",
288 frame1_msg, recv_buf);
289
290 zassert_equal(remaining, 0, "Msg not empty");
291 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
292
293 /* Then read again. Take in account that part of second frame
294 * have read from tx buffer to rx buffer.
295 */
296 ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining,
297 recv_buf, sizeof(recv_buf));
298
299 zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
300 "Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf);
301
302 zassert_equal(remaining, 0, "Msg not empty");
303 zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
304 }
305
ZTEST(net_websocket,test_recv_two_msg)306 ZTEST(net_websocket, test_recv_two_msg)
307 {
308 test_recv_2(sizeof(frame1) + FRAME1_HDR_SIZE / 2);
309 }
310
verify_sent_and_received_msg(struct msghdr * msg,bool split_msg)311 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg)
312 {
313 static struct websocket_context ctx;
314 uint32_t msg_type = -1;
315 uint64_t remaining = -1;
316 size_t split_len = 0, total_read = 0;
317 int ret;
318
319 memset(&ctx, 0, sizeof(ctx));
320
321 ctx.recv_buf.buf = temp_recv_buf;
322 ctx.recv_buf.size = sizeof(temp_recv_buf);
323
324 /* Read first the header */
325 ret = test_recv_buf(msg->msg_iov[0].iov_base,
326 msg->msg_iov[0].iov_len,
327 &ctx, &msg_type, &remaining,
328 recv_buf, sizeof(recv_buf));
329 if (remaining > 0) {
330 zassert_equal(ret, -EAGAIN, "Msg header not found");
331 } else {
332 zassert_equal(ret, 0, "Msg header read error (ret %d)", ret);
333 }
334
335 /* Then the first split if it is enabled */
336 if (split_msg) {
337 split_len = msg->msg_iov[1].iov_len / 2;
338
339 ret = test_recv_buf(msg->msg_iov[1].iov_base,
340 split_len,
341 &ctx, &msg_type, &remaining,
342 recv_buf, sizeof(recv_buf));
343 zassert_true(ret > 0, "Cannot read data (%d)", ret);
344
345 total_read = ret;
346 }
347
348 /* Then the data */
349 while (remaining > 0) {
350 ret = test_recv_buf((uint8_t *)msg->msg_iov[1].iov_base +
351 total_read,
352 msg->msg_iov[1].iov_len - total_read,
353 &ctx, &msg_type, &remaining,
354 recv_buf, sizeof(recv_buf));
355 zassert_true(ret > 0, "Cannot read data (%d)", ret);
356
357 if (memcmp(recv_buf, lorem_ipsum + total_read, ret) != 0) {
358 LOG_HEXDUMP_ERR(lorem_ipsum + total_read, ret,
359 "Received message should be");
360 LOG_HEXDUMP_ERR(recv_buf, ret, "but it was instead");
361 zassert_true(false, "Invalid received message "
362 "after %d bytes", total_read);
363 }
364
365 total_read += ret;
366 }
367
368 zassert_equal(total_read, test_msg_len,
369 "Msg body not valid, received %d instead of %zd",
370 total_read, test_msg_len);
371
372 NET_DBG("Received %zd header and %zd body",
373 msg->msg_iov[0].iov_len, total_read);
374
375 return msg->msg_iov[0].iov_len + total_read;
376 }
377
ZTEST(net_websocket,test_send_and_recv_lorem_ipsum)378 ZTEST(net_websocket, test_send_and_recv_lorem_ipsum)
379 {
380 static struct websocket_context ctx;
381 int fd, ret;
382
383 memset(&ctx, 0, sizeof(ctx));
384
385 ctx.recv_buf.buf = temp_recv_buf;
386 ctx.recv_buf.size = sizeof(temp_recv_buf);
387
388 test_msg_len = sizeof(lorem_ipsum) - 1;
389
390 fd = test_fd_alloc(&ctx);
391 ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
392 WEBSOCKET_OPCODE_DATA_TEXT, true, true,
393 SYS_FOREVER_MS);
394 zassert_equal(ret, test_msg_len,
395 "Should have sent %zd bytes but sent %d instead",
396 test_msg_len, ret);
397
398 z_free_fd(fd);
399 }
400
ZTEST(net_websocket,test_recv_two_large_split_msg)401 ZTEST(net_websocket, test_recv_two_large_split_msg)
402 {
403 static struct websocket_context ctx;
404 int fd, ret;
405
406 memset(&ctx, 0, sizeof(ctx));
407
408 ctx.recv_buf.buf = temp_recv_buf;
409 ctx.recv_buf.size = sizeof(temp_recv_buf);
410
411 test_msg_len = sizeof(lorem_ipsum) - 1;
412
413 fd = test_fd_alloc(&ctx);
414 ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
415 WEBSOCKET_OPCODE_DATA_TEXT, false, true,
416 SYS_FOREVER_MS);
417 zassert_equal(ret, test_msg_len,
418 "1st should have sent %zd bytes but sent %d instead",
419 test_msg_len, ret);
420
421 z_free_fd(fd);
422 }
423
ZTEST(net_websocket,test_send_and_recv_empty_pong)424 ZTEST(net_websocket, test_send_and_recv_empty_pong)
425 {
426 static struct websocket_context ctx;
427 int fd, ret;
428
429 memset(&ctx, 0, sizeof(ctx));
430
431 ctx.recv_buf.buf = temp_recv_buf;
432 ctx.recv_buf.size = sizeof(temp_recv_buf);
433
434 test_msg_len = 0;
435
436 fd = test_fd_alloc(&ctx);
437 ret = websocket_send_msg(fd, NULL, test_msg_len, WEBSOCKET_OPCODE_PING,
438 true, true, SYS_FOREVER_MS);
439 zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead",
440 test_msg_len, ret);
441
442 z_free_fd(fd);
443 }
444
ZTEST(net_websocket,test_recv_in_small_buffer)445 ZTEST(net_websocket, test_recv_in_small_buffer)
446 {
447 struct websocket_context ctx;
448 uint32_t msg_type = -1;
449 uint64_t remaining = -1;
450 int total_read = 0;
451 int ret;
452 const size_t frame1_msg_size = sizeof(frame1_msg) - 1;
453 const size_t recv_buf_size = 7;
454
455 memset(&ctx, 0, sizeof(ctx));
456
457 ctx.recv_buf.buf = temp_recv_buf;
458 ctx.recv_buf.size = sizeof(temp_recv_buf);
459
460 memcpy(feed_buf, &frame1, sizeof(frame1));
461
462 /* Receive first part of message */
463 ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf,
464 recv_buf_size);
465 zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d",
466 recv_buf_size, ret);
467 total_read += ret;
468
469 /* Receive second part of message */
470 ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining,
471 &recv_buf[recv_buf_size], recv_buf_size);
472 zassert_equal(ret, frame1_msg_size - recv_buf_size,
473 "Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size,
474 ret);
475 total_read += ret;
476
477 /* Check receiving whole message */
478 zassert_equal(total_read, frame1_msg_size, "Received not whole message");
479 zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size,
480 "Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf);
481 }
482
setup(void)483 static void *setup(void)
484 {
485 k_thread_system_pool_assign(k_current_get());
486 return NULL;
487 }
488
489 ZTEST_SUITE(net_websocket, NULL, setup, NULL, NULL, NULL);
490