1 /*
2  * Copyright (c) 2020 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_test, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
9 
10 #include <zephyr/ztest_assert.h>
11 
12 #include <zephyr/net/net_ip.h>
13 #include <zephyr/net/socket.h>
14 #include <zephyr/net/websocket.h>
15 #include <zephyr/sys/fdtable.h>
16 
17 #include "websocket_internal.h"
18 
19 /* Generated by http://www.lipsum.com/
20  * 2 paragraphs, 178 words, 1160 bytes of Lorem Ipsum
21  */
22 static const char lorem_ipsum[] =
23 	"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
24 	"Vestibulum ultricies sapien tellus, ac viverra dolor bibendum "
25 	"lacinia. Vestibulum et nisl tristique tellus finibus gravida "
26 	"vitae sit amet nunc. Suspendisse maximus justo mi, vitae porta "
27 	"risus suscipit vitae. Curabitur ut fringilla velit. Donec ac nisi "
28 	"in dui semper lobortis sed nec ante. Sed nec luctus dui. Sed ut "
29 	"ante nisi. Mauris congue euismod felis, et maximus ex pellentesque "
30 	"nec. Proin nibh nisl, semper at nunc in, mattis pharetra metus. Nam "
31 	"turpis risus, pulvinar sit amet varius ac, pellentesque quis purus."
32 	" "
33 	"Nam consequat purus in lacinia fringilla. Morbi volutpat, tellus "
34 	"nec tempus dapibus, ante sem aliquam dui, eu feugiat libero diam "
35 	"at leo. Sed suscipit egestas orci in ultrices. Integer in elementum "
36 	"ligula, vel sollicitudin velit. Nullam sit amet eleifend libero. "
37 	"Proin sit amet consequat tellus, vel vulputate arcu. Curabitur quis "
38 	"lobortis lacus. Sed faucibus vestibulum enim vel elementum. Vivamus "
39 	"enim nunc, auctor in purus at, aliquet pulvinar eros. Cras dapibus "
40 	"nec quam laoreet sagittis. Quisque dictum ante odio, at imperdiet "
41 	"est convallis a. Morbi mattis ut orci vitae volutpat."
42 	"\n";
43 
44 #define MAX_RECV_BUF_LEN 256
45 static uint8_t recv_buf[MAX(sizeof(lorem_ipsum), MAX_RECV_BUF_LEN)];
46 
47 /* We need to allocate bigger buffer for the websocket data we receive so that
48  * the websocket header fits into it.
49  */
50 #define EXTRA_BUF_SPACE 30
51 
52 static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
53 static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
54 static size_t test_msg_len;
55 
test_fd_alloc(void * obj)56 static int test_fd_alloc(void *obj)
57 {
58 	int fd;
59 
60 	fd = z_reserve_fd();
61 	zassert_not_equal(fd, -1, "Failed to allocate FD");
62 	z_finalize_fd(fd, obj, NULL);
63 
64 	return fd;
65 }
66 
test_recv_buf(uint8_t * input_buf,size_t input_len,struct websocket_context * ctx,uint32_t * msg_type,uint64_t * remaining,uint8_t * recv_buffer,size_t recv_len)67 static int test_recv_buf(uint8_t *input_buf, size_t input_len,
68 			 struct websocket_context *ctx,
69 			 uint32_t *msg_type, uint64_t *remaining,
70 			 uint8_t *recv_buffer, size_t recv_len)
71 {
72 	static struct test_data test_data;
73 	int fd, ret;
74 
75 	test_data.ctx = ctx;
76 	test_data.input_buf = input_buf;
77 	test_data.input_len = input_len;
78 	test_data.input_pos = 0;
79 
80 	fd = test_fd_alloc(&test_data);
81 
82 	ret = websocket_recv_msg(fd, recv_buffer, recv_len,
83 				 msg_type, remaining, 0);
84 
85 	z_free_fd(fd);
86 
87 	return ret;
88 }
89 
90 /* Websocket frame, header is 6 bytes, FIN bit is set, opcode is text (1),
91  * payload length is 12, masking key is e17e8eb9,
92  * unmasked data is "test message"
93  */
94 static const unsigned char frame1[] = {
95 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
96 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
97 	0xe9, 0xdc
98 };
99 
100 static const unsigned char frame1_msg[] = {
101 	/* Null added for printing purposes */
102 	't', 'e', 's', 't', ' ', 'm', 'e', 's', 's', 'a', 'g', 'e', '\0'
103 };
104 
105 /* The frame2 has frame1 + frame1. The idea is to test a case where we
106  * read full frame1 and then part of second frame
107  */
108 static const unsigned char frame2[] = {
109 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
110 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
111 	0xe9, 0xdc,
112 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
113 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
114 	0xe9, 0xdc
115 };
116 
117 /* Empty websocket frame, opcode is ping, without mask */
118 static const unsigned char ping[] = {0x89, 0x00};
119 
120 #define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1))
121 
test_recv(int count)122 static void test_recv(int count)
123 {
124 	struct websocket_context ctx;
125 	uint32_t msg_type = -1;
126 	uint64_t remaining = -1;
127 	int total_read = 0;
128 	int ret, i, left;
129 
130 	memset(&ctx, 0, sizeof(ctx));
131 
132 	ctx.recv_buf.buf = temp_recv_buf;
133 	ctx.recv_buf.size = sizeof(temp_recv_buf);
134 	ctx.recv_buf.count = 0;
135 
136 	memcpy(feed_buf, &frame1, sizeof(frame1));
137 
138 	NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
139 		sizeof(frame1), FRAME1_HDR_SIZE);
140 
141 	/* We feed the frame N byte(s) at a time */
142 	for (i = 0; i < sizeof(frame1) / count; i++) {
143 		ret = test_recv_buf(&feed_buf[i * count], count,
144 				    &ctx, &msg_type, &remaining,
145 				    recv_buf + total_read,
146 				    sizeof(recv_buf) - total_read);
147 		if (count < 7 && (i * count) < FRAME1_HDR_SIZE) {
148 			zassert_equal(ret, -EAGAIN,
149 				      "[%d] Header parse failed (ret %d)",
150 				      i * count, ret);
151 		} else {
152 			total_read += ret;
153 		}
154 	}
155 
156 	/* Read any remaining data */
157 	left = sizeof(frame1) % count;
158 	if (left > 0) {
159 		/* Some leftover bytes are still there */
160 		ret = test_recv_buf(&feed_buf[sizeof(frame1) - left], left,
161 				    &ctx, &msg_type, &remaining,
162 				    recv_buf + total_read,
163 				    sizeof(recv_buf) - total_read);
164 		zassert_true(ret <= (sizeof(recv_buf) - total_read),
165 			     "Invalid number of bytes read (%d)", ret);
166 		total_read += ret;
167 		zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
168 			      "Invalid amount of data read (%d)", ret);
169 
170 	} else if (total_read < (sizeof(frame1) - FRAME1_HDR_SIZE)) {
171 		/* We read the whole message earlier, but we have parsed
172 		 * only part of the message. Parse the reset of the message
173 		 * here.
174 		 */
175 		ret = test_recv_buf(&feed_buf[FRAME1_HDR_SIZE + total_read],
176 			    sizeof(frame1) - FRAME1_HDR_SIZE - total_read,
177 				    &ctx, &msg_type, &remaining,
178 				    recv_buf + total_read,
179 				    sizeof(recv_buf) - total_read);
180 		total_read += ret;
181 		zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
182 			      "Invalid amount of data read (%d)", ret);
183 	}
184 
185 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
186 			  "Invalid message, should be '%s' was '%s'",
187 			  frame1_msg, recv_buf);
188 
189 	zassert_equal(remaining, 0, "Msg not empty");
190 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
191 }
192 
ZTEST(net_websocket,test_recv_1_byte)193 ZTEST(net_websocket, test_recv_1_byte)
194 {
195 	test_recv(1);
196 }
197 
ZTEST(net_websocket,test_recv_2_byte)198 ZTEST(net_websocket, test_recv_2_byte)
199 {
200 	test_recv(2);
201 }
202 
ZTEST(net_websocket,test_recv_3_byte)203 ZTEST(net_websocket, test_recv_3_byte)
204 {
205 	test_recv(3);
206 }
207 
ZTEST(net_websocket,test_recv_6_byte)208 ZTEST(net_websocket, test_recv_6_byte)
209 {
210 	test_recv(6);
211 }
212 
ZTEST(net_websocket,test_recv_7_byte)213 ZTEST(net_websocket, test_recv_7_byte)
214 {
215 	test_recv(7);
216 }
217 
ZTEST(net_websocket,test_recv_8_byte)218 ZTEST(net_websocket, test_recv_8_byte)
219 {
220 	test_recv(8);
221 }
222 
ZTEST(net_websocket,test_recv_9_byte)223 ZTEST(net_websocket, test_recv_9_byte)
224 {
225 	test_recv(9);
226 }
227 
ZTEST(net_websocket,test_recv_10_byte)228 ZTEST(net_websocket, test_recv_10_byte)
229 {
230 	test_recv(10);
231 }
232 
ZTEST(net_websocket,test_recv_12_byte)233 ZTEST(net_websocket, test_recv_12_byte)
234 {
235 	test_recv(12);
236 }
237 
ZTEST(net_websocket,test_recv_whole_msg)238 ZTEST(net_websocket, test_recv_whole_msg)
239 {
240 	test_recv(sizeof(frame1));
241 }
242 
ZTEST(net_websocket,test_recv_empty_ping)243 ZTEST(net_websocket, test_recv_empty_ping)
244 {
245 	struct websocket_context ctx;
246 	int total_read = 0;
247 	uint32_t msg_type = -1;
248 	uint64_t remaining = -1;
249 
250 	memset(&ctx, 0, sizeof(ctx));
251 
252 	ctx.recv_buf.buf = temp_recv_buf;
253 	ctx.recv_buf.size = sizeof(temp_recv_buf);
254 	ctx.recv_buf.count = 0;
255 
256 	memcpy(feed_buf, &ping, sizeof(ping));
257 
258 	total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining,
259 				   recv_buf, sizeof(recv_buf));
260 
261 	zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read);
262 	zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping");
263 }
264 
test_recv_2(int count)265 static void test_recv_2(int count)
266 {
267 	struct websocket_context ctx;
268 	uint32_t msg_type = -1;
269 	uint64_t remaining = -1;
270 	int total_read = 0;
271 	int ret;
272 
273 	memset(&ctx, 0, sizeof(ctx));
274 
275 	ctx.recv_buf.buf = temp_recv_buf;
276 	ctx.recv_buf.size = sizeof(temp_recv_buf);
277 
278 	memcpy(feed_buf, &frame2, sizeof(frame2));
279 
280 	NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
281 		sizeof(frame2), FRAME1_HDR_SIZE);
282 
283 	total_read = test_recv_buf(&feed_buf[0], count, &ctx, &msg_type,
284 				   &remaining, recv_buf, sizeof(recv_buf));
285 
286 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
287 			  "Invalid message, should be '%s' was '%s'",
288 			  frame1_msg, recv_buf);
289 
290 	zassert_equal(remaining, 0, "Msg not empty");
291 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
292 
293 	/* Then read again. Take in account that part of second frame
294 	 * have read from tx buffer to rx buffer.
295 	 */
296 	ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining,
297 			    recv_buf, sizeof(recv_buf));
298 
299 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
300 			  "Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf);
301 
302 	zassert_equal(remaining, 0, "Msg not empty");
303 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
304 }
305 
ZTEST(net_websocket,test_recv_two_msg)306 ZTEST(net_websocket, test_recv_two_msg)
307 {
308 	test_recv_2(sizeof(frame1) + FRAME1_HDR_SIZE / 2);
309 }
310 
verify_sent_and_received_msg(struct msghdr * msg,bool split_msg)311 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg)
312 {
313 	static struct websocket_context ctx;
314 	uint32_t msg_type = -1;
315 	uint64_t remaining = -1;
316 	size_t split_len = 0, total_read = 0;
317 	int ret;
318 
319 	memset(&ctx, 0, sizeof(ctx));
320 
321 	ctx.recv_buf.buf = temp_recv_buf;
322 	ctx.recv_buf.size = sizeof(temp_recv_buf);
323 
324 	/* Read first the header */
325 	ret = test_recv_buf(msg->msg_iov[0].iov_base,
326 			    msg->msg_iov[0].iov_len,
327 			    &ctx, &msg_type, &remaining,
328 			    recv_buf, sizeof(recv_buf));
329 	if (remaining > 0) {
330 		zassert_equal(ret, -EAGAIN, "Msg header not found");
331 	} else {
332 		zassert_equal(ret, 0, "Msg header read error (ret %d)", ret);
333 	}
334 
335 	/* Then the first split if it is enabled */
336 	if (split_msg) {
337 		split_len = msg->msg_iov[1].iov_len / 2;
338 
339 		ret = test_recv_buf(msg->msg_iov[1].iov_base,
340 				    split_len,
341 				    &ctx, &msg_type, &remaining,
342 				    recv_buf, sizeof(recv_buf));
343 		zassert_true(ret > 0, "Cannot read data (%d)", ret);
344 
345 		total_read = ret;
346 	}
347 
348 	/* Then the data */
349 	while (remaining > 0) {
350 		ret = test_recv_buf((uint8_t *)msg->msg_iov[1].iov_base +
351 								total_read,
352 				    msg->msg_iov[1].iov_len - total_read,
353 				    &ctx, &msg_type, &remaining,
354 				    recv_buf, sizeof(recv_buf));
355 		zassert_true(ret > 0, "Cannot read data (%d)", ret);
356 
357 		if (memcmp(recv_buf, lorem_ipsum + total_read, ret) != 0) {
358 			LOG_HEXDUMP_ERR(lorem_ipsum + total_read, ret,
359 					"Received message should be");
360 			LOG_HEXDUMP_ERR(recv_buf, ret, "but it was instead");
361 			zassert_true(false, "Invalid received message "
362 				     "after %d bytes", total_read);
363 		}
364 
365 		total_read += ret;
366 	}
367 
368 	zassert_equal(total_read, test_msg_len,
369 		      "Msg body not valid, received %d instead of %zd",
370 		      total_read, test_msg_len);
371 
372 	NET_DBG("Received %zd header and %zd body",
373 		msg->msg_iov[0].iov_len, total_read);
374 
375 	return msg->msg_iov[0].iov_len + total_read;
376 }
377 
ZTEST(net_websocket,test_send_and_recv_lorem_ipsum)378 ZTEST(net_websocket, test_send_and_recv_lorem_ipsum)
379 {
380 	static struct websocket_context ctx;
381 	int fd, ret;
382 
383 	memset(&ctx, 0, sizeof(ctx));
384 
385 	ctx.recv_buf.buf = temp_recv_buf;
386 	ctx.recv_buf.size = sizeof(temp_recv_buf);
387 
388 	test_msg_len = sizeof(lorem_ipsum) - 1;
389 
390 	fd = test_fd_alloc(&ctx);
391 	ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
392 				 WEBSOCKET_OPCODE_DATA_TEXT, true, true,
393 				 SYS_FOREVER_MS);
394 	zassert_equal(ret, test_msg_len,
395 		      "Should have sent %zd bytes but sent %d instead",
396 		      test_msg_len, ret);
397 
398 	z_free_fd(fd);
399 }
400 
ZTEST(net_websocket,test_recv_two_large_split_msg)401 ZTEST(net_websocket, test_recv_two_large_split_msg)
402 {
403 	static struct websocket_context ctx;
404 	int fd, ret;
405 
406 	memset(&ctx, 0, sizeof(ctx));
407 
408 	ctx.recv_buf.buf = temp_recv_buf;
409 	ctx.recv_buf.size = sizeof(temp_recv_buf);
410 
411 	test_msg_len = sizeof(lorem_ipsum) - 1;
412 
413 	fd = test_fd_alloc(&ctx);
414 	ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
415 				 WEBSOCKET_OPCODE_DATA_TEXT, false, true,
416 				 SYS_FOREVER_MS);
417 	zassert_equal(ret, test_msg_len,
418 		      "1st should have sent %zd bytes but sent %d instead",
419 		      test_msg_len, ret);
420 
421 	z_free_fd(fd);
422 }
423 
ZTEST(net_websocket,test_send_and_recv_empty_pong)424 ZTEST(net_websocket, test_send_and_recv_empty_pong)
425 {
426 	static struct websocket_context ctx;
427 	int fd, ret;
428 
429 	memset(&ctx, 0, sizeof(ctx));
430 
431 	ctx.recv_buf.buf = temp_recv_buf;
432 	ctx.recv_buf.size = sizeof(temp_recv_buf);
433 
434 	test_msg_len = 0;
435 
436 	fd = test_fd_alloc(&ctx);
437 	ret = websocket_send_msg(fd, NULL, test_msg_len, WEBSOCKET_OPCODE_PING,
438 				 true, true, SYS_FOREVER_MS);
439 	zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead",
440 		      test_msg_len, ret);
441 
442 	z_free_fd(fd);
443 }
444 
ZTEST(net_websocket,test_recv_in_small_buffer)445 ZTEST(net_websocket, test_recv_in_small_buffer)
446 {
447 	struct websocket_context ctx;
448 	uint32_t msg_type = -1;
449 	uint64_t remaining = -1;
450 	int total_read = 0;
451 	int ret;
452 	const size_t frame1_msg_size = sizeof(frame1_msg) - 1;
453 	const size_t recv_buf_size = 7;
454 
455 	memset(&ctx, 0, sizeof(ctx));
456 
457 	ctx.recv_buf.buf = temp_recv_buf;
458 	ctx.recv_buf.size = sizeof(temp_recv_buf);
459 
460 	memcpy(feed_buf, &frame1, sizeof(frame1));
461 
462 	/* Receive first part of message */
463 	ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf,
464 			    recv_buf_size);
465 	zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d",
466 		      recv_buf_size, ret);
467 	total_read += ret;
468 
469 	/* Receive second part of message */
470 	ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining,
471 			    &recv_buf[recv_buf_size], recv_buf_size);
472 	zassert_equal(ret, frame1_msg_size - recv_buf_size,
473 		      "Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size,
474 		      ret);
475 	total_read += ret;
476 
477 	/* Check receiving whole message */
478 	zassert_equal(total_read, frame1_msg_size, "Received not whole message");
479 	zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size,
480 			  "Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf);
481 }
482 
setup(void)483 static void *setup(void)
484 {
485 	k_thread_system_pool_assign(k_current_get());
486 	return NULL;
487 }
488 
489 ZTEST_SUITE(net_websocket, NULL, setup, NULL, NULL, NULL);
490