1 /*
2 * Copyright (c) 2020 Friedt Professional Engineering Services, Inc
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/logging/log.h>
8 #include <zephyr/net/net_core.h>
9 #include <zephyr/net/net_ip.h>
10 #include <zephyr/net/socket.h>
11 #include <zephyr/net/tls_credentials.h>
12 #include <zephyr/posix/sys/socket.h>
13 #include <zephyr/posix/arpa/inet.h>
14 #include <zephyr/posix/unistd.h>
15 #include <zephyr/sys/util.h>
16 #include <zephyr/ztest.h>
17
18 #include <mbedtls/x509.h>
19 #include <mbedtls/x509_crt.h>
20
21 LOG_MODULE_REGISTER(tls_test, CONFIG_NET_SOCKETS_LOG_LEVEL);
22
23 /**
24 * @brief An encrypted message to pass between server and client.
25 *
26 * The answer to life, the universe, and everything.
27 *
28 * See also <a href="https://en.wikipedia.org/wiki/42_(number)#The_Hitchhiker's_Guide_to_the_Galaxy">42</a>.
29 */
30 #define SECRET "forty-two"
31
32 /**
33 * @brief Size of the encrypted message passed between server and client.
34 */
35 #define SECRET_SIZE (sizeof(SECRET) - 1)
36
37 /** @brief Stack size for the server thread */
38 #define STACK_SIZE 8192
39
40 #define MY_IPV4_ADDR "127.0.0.1"
41
42 /** @brief TCP port for the server thread */
43 #define PORT 4242
44
45 /** @brief arbitrary timeout value in ms */
46 #define TIMEOUT 1000
47
48 /**
49 * @brief Application-dependent TLS credential identifiers
50 *
51 * Since both the server and client exist in the same test
52 * application in this case, both the server and client credentials
53 * are loaded together.
54 *
55 * The server would normally need
56 * - SERVER_CERTIFICATE_TAG (for both public and private keys)
57 * - CA_CERTIFICATE_TAG (only when client authentication is required)
58 *
59 * The client would normally load
60 * - CA_CERTIFICATE_TAG (always required, to verify the server)
61 * - CLIENT_CERTIFICATE_TAG (for both public and private keys, only when
62 * client authentication is required)
63 */
64 enum tls_tag {
65 /** The Certificate Authority public key */
66 CA_CERTIFICATE_TAG,
67 /** Used for both the public and private server keys */
68 SERVER_CERTIFICATE_TAG,
69 /** Used for both the public and private client keys */
70 CLIENT_CERTIFICATE_TAG,
71 };
72
73 /** @brief synchronization object for server & client threads */
74 static struct k_sem server_sem;
75
76 /** @brief The server thread stack */
77 static K_THREAD_STACK_DEFINE(server_stack, STACK_SIZE);
78 /** @brief the server thread object */
79 static struct k_thread server_thread;
80
81 #ifdef CONFIG_TLS_CREDENTIALS
82 /**
83 * @brief The Certificate Authority (CA) Certificate
84 *
85 * The client needs the CA cert to verify the server public key. TLS client
86 * sockets are always required to verify the server public key.
87 *
88 * Additionally, when the peer verification mode is
89 * @ref TLS_PEER_VERIFY_OPTIONAL or @ref TLS_PEER_VERIFY_REQUIRED, then
90 * the server also needs the CA cert in order to verify the client. This
91 * type of configuration is often referred to as *mutual authentication*.
92 */
93 static const unsigned char ca[] = {
94 #include "ca.inc"
95 };
96
97 /**
98 * @brief The Server Certificate
99 *
100 * This is the public key of the server.
101 */
102 static const unsigned char server[] = {
103 #include "server.inc"
104 };
105
106 /**
107 * @brief The Server Private Key
108 *
109 * This is the private key of the server.
110 */
111 static const unsigned char server_privkey[] = {
112 #include "server_privkey.inc"
113 };
114
115 /**
116 * @brief The Client Certificate
117 *
118 * This is the public key of the client.
119 */
120 static const unsigned char client[] = {
121 #include "client.inc"
122 };
123
124 /**
125 * @brief The Client Private Key
126 *
127 * This is the private key of the client.
128 */
129 static const unsigned char client_privkey[] = {
130 #include "client_privkey.inc"
131 };
132 #else /* CONFIG_TLS_CREDENTIALS */
133 #define ca NULL
134 #define server NULL
135 #define server_privkey NULL
136 #define client NULL
137 #define client_privkey NULL
138 #endif /* CONFIG_TLS_CREDENTIALS */
139
140 /**
141 * @brief The server thread function
142 *
143 * This function simply accepts a client connection and
144 * echoes the first @ref SECRET_SIZE bytes of the first
145 * packet. After that, the server is closed and connections
146 * are no longer accepted.
147 *
148 * @param arg0 a pointer to the int representing the server file descriptor
149 * @param arg1 ignored
150 * @param arg2 ignored
151 */
server_thread_fn(void * arg0,void * arg1,void * arg2)152 static void server_thread_fn(void *arg0, void *arg1, void *arg2)
153 {
154 const int server_fd = POINTER_TO_INT(arg0);
155 const int echo = POINTER_TO_INT(arg1);
156 const int expect_failure = POINTER_TO_INT(arg2);
157
158 int r;
159 int client_fd;
160 net_socklen_t addrlen;
161 char addrstr[INET_ADDRSTRLEN];
162 struct net_sockaddr_in sa;
163 char *addrstrp;
164
165 k_thread_name_set(k_current_get(), "server");
166
167 NET_DBG("Server thread running");
168
169 memset(&sa, 0, sizeof(sa));
170 addrlen = sizeof(sa);
171
172 NET_DBG("Accepting client connection..");
173 k_sem_give(&server_sem);
174 r = accept(server_fd, (struct net_sockaddr *)&sa, &addrlen);
175 if (expect_failure) {
176 zassert_equal(r, -1, "accept() should've failed");
177 return;
178 }
179 zassert_not_equal(r, -1, "accept() failed (%d)", r);
180 client_fd = r;
181
182 memset(addrstr, '\0', sizeof(addrstr));
183 addrstrp = (char *)inet_ntop(AF_INET, &sa.sin_addr,
184 addrstr, sizeof(addrstr));
185 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
186
187 NET_DBG("accepted connection from [%s]:%d as fd %d",
188 addrstr, net_ntohs(sa.sin_port), client_fd);
189
190 if (echo) {
191 NET_DBG("calling recv()");
192 r = recv(client_fd, addrstr, sizeof(addrstr), 0);
193 zassert_not_equal(r, -1, "recv() failed (%d)", errno);
194 zassert_equal(r, SECRET_SIZE, "expected: %zu actual: %d",
195 SECRET_SIZE, r);
196
197 NET_DBG("calling send()");
198 r = send(client_fd, SECRET, SECRET_SIZE, 0);
199 zassert_not_equal(r, -1, "send() failed (%d)", errno);
200 zassert_equal(r, SECRET_SIZE, "expected: %zu actual: %d",
201 SECRET_SIZE, r);
202 }
203
204 NET_DBG("closing client fd");
205 r = close(client_fd);
206 zassert_not_equal(r, -1, "close() failed on the server fd (%d)", errno);
207 }
208
test_configure_server(k_tid_t * server_thread_id,int peer_verify,int echo,int expect_failure)209 static int test_configure_server(k_tid_t *server_thread_id, int peer_verify,
210 int echo, int expect_failure)
211 {
212 static const sec_tag_t server_tag_list_verify_none[] = {
213 SERVER_CERTIFICATE_TAG,
214 };
215
216 static const sec_tag_t server_tag_list_verify[] = {
217 CA_CERTIFICATE_TAG,
218 SERVER_CERTIFICATE_TAG,
219 };
220
221 char addrstr[INET_ADDRSTRLEN];
222 const sec_tag_t *sec_tag_list;
223 size_t sec_tag_list_size;
224 struct net_sockaddr_in sa;
225 const int yes = true;
226 char *addrstrp;
227 int server_fd;
228 int r;
229
230 k_sem_init(&server_sem, 0, 1);
231
232 NET_DBG("Creating server socket");
233 r = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
234 zassert_not_equal(r, -1, "failed to create server socket (%d)", errno);
235 server_fd = r;
236
237 r = setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
238 zassert_not_equal(r, -1, "failed to set SO_REUSEADDR (%d)", errno);
239
240 switch (peer_verify) {
241 case TLS_PEER_VERIFY_NONE:
242 sec_tag_list = server_tag_list_verify_none;
243 sec_tag_list_size = sizeof(server_tag_list_verify_none);
244 break;
245 case TLS_PEER_VERIFY_OPTIONAL:
246 case TLS_PEER_VERIFY_REQUIRED:
247 sec_tag_list = server_tag_list_verify;
248 sec_tag_list_size = sizeof(server_tag_list_verify);
249
250 r = setsockopt(server_fd, SOL_TLS, TLS_PEER_VERIFY,
251 &peer_verify, sizeof(peer_verify));
252 zassert_not_equal(r, -1, "failed to set TLS_PEER_VERIFY (%d)",
253 errno);
254 break;
255 default:
256 zassert_true(false, "unrecognized TLS peer verify type %d",
257 peer_verify);
258 return -1;
259 }
260
261 r = setsockopt(server_fd, SOL_TLS, TLS_SEC_TAG_LIST,
262 sec_tag_list, sec_tag_list_size);
263 zassert_not_equal(r, -1, "failed to set TLS_SEC_TAG_LIST (%d)", errno);
264
265 r = setsockopt(server_fd, SOL_TLS, TLS_HOSTNAME, "localhost",
266 sizeof("localhost"));
267 zassert_not_equal(r, -1, "failed to set TLS_HOSTNAME (%d)", errno);
268
269 memset(&sa, 0, sizeof(sa));
270 /* The server listens on all network interfaces */
271 sa.sin_addr.s_addr = NET_INADDR_ANY;
272 sa.sin_family = AF_INET;
273 sa.sin_port = net_htons(PORT);
274
275 r = bind(server_fd, (struct net_sockaddr *)&sa, sizeof(sa));
276 zassert_not_equal(r, -1, "failed to bind (%d)", errno);
277
278 r = listen(server_fd, 1);
279 zassert_not_equal(r, -1, "failed to listen (%d)", errno);
280
281 memset(addrstr, '\0', sizeof(addrstr));
282 addrstrp = (char *)inet_ntop(AF_INET, &sa.sin_addr,
283 addrstr, sizeof(addrstr));
284 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
285
286 NET_DBG("listening on [%s]:%d as fd %d",
287 addrstr, net_ntohs(sa.sin_port), server_fd);
288
289 NET_DBG("Creating server thread");
290 *server_thread_id = k_thread_create(&server_thread, server_stack,
291 STACK_SIZE, server_thread_fn,
292 INT_TO_POINTER(server_fd),
293 INT_TO_POINTER(echo),
294 INT_TO_POINTER(expect_failure),
295 K_PRIO_PREEMPT(8), 0, K_NO_WAIT);
296
297 r = k_sem_take(&server_sem, K_MSEC(TIMEOUT));
298 zassert_equal(0, r, "failed to synchronize with server thread (%d)", r);
299
300 return server_fd;
301 }
302
test_configure_client(struct net_sockaddr_in * sa,bool own_cert,const char * hostname)303 static int test_configure_client(struct net_sockaddr_in *sa, bool own_cert,
304 const char *hostname)
305 {
306 static const sec_tag_t client_tag_list_verify_none[] = {
307 CA_CERTIFICATE_TAG,
308 };
309
310 static const sec_tag_t client_tag_list_verify[] = {
311 CA_CERTIFICATE_TAG,
312 CLIENT_CERTIFICATE_TAG,
313 };
314
315 char addrstr[INET_ADDRSTRLEN];
316 const sec_tag_t *sec_tag_list;
317 size_t sec_tag_list_size;
318 char *addrstrp;
319 int client_fd;
320 int r;
321
322 k_thread_name_set(k_current_get(), "client");
323
324 NET_DBG("Creating client socket");
325 r = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
326 zassert_not_equal(r, -1, "failed to create client socket (%d)", errno);
327 client_fd = r;
328
329 if (own_cert) {
330 sec_tag_list = client_tag_list_verify;
331 sec_tag_list_size = sizeof(client_tag_list_verify);
332 } else {
333 sec_tag_list = client_tag_list_verify_none;
334 sec_tag_list_size = sizeof(client_tag_list_verify_none);
335 }
336
337 r = setsockopt(client_fd, SOL_TLS, TLS_SEC_TAG_LIST,
338 sec_tag_list, sec_tag_list_size);
339 zassert_not_equal(r, -1, "failed to set TLS_SEC_TAG_LIST (%d)", errno);
340
341 r = setsockopt(client_fd, SOL_TLS, TLS_HOSTNAME, hostname,
342 strlen(hostname) + 1);
343 zassert_not_equal(r, -1, "failed to set TLS_HOSTNAME (%d)", errno);
344
345 sa->sin_family = AF_INET;
346 sa->sin_port = net_htons(PORT);
347 r = inet_pton(AF_INET, MY_IPV4_ADDR, &sa->sin_addr.s_addr);
348 zassert_not_equal(-1, r, "inet_pton() failed (%d)", errno);
349 zassert_not_equal(0, r, "%s is not a valid IPv4 address", MY_IPV4_ADDR);
350 zassert_equal(1, r, "inet_pton() failed to convert %s", MY_IPV4_ADDR);
351
352 memset(addrstr, '\0', sizeof(addrstr));
353 addrstrp = (char *)inet_ntop(AF_INET, &sa->sin_addr,
354 addrstr, sizeof(addrstr));
355 zassert_not_equal(addrstrp, NULL, "inet_ntop() failed (%d)", errno);
356
357 NET_DBG("connecting to [%s]:%d with fd %d",
358 addrstr, net_ntohs(sa->sin_port), client_fd);
359
360 return client_fd;
361 }
test_shutdown(int client_fd,int server_fd,k_tid_t server_thread_id)362 static void test_shutdown(int client_fd, int server_fd, k_tid_t server_thread_id)
363 {
364 int r;
365
366 NET_DBG("closing client fd");
367 r = close(client_fd);
368 zassert_not_equal(-1, r, "close() failed on the client fd (%d)", errno);
369
370 NET_DBG("closing server fd");
371 r = close(server_fd);
372 zassert_not_equal(-1, r, "close() failed on the server fd (%d)", errno);
373
374 r = k_thread_join(&server_thread, K_FOREVER);
375 zassert_equal(0, r, "k_thread_join() failed (%d)", r);
376
377 k_yield();
378 }
379
test_common(int peer_verify)380 static void test_common(int peer_verify)
381 {
382 k_tid_t server_thread_id;
383 struct net_sockaddr_in sa;
384 uint8_t rx_buf[16];
385 int server_fd;
386 int client_fd;
387 int r;
388
389 /*
390 * Server socket setup
391 */
392 server_fd = test_configure_server(&server_thread_id, peer_verify, true,
393 false);
394
395 /*
396 * Client socket setup
397 */
398 client_fd = test_configure_client(&sa, peer_verify != TLS_PEER_VERIFY_NONE,
399 "localhost");
400
401 /*
402 * The main part of the test
403 */
404
405 r = connect(client_fd, (struct net_sockaddr *)&sa, sizeof(sa));
406 zassert_not_equal(r, -1, "failed to connect (%d)", errno);
407
408 NET_DBG("Calling send()");
409 r = send(client_fd, SECRET, SECRET_SIZE, 0);
410 zassert_not_equal(r, -1, "send() failed (%d)", errno);
411 zassert_equal(SECRET_SIZE, r, "expected: %zu actual: %d", SECRET_SIZE, r);
412
413 NET_DBG("Calling recv()");
414 memset(rx_buf, 0, sizeof(rx_buf));
415 r = recv(client_fd, rx_buf, sizeof(rx_buf), 0);
416 zassert_not_equal(r, -1, "recv() failed (%d)", errno);
417 zassert_equal(SECRET_SIZE, r, "expected: %zu actual: %d", SECRET_SIZE, r);
418 zassert_mem_equal(SECRET, rx_buf, SECRET_SIZE,
419 "expected: %s actual: %s", SECRET, rx_buf);
420
421 /*
422 * Cleanup resources
423 */
424 test_shutdown(client_fd, server_fd, server_thread_id);
425 }
426
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_none)427 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_none)
428 {
429 test_common(TLS_PEER_VERIFY_NONE);
430 }
431
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_optional)432 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_optional)
433 {
434 test_common(TLS_PEER_VERIFY_OPTIONAL);
435 }
436
ZTEST(net_socket_tls_api_extension,test_tls_peer_verify_required)437 ZTEST(net_socket_tls_api_extension, test_tls_peer_verify_required)
438 {
439 test_common(TLS_PEER_VERIFY_REQUIRED);
440 }
441
test_tls_cert_verify_result_opt_common(uint32_t expect)442 static void test_tls_cert_verify_result_opt_common(uint32_t expect)
443 {
444 int server_fd, client_fd, ret;
445 k_tid_t server_thread_id;
446 struct net_sockaddr_in sa;
447 uint32_t optval;
448 net_socklen_t optlen = sizeof(optval);
449 const char *hostname = "localhost";
450 int peer_verify = TLS_PEER_VERIFY_OPTIONAL;
451
452 if (expect == MBEDTLS_X509_BADCERT_CN_MISMATCH) {
453 hostname = "dummy";
454 }
455
456 server_fd = test_configure_server(&server_thread_id, TLS_PEER_VERIFY_NONE,
457 false, false);
458 client_fd = test_configure_client(&sa, false, hostname);
459
460 ret = zsock_setsockopt(client_fd, SOL_TLS, TLS_PEER_VERIFY,
461 &peer_verify, sizeof(peer_verify));
462 zassert_ok(ret, "failed to set TLS_PEER_VERIFY (%d)", errno);
463
464 ret = zsock_connect(client_fd, (struct net_sockaddr *)&sa, sizeof(sa));
465 zassert_not_equal(ret, -1, "failed to connect (%d)", errno);
466
467 ret = zsock_getsockopt(client_fd, SOL_TLS, TLS_CERT_VERIFY_RESULT,
468 &optval, &optlen);
469 zassert_equal(ret, 0, "getsockopt failed (%d)", errno);
470 zassert_equal(optval, expect, "getsockopt got invalid verify result %d",
471 optval);
472
473 test_shutdown(client_fd, server_fd, server_thread_id);
474 }
475
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_result_opt_ok)476 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_ok)
477 {
478 test_tls_cert_verify_result_opt_common(0);
479 }
480
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_result_opt_bad_cn)481 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_bad_cn)
482 {
483 test_tls_cert_verify_result_opt_common(MBEDTLS_X509_BADCERT_CN_MISMATCH);
484 }
485
486 struct test_cert_verify_ctx {
487 bool cb_called;
488 int result;
489 };
490
cert_verify_cb(void * ctx,mbedtls_x509_crt * crt,int depth,uint32_t * flags)491 static int cert_verify_cb(void *ctx, mbedtls_x509_crt *crt, int depth,
492 uint32_t *flags)
493 {
494 struct test_cert_verify_ctx *test_ctx = (struct test_cert_verify_ctx *)ctx;
495
496 test_ctx->cb_called = true;
497
498 if (test_ctx->result == 0) {
499 *flags = 0;
500 } else {
501 *flags |= MBEDTLS_X509_BADCERT_NOT_TRUSTED;
502 }
503
504 return test_ctx->result;
505 }
506
test_tls_cert_verify_cb_opt_common(int result)507 static void test_tls_cert_verify_cb_opt_common(int result)
508 {
509 int server_fd, client_fd, ret;
510 k_tid_t server_thread_id;
511 struct net_sockaddr_in sa;
512 struct test_cert_verify_ctx ctx = {
513 .cb_called = false,
514 .result = result,
515 };
516 struct tls_cert_verify_cb cb = {
517 .cb = cert_verify_cb,
518 .ctx = &ctx,
519 };
520
521 server_fd = test_configure_server(&server_thread_id, TLS_PEER_VERIFY_NONE,
522 false, result == 0 ? false : true);
523 client_fd = test_configure_client(&sa, false, "localhost");
524
525 ret = zsock_setsockopt(client_fd, SOL_TLS, TLS_CERT_VERIFY_CALLBACK,
526 &cb, sizeof(cb));
527 zassert_ok(ret, "failed to set TLS_CERT_VERIFY_CALLBACK (%d)", errno);
528
529 ret = zsock_connect(client_fd, (struct net_sockaddr *)&sa, sizeof(sa));
530 zassert_true(ctx.cb_called, "callback not called");
531 if (result == 0) {
532 zassert_equal(ret, 0, "failed to connect (%d)", errno);
533 } else {
534 zassert_equal(ret, -1, "connect() should fail");
535 zassert_equal(errno, ECONNABORTED, "invalid errno");
536 }
537
538 test_shutdown(client_fd, server_fd, server_thread_id);
539 }
540
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_cb_opt_ok)541 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_cb_opt_ok)
542 {
543 test_tls_cert_verify_cb_opt_common(0);
544 }
545
ZTEST(net_socket_tls_api_extension,test_tls_cert_verify_cb_opt_bad_cert)546 ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_cb_opt_bad_cert)
547 {
548 test_tls_cert_verify_cb_opt_common(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
549 }
550
setup(void)551 static void *setup(void)
552 {
553 int r;
554
555 /*
556 * Load both client & server credentials
557 *
558 * Normally, this would be split into separate applications but
559 * for testing purposes, we just use separate threads.
560 *
561 * Also, it has to be done before tests are run, otherwise
562 * there are errors due to attempts to load too many certificates.
563 *
564 * The server would normally load
565 * - server public key
566 * - server private key
567 * - ca cert (only when client authentication is required)
568 *
569 * The client would normally load
570 * - ca cert (to verify the server)
571 * - client public key (only when client authentication is required)
572 * - client private key (only when client authentication is required)
573 */
574 if (IS_ENABLED(CONFIG_TLS_CREDENTIALS)) {
575 NET_DBG("Loading credentials");
576 r = tls_credential_add(CA_CERTIFICATE_TAG,
577 TLS_CREDENTIAL_CA_CERTIFICATE,
578 ca, sizeof(ca));
579 zassert_equal(r, 0, "failed to add CA Certificate (%d)", r);
580
581 r = tls_credential_add(SERVER_CERTIFICATE_TAG,
582 TLS_CREDENTIAL_PUBLIC_CERTIFICATE,
583 server, sizeof(server));
584 zassert_equal(r, 0, "failed to add Server Certificate (%d)", r);
585
586 r = tls_credential_add(SERVER_CERTIFICATE_TAG,
587 TLS_CREDENTIAL_PRIVATE_KEY,
588 server_privkey, sizeof(server_privkey));
589 zassert_equal(r, 0, "failed to add Server Private Key (%d)", r);
590
591 r = tls_credential_add(CLIENT_CERTIFICATE_TAG,
592 TLS_CREDENTIAL_PUBLIC_CERTIFICATE,
593 client, sizeof(client));
594 zassert_equal(r, 0, "failed to add Client Certificate (%d)", r);
595
596 r = tls_credential_add(CLIENT_CERTIFICATE_TAG,
597 TLS_CREDENTIAL_PRIVATE_KEY,
598 client_privkey, sizeof(client_privkey));
599 zassert_equal(r, 0, "failed to add Client Private Key (%d)", r);
600 }
601 return NULL;
602 }
603
604 ZTEST_SUITE(net_socket_tls_api_extension, NULL, setup, NULL, NULL, NULL);
605