1 /*
2  * Copyright (c) 2020 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_test, CONFIG_NET_WEBSOCKET_LOG_LEVEL);
9 
10 #include <zephyr/ztest_assert.h>
11 
12 #include <zephyr/misc/lorem_ipsum.h>
13 #include <zephyr/net/net_ip.h>
14 #include <zephyr/net/socket.h>
15 #include <zephyr/net/websocket.h>
16 #include <zephyr/sys/fdtable.h>
17 
18 #include "websocket_internal.h"
19 
20 /* Generated by http://www.lipsum.com/
21  * 2 paragraphs, 178 words, 1160 bytes of Lorem Ipsum
22  */
23 static const char lorem_ipsum[] = LOREM_IPSUM;
24 
25 #define MAX_RECV_BUF_LEN 256
26 static uint8_t recv_buf[MAX(sizeof(lorem_ipsum), MAX_RECV_BUF_LEN)];
27 
28 /* We need to allocate bigger buffer for the websocket data we receive so that
29  * the websocket header fits into it.
30  */
31 #define EXTRA_BUF_SPACE 30
32 
33 static uint8_t temp_recv_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
34 static uint8_t feed_buf[MAX_RECV_BUF_LEN + EXTRA_BUF_SPACE];
35 static size_t test_msg_len;
36 
test_fd_alloc(void * obj)37 static int test_fd_alloc(void *obj)
38 {
39 	int fd;
40 
41 	fd = zvfs_reserve_fd();
42 	zassert_not_equal(fd, -1, "Failed to allocate FD");
43 	zvfs_finalize_fd(fd, obj, NULL);
44 
45 	return fd;
46 }
47 
test_recv_buf(uint8_t * input_buf,size_t input_len,struct websocket_context * ctx,uint32_t * msg_type,uint64_t * remaining,uint8_t * recv_buffer,size_t recv_len)48 static int test_recv_buf(uint8_t *input_buf, size_t input_len,
49 			 struct websocket_context *ctx,
50 			 uint32_t *msg_type, uint64_t *remaining,
51 			 uint8_t *recv_buffer, size_t recv_len)
52 {
53 	static struct test_data test_data;
54 	int fd, ret;
55 
56 	test_data.ctx = ctx;
57 	test_data.input_buf = input_buf;
58 	test_data.input_len = input_len;
59 	test_data.input_pos = 0;
60 
61 	fd = test_fd_alloc(&test_data);
62 
63 	ret = websocket_recv_msg(fd, recv_buffer, recv_len,
64 				 msg_type, remaining, 0);
65 
66 	zvfs_free_fd(fd);
67 
68 	return ret;
69 }
70 
71 /* Websocket frame, header is 6 bytes, FIN bit is set, opcode is text (1),
72  * payload length is 12, masking key is e17e8eb9,
73  * unmasked data is "test message"
74  */
75 static const unsigned char frame1[] = {
76 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
77 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
78 	0xe9, 0xdc
79 };
80 
81 static const unsigned char frame1_msg[] = {
82 	/* Null added for printing purposes */
83 	't', 'e', 's', 't', ' ', 'm', 'e', 's', 's', 'a', 'g', 'e', '\0'
84 };
85 
86 /* The frame2 has frame1 + frame1. The idea is to test a case where we
87  * read full frame1 and then part of second frame
88  */
89 static const unsigned char frame2[] = {
90 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
91 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
92 	0xe9, 0xdc,
93 	0x81, 0x8c, 0xe1, 0x7e, 0x8e, 0xb9, 0x95, 0x1b,
94 	0xfd, 0xcd, 0xc1, 0x13, 0xeb, 0xca, 0x92, 0x1f,
95 	0xe9, 0xdc
96 };
97 
98 /* Empty websocket frame, opcode is ping, without mask */
99 static const unsigned char ping[] = {0x89, 0x00};
100 
101 #define FRAME1_HDR_SIZE (sizeof(frame1) - (sizeof(frame1_msg) - 1))
102 
test_recv(int count)103 static void test_recv(int count)
104 {
105 	struct websocket_context ctx;
106 	uint32_t msg_type = -1;
107 	uint64_t remaining = -1;
108 	int total_read = 0;
109 	int ret, i, left;
110 
111 	memset(&ctx, 0, sizeof(ctx));
112 
113 	ctx.recv_buf.buf = temp_recv_buf;
114 	ctx.recv_buf.size = sizeof(temp_recv_buf);
115 	ctx.recv_buf.count = 0;
116 
117 	memcpy(feed_buf, &frame1, sizeof(frame1));
118 
119 	NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
120 		sizeof(frame1), FRAME1_HDR_SIZE);
121 
122 	/* We feed the frame N byte(s) at a time */
123 	for (i = 0; i < sizeof(frame1) / count; i++) {
124 		ret = test_recv_buf(&feed_buf[i * count], count,
125 				    &ctx, &msg_type, &remaining,
126 				    recv_buf + total_read,
127 				    sizeof(recv_buf) - total_read);
128 		if (count < 7 && (i * count) < FRAME1_HDR_SIZE) {
129 			zassert_equal(ret, -EAGAIN,
130 				      "[%d] Header parse failed (ret %d)",
131 				      i * count, ret);
132 		} else {
133 			total_read += ret;
134 		}
135 	}
136 
137 	/* Read any remaining data */
138 	left = sizeof(frame1) % count;
139 	if (left > 0) {
140 		/* Some leftover bytes are still there */
141 		ret = test_recv_buf(&feed_buf[sizeof(frame1) - left], left,
142 				    &ctx, &msg_type, &remaining,
143 				    recv_buf + total_read,
144 				    sizeof(recv_buf) - total_read);
145 		zassert_true(ret <= (sizeof(recv_buf) - total_read),
146 			     "Invalid number of bytes read (%d)", ret);
147 		total_read += ret;
148 		zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
149 			      "Invalid amount of data read (%d)", ret);
150 
151 	} else if (total_read < (sizeof(frame1) - FRAME1_HDR_SIZE)) {
152 		/* We read the whole message earlier, but we have parsed
153 		 * only part of the message. Parse the reset of the message
154 		 * here.
155 		 */
156 		ret = test_recv_buf(&feed_buf[FRAME1_HDR_SIZE + total_read],
157 			    sizeof(frame1) - FRAME1_HDR_SIZE - total_read,
158 				    &ctx, &msg_type, &remaining,
159 				    recv_buf + total_read,
160 				    sizeof(recv_buf) - total_read);
161 		total_read += ret;
162 		zassert_equal(total_read, sizeof(frame1) - FRAME1_HDR_SIZE,
163 			      "Invalid amount of data read (%d)", ret);
164 	}
165 
166 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
167 			  "Invalid message, should be '%s' was '%s'",
168 			  frame1_msg, recv_buf);
169 
170 	zassert_equal(remaining, 0, "Msg not empty");
171 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
172 }
173 
ZTEST(net_websocket,test_recv_1_byte)174 ZTEST(net_websocket, test_recv_1_byte)
175 {
176 	test_recv(1);
177 }
178 
ZTEST(net_websocket,test_recv_2_byte)179 ZTEST(net_websocket, test_recv_2_byte)
180 {
181 	test_recv(2);
182 }
183 
ZTEST(net_websocket,test_recv_3_byte)184 ZTEST(net_websocket, test_recv_3_byte)
185 {
186 	test_recv(3);
187 }
188 
ZTEST(net_websocket,test_recv_6_byte)189 ZTEST(net_websocket, test_recv_6_byte)
190 {
191 	test_recv(6);
192 }
193 
ZTEST(net_websocket,test_recv_7_byte)194 ZTEST(net_websocket, test_recv_7_byte)
195 {
196 	test_recv(7);
197 }
198 
ZTEST(net_websocket,test_recv_8_byte)199 ZTEST(net_websocket, test_recv_8_byte)
200 {
201 	test_recv(8);
202 }
203 
ZTEST(net_websocket,test_recv_9_byte)204 ZTEST(net_websocket, test_recv_9_byte)
205 {
206 	test_recv(9);
207 }
208 
ZTEST(net_websocket,test_recv_10_byte)209 ZTEST(net_websocket, test_recv_10_byte)
210 {
211 	test_recv(10);
212 }
213 
ZTEST(net_websocket,test_recv_12_byte)214 ZTEST(net_websocket, test_recv_12_byte)
215 {
216 	test_recv(12);
217 }
218 
ZTEST(net_websocket,test_recv_whole_msg)219 ZTEST(net_websocket, test_recv_whole_msg)
220 {
221 	test_recv(sizeof(frame1));
222 }
223 
ZTEST(net_websocket,test_recv_empty_ping)224 ZTEST(net_websocket, test_recv_empty_ping)
225 {
226 	struct websocket_context ctx;
227 	int total_read = 0;
228 	uint32_t msg_type = -1;
229 	uint64_t remaining = -1;
230 
231 	memset(&ctx, 0, sizeof(ctx));
232 
233 	ctx.recv_buf.buf = temp_recv_buf;
234 	ctx.recv_buf.size = sizeof(temp_recv_buf);
235 	ctx.recv_buf.count = 0;
236 
237 	memcpy(feed_buf, &ping, sizeof(ping));
238 
239 	total_read = test_recv_buf(&feed_buf[0], sizeof(ping), &ctx, &msg_type, &remaining,
240 				   recv_buf, sizeof(recv_buf));
241 
242 	zassert_equal(total_read, 0, "Msg not empty (ret %d)", total_read);
243 	zassert_equal(msg_type & WEBSOCKET_FLAG_PING, WEBSOCKET_FLAG_PING, "Msg is not ping");
244 }
245 
test_recv_2(int count)246 static void test_recv_2(int count)
247 {
248 	struct websocket_context ctx;
249 	uint32_t msg_type = -1;
250 	uint64_t remaining = -1;
251 	int total_read = 0;
252 	int ret;
253 
254 	memset(&ctx, 0, sizeof(ctx));
255 
256 	ctx.recv_buf.buf = temp_recv_buf;
257 	ctx.recv_buf.size = sizeof(temp_recv_buf);
258 
259 	memcpy(feed_buf, &frame2, sizeof(frame2));
260 
261 	NET_DBG("Reading %d bytes at a time, frame %zd hdr %zd", count,
262 		sizeof(frame2), FRAME1_HDR_SIZE);
263 
264 	total_read = test_recv_buf(&feed_buf[0], count, &ctx, &msg_type,
265 				   &remaining, recv_buf, sizeof(recv_buf));
266 
267 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
268 			  "Invalid message, should be '%s' was '%s'",
269 			  frame1_msg, recv_buf);
270 
271 	zassert_equal(remaining, 0, "Msg not empty");
272 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
273 
274 	/* Then read again. Take in account that part of second frame
275 	 * have read from tx buffer to rx buffer.
276 	 */
277 	ret = test_recv_buf(&feed_buf[count], sizeof(frame2) - count, &ctx, &msg_type, &remaining,
278 			    recv_buf, sizeof(recv_buf));
279 
280 	zassert_mem_equal(recv_buf, frame1_msg, sizeof(frame1_msg) - 1,
281 			  "Invalid 2nd message, should be '%s' was '%s'", frame1_msg, recv_buf);
282 
283 	zassert_equal(remaining, 0, "Msg not empty");
284 	zassert_equal(msg_type & WEBSOCKET_FLAG_TEXT, WEBSOCKET_FLAG_TEXT, "Msg is not text");
285 }
286 
ZTEST(net_websocket,test_recv_two_msg)287 ZTEST(net_websocket, test_recv_two_msg)
288 {
289 	test_recv_2(sizeof(frame1) + FRAME1_HDR_SIZE / 2);
290 }
291 
verify_sent_and_received_msg(struct msghdr * msg,bool split_msg)292 int verify_sent_and_received_msg(struct msghdr *msg, bool split_msg)
293 {
294 	static struct websocket_context ctx;
295 	uint32_t msg_type = -1;
296 	uint64_t remaining = -1;
297 	size_t split_len = 0, total_read = 0;
298 	int ret;
299 
300 	memset(&ctx, 0, sizeof(ctx));
301 
302 	ctx.recv_buf.buf = temp_recv_buf;
303 	ctx.recv_buf.size = sizeof(temp_recv_buf);
304 
305 	/* Read first the header */
306 	ret = test_recv_buf(msg->msg_iov[0].iov_base,
307 			    msg->msg_iov[0].iov_len,
308 			    &ctx, &msg_type, &remaining,
309 			    recv_buf, sizeof(recv_buf));
310 	if (remaining > 0) {
311 		zassert_equal(ret, -EAGAIN, "Msg header not found");
312 	} else {
313 		zassert_equal(ret, 0, "Msg header read error (ret %d)", ret);
314 	}
315 
316 	/* Then the first split if it is enabled */
317 	if (split_msg) {
318 		split_len = msg->msg_iov[1].iov_len / 2;
319 
320 		ret = test_recv_buf(msg->msg_iov[1].iov_base,
321 				    split_len,
322 				    &ctx, &msg_type, &remaining,
323 				    recv_buf, sizeof(recv_buf));
324 		zassert_true(ret > 0, "Cannot read data (%d)", ret);
325 
326 		total_read = ret;
327 	}
328 
329 	/* Then the data */
330 	while (remaining > 0) {
331 		ret = test_recv_buf((uint8_t *)msg->msg_iov[1].iov_base +
332 								total_read,
333 				    msg->msg_iov[1].iov_len - total_read,
334 				    &ctx, &msg_type, &remaining,
335 				    recv_buf, sizeof(recv_buf));
336 		zassert_true(ret > 0, "Cannot read data (%d)", ret);
337 
338 		if (memcmp(recv_buf, lorem_ipsum + total_read, ret) != 0) {
339 			LOG_HEXDUMP_ERR(lorem_ipsum + total_read, ret,
340 					"Received message should be");
341 			LOG_HEXDUMP_ERR(recv_buf, ret, "but it was instead");
342 			zassert_true(false, "Invalid received message "
343 				     "after %d bytes", total_read);
344 		}
345 
346 		total_read += ret;
347 	}
348 
349 	zassert_equal(total_read, test_msg_len,
350 		      "Msg body not valid, received %d instead of %zd",
351 		      total_read, test_msg_len);
352 
353 	NET_DBG("Received %zd header and %zd body",
354 		msg->msg_iov[0].iov_len, total_read);
355 
356 	return msg->msg_iov[0].iov_len + total_read;
357 }
358 
ZTEST(net_websocket,test_send_and_recv_lorem_ipsum)359 ZTEST(net_websocket, test_send_and_recv_lorem_ipsum)
360 {
361 	static struct websocket_context ctx;
362 	int fd, ret;
363 
364 	memset(&ctx, 0, sizeof(ctx));
365 
366 	ctx.recv_buf.buf = temp_recv_buf;
367 	ctx.recv_buf.size = sizeof(temp_recv_buf);
368 
369 	test_msg_len = sizeof(lorem_ipsum) - 1;
370 
371 	fd = test_fd_alloc(&ctx);
372 	ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
373 				 WEBSOCKET_OPCODE_DATA_TEXT, true, true,
374 				 SYS_FOREVER_MS);
375 	zassert_equal(ret, test_msg_len,
376 		      "Should have sent %zd bytes but sent %d instead",
377 		      test_msg_len, ret);
378 
379 	zvfs_free_fd(fd);
380 }
381 
ZTEST(net_websocket,test_recv_two_large_split_msg)382 ZTEST(net_websocket, test_recv_two_large_split_msg)
383 {
384 	static struct websocket_context ctx;
385 	int fd, ret;
386 
387 	memset(&ctx, 0, sizeof(ctx));
388 
389 	ctx.recv_buf.buf = temp_recv_buf;
390 	ctx.recv_buf.size = sizeof(temp_recv_buf);
391 
392 	test_msg_len = sizeof(lorem_ipsum) - 1;
393 
394 	fd = test_fd_alloc(&ctx);
395 	ret = websocket_send_msg(fd, lorem_ipsum, test_msg_len,
396 				 WEBSOCKET_OPCODE_DATA_TEXT, false, true,
397 				 SYS_FOREVER_MS);
398 	zassert_equal(ret, test_msg_len,
399 		      "1st should have sent %zd bytes but sent %d instead",
400 		      test_msg_len, ret);
401 
402 	zvfs_free_fd(fd);
403 }
404 
ZTEST(net_websocket,test_send_and_recv_empty_pong)405 ZTEST(net_websocket, test_send_and_recv_empty_pong)
406 {
407 	static struct websocket_context ctx;
408 	int fd, ret;
409 
410 	memset(&ctx, 0, sizeof(ctx));
411 
412 	ctx.recv_buf.buf = temp_recv_buf;
413 	ctx.recv_buf.size = sizeof(temp_recv_buf);
414 
415 	test_msg_len = 0;
416 
417 	fd = test_fd_alloc(&ctx);
418 	ret = websocket_send_msg(fd, NULL, test_msg_len, WEBSOCKET_OPCODE_PING,
419 				 true, true, SYS_FOREVER_MS);
420 	zassert_equal(ret, test_msg_len, "Should have sent %zd bytes but sent %d instead",
421 		      test_msg_len, ret);
422 
423 	zvfs_free_fd(fd);
424 }
425 
ZTEST(net_websocket,test_recv_in_small_buffer)426 ZTEST(net_websocket, test_recv_in_small_buffer)
427 {
428 	struct websocket_context ctx;
429 	uint32_t msg_type = -1;
430 	uint64_t remaining = -1;
431 	int total_read = 0;
432 	int ret;
433 	const size_t frame1_msg_size = sizeof(frame1_msg) - 1;
434 	const size_t recv_buf_size = 7;
435 
436 	memset(&ctx, 0, sizeof(ctx));
437 
438 	ctx.recv_buf.buf = temp_recv_buf;
439 	ctx.recv_buf.size = sizeof(temp_recv_buf);
440 
441 	memcpy(feed_buf, &frame1, sizeof(frame1));
442 
443 	/* Receive first part of message */
444 	ret = test_recv_buf(&feed_buf[0], sizeof(frame1), &ctx, &msg_type, &remaining, recv_buf,
445 			    recv_buf_size);
446 	zassert_equal(ret, recv_buf_size, "Should have received %zd bytes but ret %d",
447 		      recv_buf_size, ret);
448 	total_read += ret;
449 
450 	/* Receive second part of message */
451 	ret = test_recv_buf(&feed_buf[sizeof(frame1)], 0, &ctx, &msg_type, &remaining,
452 			    &recv_buf[recv_buf_size], recv_buf_size);
453 	zassert_equal(ret, frame1_msg_size - recv_buf_size,
454 		      "Should have received %zd bytes but ret %d", frame1_msg_size - recv_buf_size,
455 		      ret);
456 	total_read += ret;
457 
458 	/* Check receiving whole message */
459 	zassert_equal(total_read, frame1_msg_size, "Received not whole message");
460 	zassert_mem_equal(recv_buf, frame1_msg, frame1_msg_size,
461 			  "Invalid message, should be '%s' was '%s'", frame1_msg, recv_buf);
462 }
463 
setup(void)464 static void *setup(void)
465 {
466 	k_thread_system_pool_assign(k_current_get());
467 	return NULL;
468 }
469 
470 ZTEST_SUITE(net_websocket, NULL, setup, NULL, NULL, NULL);
471