1 /*
2 * Copyright (c) 2018 Intel Corporation
3 * Copyright (c) 2018 Nordic Semiconductor ASA
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <stdbool.h>
9 #include <zephyr/posix/fcntl.h>
10
11 #include <zephyr/logging/log.h>
12 LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL);
13
14 #include <zephyr/init.h>
15 #include <zephyr/sys/util.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/random/random.h>
18 #include <zephyr/internal/syscall_handler.h>
19 #include <zephyr/sys/fdtable.h>
20
21 /* TODO: Remove all direct access to private fields.
22 * According with Mbed TLS migration guide:
23 *
24 * Direct access to fields of structures
25 * (`struct` types) declared in public headers is no longer
26 * supported. In Mbed TLS 3, the layout of structures is not
27 * considered part of the stable API, and minor versions (3.1, 3.2,
28 * etc.) may add, remove, rename, reorder or change the type of
29 * structure fields.
30 */
31 #if !defined(MBEDTLS_ALLOW_PRIVATE_ACCESS)
32 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
33 #endif
34
35 #if defined(CONFIG_MBEDTLS)
36 #if !defined(CONFIG_MBEDTLS_CFG_FILE)
37 #include "mbedtls/config.h"
38 #else
39 #include CONFIG_MBEDTLS_CFG_FILE
40 #endif /* CONFIG_MBEDTLS_CFG_FILE */
41
42 #include <mbedtls/net_sockets.h>
43 #include <mbedtls/x509.h>
44 #include <mbedtls/x509_crt.h>
45 #include <mbedtls/ssl.h>
46 #include <mbedtls/ssl_cookie.h>
47 #include <mbedtls/error.h>
48 #include <mbedtls/platform.h>
49 #include <mbedtls/ssl_cache.h>
50 #endif /* CONFIG_MBEDTLS */
51
52 #include "sockets_internal.h"
53 #include "tls_internal.h"
54
55 #if defined(CONFIG_MBEDTLS_DEBUG)
56 #include <zephyr_mbedtls_priv.h>
57 #endif
58
59 #if defined(CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS)
60 #define ALPN_MAX_PROTOCOLS (CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS + 1)
61 #else
62 #define ALPN_MAX_PROTOCOLS 0
63 #endif /* CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS */
64
65 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
66 #define DTLS_SENDMSG_BUF_SIZE (CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE)
67 #else
68 #define DTLS_SENDMSG_BUF_SIZE 0
69 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
70
71 static const struct socket_op_vtable tls_sock_fd_op_vtable;
72
73 #ifndef MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED
74 #define MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE
75 #endif
76
77 /** A list of secure tags that TLS context should use. */
78 struct sec_tag_list {
79 /** An array of secure tags referencing TLS credentials. */
80 sec_tag_t sec_tags[CONFIG_NET_SOCKETS_TLS_MAX_CREDENTIALS];
81
82 /** Number of configured secure tags. */
83 int sec_tag_count;
84 };
85
86 /** Timer context for DTLS. */
87 struct dtls_timing_context {
88 /** Current time, stored during timer set. */
89 uint32_t snapshot;
90
91 /** Intermediate delay value. For details, refer to mbedTLS API
92 * documentation (mbedtls_ssl_set_timer_t).
93 */
94 uint32_t int_ms;
95
96 /** Final delay value. For details, refer to mbedTLS API documentation
97 * (mbedtls_ssl_set_timer_t).
98 */
99 uint32_t fin_ms;
100 };
101
102 /** TLS peer address/session ID mapping. */
103 struct tls_session_cache {
104 /** Creation time. */
105 int64_t timestamp;
106
107 /** Peer address. */
108 struct sockaddr peer_addr;
109
110 /** Session buffer. */
111 uint8_t *session;
112
113 /** Session length. */
114 size_t session_len;
115 };
116
117 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
118 struct tls_dtls_cid {
119 bool enabled;
120 unsigned char cid[MAX(MBEDTLS_SSL_CID_OUT_LEN_MAX,
121 MBEDTLS_SSL_CID_IN_LEN_MAX)];
122 size_t cid_len;
123 };
124 #endif
125
126 /** TLS context information. */
127 __net_socket struct tls_context {
128 /** Underlying TCP/UDP socket. */
129 int sock;
130
131 /** Information whether TLS context is used. */
132 bool is_used : 1;
133
134 /** Information whether TLS context was initialized. */
135 bool is_initialized : 1;
136
137 /** Information whether underlying socket is listening. */
138 bool is_listening : 1;
139
140 /** Information whether TLS handshake is currently in progress. */
141 bool handshake_in_progress : 1;
142
143 /** Session ended at the TLS/DTLS level. */
144 bool session_closed : 1;
145
146 /** Socket type. */
147 enum net_sock_type type;
148
149 /** Secure protocol version running on TLS context. */
150 enum net_ip_protocol_secure tls_version;
151
152 /** Socket flags passed to a socket call. */
153 int flags;
154
155 /* Indicates whether socket is in error state at TLS/DTLS level. */
156 int error;
157
158 /** Information whether TLS handshake is complete or not. */
159 struct k_sem tls_established;
160
161 /* TLS socket mutex lock. */
162 struct k_mutex *lock;
163
164 /** TLS specific option values. */
165 struct {
166 /** Select which credentials to use with TLS. */
167 struct sec_tag_list sec_tag_list;
168
169 /** 0-terminated list of allowed ciphersuites (mbedTLS format).
170 */
171 int ciphersuites[CONFIG_NET_SOCKETS_TLS_MAX_CIPHERSUITES + 1];
172
173 /** Information if hostname was explicitly set on a socket. */
174 bool is_hostname_set;
175
176 /** Peer verification level. */
177 int8_t verify_level;
178
179 /** Indicating on whether DER certificates should not be copied
180 * to the heap.
181 */
182 int8_t cert_nocopy;
183
184 /** DTLS role, client by default. */
185 int8_t role;
186
187 /** NULL-terminated list of allowed application layer
188 * protocols.
189 */
190 const char *alpn_list[ALPN_MAX_PROTOCOLS];
191
192 /** Session cache enabled on a socket. */
193 bool cache_enabled;
194
195 /** Socket TX timeout */
196 k_timeout_t timeout_tx;
197
198 /** Socket RX timeout */
199 k_timeout_t timeout_rx;
200
201 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
202 /* DTLS handshake timeout */
203 uint32_t dtls_handshake_timeout_min;
204 uint32_t dtls_handshake_timeout_max;
205
206 struct tls_dtls_cid dtls_cid;
207
208 bool dtls_handshake_on_connect;
209 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
210 } options;
211
212 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
213 /** Context information for DTLS timing. */
214 struct dtls_timing_context dtls_timing;
215
216 /** mbedTLS cookie context for DTLS */
217 mbedtls_ssl_cookie_ctx cookie;
218
219 /** DTLS peer address. */
220 struct sockaddr dtls_peer_addr;
221
222 /** DTLS peer address length. */
223 socklen_t dtls_peer_addrlen;
224 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
225
226 #if defined(CONFIG_MBEDTLS)
227 /** mbedTLS context. */
228 mbedtls_ssl_context ssl;
229
230 /** mbedTLS configuration. */
231 mbedtls_ssl_config config;
232
233 #if defined(MBEDTLS_X509_CRT_PARSE_C)
234 /** mbedTLS structure for CA chain. */
235 mbedtls_x509_crt ca_chain;
236
237 /** mbedTLS structure for own certificate. */
238 mbedtls_x509_crt own_cert;
239
240 /** mbedTLS structure for own private key. */
241 mbedtls_pk_context priv_key;
242 #endif /* MBEDTLS_X509_CRT_PARSE_C */
243
244 #endif /* CONFIG_MBEDTLS */
245 };
246
247
248 /* A global pool of TLS contexts. */
249 static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS];
250
251 static struct tls_session_cache client_cache[CONFIG_NET_SOCKETS_TLS_MAX_CLIENT_SESSION_COUNT];
252
253 #if defined(MBEDTLS_SSL_CACHE_C)
254 static mbedtls_ssl_cache_context server_cache;
255 #endif
256
257 /* A mutex for protecting TLS context allocation. */
258 static struct k_mutex context_lock;
259
260 /* Arbitrary delay value to wait if mbedTLS reports it cannot proceed for
261 * reasons other than TX/RX block.
262 */
263 #define TLS_WAIT_MS 100
264
tls_session_cache_reset(void)265 static void tls_session_cache_reset(void)
266 {
267 for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
268 if (client_cache[i].session != NULL) {
269 mbedtls_free(client_cache[i].session);
270 }
271 }
272
273 (void)memset(client_cache, 0, sizeof(client_cache));
274 }
275
net_socket_is_tls(void * obj)276 bool net_socket_is_tls(void *obj)
277 {
278 return PART_OF_ARRAY(tls_contexts, (struct tls_context *)obj);
279 }
280
tls_ctr_drbg_random(void * ctx,unsigned char * buf,size_t len)281 static int tls_ctr_drbg_random(void *ctx, unsigned char *buf, size_t len)
282 {
283 ARG_UNUSED(ctx);
284
285 #if defined(CONFIG_CSPRNG_ENABLED)
286 return sys_csrand_get(buf, len);
287 #else
288 sys_rand_get(buf, len);
289
290 return 0;
291 #endif
292 }
293
294 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
295 /* mbedTLS-defined function for setting timer. */
dtls_timing_set_delay(void * data,uint32_t int_ms,uint32_t fin_ms)296 static void dtls_timing_set_delay(void *data, uint32_t int_ms, uint32_t fin_ms)
297 {
298 struct dtls_timing_context *ctx = data;
299
300 ctx->int_ms = int_ms;
301 ctx->fin_ms = fin_ms;
302
303 if (fin_ms != 0U) {
304 ctx->snapshot = k_uptime_get_32();
305 }
306 }
307
308 /* mbedTLS-defined function for getting timer status.
309 * The return values are specified by mbedTLS. The callback must return:
310 * -1 if cancelled (fin_ms == 0),
311 * 0 if none of the delays have passed,
312 * 1 if only the intermediate delay has passed,
313 * 2 if the final delay has passed.
314 */
dtls_timing_get_delay(void * data)315 static int dtls_timing_get_delay(void *data)
316 {
317 struct dtls_timing_context *timing = data;
318 unsigned long elapsed_ms;
319
320 NET_ASSERT(timing);
321
322 if (timing->fin_ms == 0U) {
323 return -1;
324 }
325
326 elapsed_ms = k_uptime_get_32() - timing->snapshot;
327
328 if (elapsed_ms >= timing->fin_ms) {
329 return 2;
330 }
331
332 if (elapsed_ms >= timing->int_ms) {
333 return 1;
334 }
335
336 return 0;
337 }
338
dtls_get_remaining_timeout(struct tls_context * ctx)339 static int dtls_get_remaining_timeout(struct tls_context *ctx)
340 {
341 struct dtls_timing_context *timing = &ctx->dtls_timing;
342 uint32_t elapsed_ms;
343
344 elapsed_ms = k_uptime_get_32() - timing->snapshot;
345
346 if (timing->fin_ms == 0U) {
347 return SYS_FOREVER_MS;
348 }
349
350 if (elapsed_ms >= timing->fin_ms) {
351 return 0;
352 }
353
354 return timing->fin_ms - elapsed_ms;
355 }
356 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
357
358 /* Initialize TLS internals. */
tls_init(void)359 static int tls_init(void)
360 {
361
362 #if !defined(CONFIG_ENTROPY_HAS_DRIVER)
363 NET_WARN("No entropy device on the system, "
364 "TLS communication is insecure!");
365 #endif
366
367 (void)memset(tls_contexts, 0, sizeof(tls_contexts));
368 (void)memset(client_cache, 0, sizeof(client_cache));
369
370 k_mutex_init(&context_lock);
371
372 #if defined(MBEDTLS_SSL_CACHE_C)
373 mbedtls_ssl_cache_init(&server_cache);
374 #endif
375
376 return 0;
377 }
378
379 SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT);
380
is_handshake_complete(struct tls_context * ctx)381 static inline bool is_handshake_complete(struct tls_context *ctx)
382 {
383 return k_sem_count_get(&ctx->tls_established) != 0;
384 }
385
386 /*
387 * Copied from include/mbedtls/ssl_internal.h
388 *
389 * Maximum length we can advertise as our max content length for
390 * RFC 6066 max_fragment_length extension negotiation purposes
391 * (the lesser of both sizes, if they are unequal.)
392 */
393 #define MBEDTLS_TLS_EXT_ADV_CONTENT_LEN ( \
394 (MBEDTLS_SSL_IN_CONTENT_LEN > MBEDTLS_SSL_OUT_CONTENT_LEN) \
395 ? (MBEDTLS_SSL_OUT_CONTENT_LEN) \
396 : (MBEDTLS_SSL_IN_CONTENT_LEN) \
397 )
398
399 #if defined(CONFIG_NET_SOCKETS_TLS_SET_MAX_FRAGMENT_LENGTH) && \
400 defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) && \
401 (MBEDTLS_TLS_EXT_ADV_CONTENT_LEN < 16384)
402
403 BUILD_ASSERT(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN >= 512,
404 "Too small content length!");
405
tls_mfl_code_from_content_len(size_t len)406 static inline unsigned char tls_mfl_code_from_content_len(size_t len)
407 {
408 if (len >= 4096) {
409 return MBEDTLS_SSL_MAX_FRAG_LEN_4096;
410 } else if (len >= 2048) {
411 return MBEDTLS_SSL_MAX_FRAG_LEN_2048;
412 } else if (len >= 1024) {
413 return MBEDTLS_SSL_MAX_FRAG_LEN_1024;
414 } else if (len >= 512) {
415 return MBEDTLS_SSL_MAX_FRAG_LEN_512;
416 } else {
417 return MBEDTLS_SSL_MAX_FRAG_LEN_INVALID;
418 }
419 }
420
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)421 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type)
422 {
423 unsigned char mfl_code;
424 size_t len = MBEDTLS_TLS_EXT_ADV_CONTENT_LEN;
425
426 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
427 if (type == SOCK_DGRAM && len > CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH) {
428 len = CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH;
429 }
430 #endif
431 mfl_code = tls_mfl_code_from_content_len(len);
432
433 mbedtls_ssl_conf_max_frag_len(config, mfl_code);
434 }
435 #else
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)436 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type) {}
437 #endif
438
439 /* Allocate TLS context. */
tls_alloc(void)440 static struct tls_context *tls_alloc(void)
441 {
442 int i;
443 struct tls_context *tls = NULL;
444
445 k_mutex_lock(&context_lock, K_FOREVER);
446
447 for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) {
448 if (!tls_contexts[i].is_used) {
449 tls = &tls_contexts[i];
450 (void)memset(tls, 0, sizeof(*tls));
451 tls->is_used = true;
452 tls->options.verify_level = -1;
453 tls->options.timeout_tx = K_FOREVER;
454 tls->options.timeout_rx = K_FOREVER;
455 tls->sock = -1;
456
457 NET_DBG("Allocated TLS context, %p", tls);
458 break;
459 }
460 }
461
462 k_mutex_unlock(&context_lock);
463
464 if (tls) {
465 k_sem_init(&tls->tls_established, 0, 1);
466
467 mbedtls_ssl_init(&tls->ssl);
468 mbedtls_ssl_config_init(&tls->config);
469 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
470 mbedtls_ssl_cookie_init(&tls->cookie);
471 tls->options.dtls_handshake_timeout_min =
472 MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN;
473 tls->options.dtls_handshake_timeout_max =
474 MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX;
475 tls->options.dtls_cid.cid_len = 0;
476 tls->options.dtls_cid.enabled = false;
477 tls->options.dtls_handshake_on_connect = true;
478 #endif
479 #if defined(MBEDTLS_X509_CRT_PARSE_C)
480 mbedtls_x509_crt_init(&tls->ca_chain);
481 mbedtls_x509_crt_init(&tls->own_cert);
482 mbedtls_pk_init(&tls->priv_key);
483 #endif
484
485 #if defined(CONFIG_MBEDTLS_DEBUG)
486 mbedtls_ssl_conf_dbg(&tls->config, zephyr_mbedtls_debug, NULL);
487 #endif
488 } else {
489 NET_WARN("Failed to allocate TLS context");
490 }
491
492 return tls;
493 }
494
495 /* Allocate new TLS context and copy the content from the source context. */
tls_clone(struct tls_context * source_tls)496 static struct tls_context *tls_clone(struct tls_context *source_tls)
497 {
498 struct tls_context *target_tls;
499
500 target_tls = tls_alloc();
501 if (!target_tls) {
502 return NULL;
503 }
504
505 target_tls->tls_version = source_tls->tls_version;
506 target_tls->type = source_tls->type;
507
508 memcpy(&target_tls->options, &source_tls->options,
509 sizeof(target_tls->options));
510
511 #if defined(MBEDTLS_X509_CRT_PARSE_C)
512 if (target_tls->options.is_hostname_set) {
513 mbedtls_ssl_set_hostname(&target_tls->ssl,
514 source_tls->ssl.hostname);
515 }
516 #endif
517
518 return target_tls;
519 }
520
521 /* Release TLS context. */
tls_release(struct tls_context * tls)522 static int tls_release(struct tls_context *tls)
523 {
524 if (!PART_OF_ARRAY(tls_contexts, tls)) {
525 NET_ERR("Invalid TLS context");
526 return -EBADF;
527 }
528
529 if (!tls->is_used) {
530 NET_ERR("Deallocating unused TLS context");
531 return -EBADF;
532 }
533
534 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
535 mbedtls_ssl_cookie_free(&tls->cookie);
536 #endif
537 mbedtls_ssl_config_free(&tls->config);
538 mbedtls_ssl_free(&tls->ssl);
539 #if defined(MBEDTLS_X509_CRT_PARSE_C)
540 mbedtls_x509_crt_free(&tls->ca_chain);
541 mbedtls_x509_crt_free(&tls->own_cert);
542 mbedtls_pk_free(&tls->priv_key);
543 #endif
544
545 tls->is_used = false;
546
547 return 0;
548 }
549
peer_addr_cmp(const struct sockaddr * addr,const struct sockaddr * peer_addr)550 static bool peer_addr_cmp(const struct sockaddr *addr,
551 const struct sockaddr *peer_addr)
552 {
553 if (addr->sa_family != peer_addr->sa_family) {
554 return false;
555 }
556
557 if (IS_ENABLED(CONFIG_NET_IPV6) && peer_addr->sa_family == AF_INET6) {
558 struct sockaddr_in6 *addr1 = net_sin6(peer_addr);
559 struct sockaddr_in6 *addr2 = net_sin6(addr);
560
561 return (addr1->sin6_port == addr2->sin6_port) &&
562 net_ipv6_addr_cmp(&addr1->sin6_addr, &addr2->sin6_addr);
563 } else if (IS_ENABLED(CONFIG_NET_IPV4) && peer_addr->sa_family == AF_INET) {
564 struct sockaddr_in *addr1 = net_sin(peer_addr);
565 struct sockaddr_in *addr2 = net_sin(addr);
566
567 return (addr1->sin_port == addr2->sin_port) &&
568 net_ipv4_addr_cmp(&addr1->sin_addr, &addr2->sin_addr);
569 }
570
571 return false;
572 }
573
tls_session_save(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)574 static int tls_session_save(const struct sockaddr *peer_addr,
575 mbedtls_ssl_session *session)
576 {
577 struct tls_session_cache *entry = NULL;
578 size_t session_len;
579 int ret;
580
581 for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
582 if (client_cache[i].session == NULL) {
583 /* New entry. */
584 if (entry == NULL || entry->session != NULL) {
585 entry = &client_cache[i];
586 }
587 } else {
588 if (peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
589 /* Reuse old entry for given address. */
590 entry = &client_cache[i];
591 break;
592 }
593
594 /* Remember the oldest entry and reuse if needed. */
595 if (entry == NULL ||
596 (entry->session != NULL &&
597 entry->timestamp < client_cache[i].timestamp)) {
598 entry = &client_cache[i];
599 }
600 }
601 }
602
603 /* Allocate session and save */
604
605 if (entry->session != NULL) {
606 mbedtls_free(entry->session);
607 entry->session = NULL;
608 }
609
610 (void)mbedtls_ssl_session_save(session, NULL, 0, &session_len);
611
612 entry->session = mbedtls_calloc(1, session_len);
613 if (entry->session == NULL) {
614 NET_ERR("Failed to allocate session buffer.");
615 return -ENOMEM;
616 }
617
618 ret = mbedtls_ssl_session_save(session, entry->session, session_len,
619 &session_len);
620 if (ret < 0) {
621 NET_ERR("Failed to serialize session, err: -0x%x.", -ret);
622 mbedtls_free(entry->session);
623 entry->session = NULL;
624 return -ENOMEM;
625 }
626
627 entry->session_len = session_len;
628 entry->timestamp = k_uptime_get();
629 memcpy(&entry->peer_addr, peer_addr, sizeof(*peer_addr));
630
631 return 0;
632 }
633
tls_session_get(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)634 static int tls_session_get(const struct sockaddr *peer_addr,
635 mbedtls_ssl_session *session)
636 {
637 struct tls_session_cache *entry = NULL;
638 int ret;
639
640 for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
641 if (client_cache[i].session != NULL &&
642 peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
643 entry = &client_cache[i];
644 break;
645 }
646 }
647
648 if (entry == NULL) {
649 return -ENOENT;
650 }
651
652 ret = mbedtls_ssl_session_load(session, entry->session,
653 entry->session_len);
654 if (ret < 0) {
655 /* Discard corrupted session data. */
656 mbedtls_free(entry->session);
657 entry->session = NULL;
658 NET_ERR("Failed to load TLS session %d", ret);
659 return -EIO;
660 }
661
662 return 0;
663 }
664
tls_session_store(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)665 static void tls_session_store(struct tls_context *context,
666 const struct sockaddr *addr,
667 socklen_t addrlen)
668 {
669 mbedtls_ssl_session session;
670 struct sockaddr peer_addr = { 0 };
671 int ret;
672
673 if (!context->options.cache_enabled) {
674 return;
675 }
676
677 memcpy(&peer_addr, addr, addrlen);
678 mbedtls_ssl_session_init(&session);
679
680 ret = mbedtls_ssl_get_session(&context->ssl, &session);
681 if (ret < 0) {
682 NET_ERR("Failed to obtain session for %p", context);
683 goto exit;
684 }
685
686 ret = tls_session_save(&peer_addr, &session);
687 if (ret < 0) {
688 NET_ERR("Failed to save session for %p", context);
689 }
690
691 exit:
692 mbedtls_ssl_session_free(&session);
693 }
694
tls_session_restore(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)695 static void tls_session_restore(struct tls_context *context,
696 const struct sockaddr *addr,
697 socklen_t addrlen)
698 {
699 mbedtls_ssl_session session;
700 struct sockaddr peer_addr = { 0 };
701 int ret;
702
703 if (!context->options.cache_enabled) {
704 return;
705 }
706
707 memcpy(&peer_addr, addr, addrlen);
708 mbedtls_ssl_session_init(&session);
709
710 ret = tls_session_get(&peer_addr, &session);
711 if (ret < 0) {
712 NET_DBG("Session not found for %p", context);
713 goto exit;
714 }
715
716 ret = mbedtls_ssl_set_session(&context->ssl, &session);
717 if (ret < 0) {
718 NET_ERR("Failed to set session for %p", context);
719 }
720
721 exit:
722 mbedtls_ssl_session_free(&session);
723 }
724
tls_session_purge(void)725 static void tls_session_purge(void)
726 {
727 tls_session_cache_reset();
728
729 #if defined(MBEDTLS_SSL_CACHE_C)
730 mbedtls_ssl_cache_free(&server_cache);
731 mbedtls_ssl_cache_init(&server_cache);
732 #endif
733 }
734
time_left(uint32_t start,uint32_t timeout)735 static inline int time_left(uint32_t start, uint32_t timeout)
736 {
737 uint32_t elapsed = k_uptime_get_32() - start;
738
739 return timeout - elapsed;
740 }
741
wait(int sock,int timeout,int event)742 static int wait(int sock, int timeout, int event)
743 {
744 struct zsock_pollfd fds = {
745 .fd = sock,
746 .events = event,
747 };
748 int ret;
749
750 ret = zsock_poll(&fds, 1, timeout);
751 if (ret < 0) {
752 return ret;
753 }
754
755 if (ret == 1) {
756 if (fds.revents & ZSOCK_POLLNVAL) {
757 return -EBADF;
758 }
759
760 if (fds.revents & ZSOCK_POLLERR) {
761 int optval;
762 socklen_t optlen = sizeof(optval);
763
764 if (zsock_getsockopt(fds.fd, SOL_SOCKET, SO_ERROR,
765 &optval, &optlen) == 0) {
766 NET_ERR("TLS underlying socket poll error %d",
767 -optval);
768 return -optval;
769 }
770
771 return -EIO;
772 }
773 }
774
775 return 0;
776 }
777
wait_for_reason(int sock,int timeout,int reason)778 static int wait_for_reason(int sock, int timeout, int reason)
779 {
780 if (reason == MBEDTLS_ERR_SSL_WANT_READ) {
781 return wait(sock, timeout, ZSOCK_POLLIN);
782 }
783
784 if (reason == MBEDTLS_ERR_SSL_WANT_WRITE) {
785 return wait(sock, timeout, ZSOCK_POLLOUT);
786 }
787
788 /* Any other reason - no way to monitor, just wait for some time. */
789 k_msleep(TLS_WAIT_MS);
790
791 return 0;
792 }
793
is_blocking(int sock,int flags)794 static bool is_blocking(int sock, int flags)
795 {
796 int sock_flags = zsock_fcntl(sock, F_GETFL, 0);
797
798 if (sock_flags == -1) {
799 return false;
800 }
801
802 return !((flags & ZSOCK_MSG_DONTWAIT) || (sock_flags & O_NONBLOCK));
803 }
804
timeout_to_ms(k_timeout_t * timeout)805 static int timeout_to_ms(k_timeout_t *timeout)
806 {
807 if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
808 return 0;
809 } else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
810 return SYS_FOREVER_MS;
811 } else {
812 return k_ticks_to_ms_floor32(timeout->ticks);
813 }
814 }
815
ctx_set_lock(struct tls_context * ctx,struct k_mutex * lock)816 static void ctx_set_lock(struct tls_context *ctx, struct k_mutex *lock)
817 {
818 ctx->lock = lock;
819 }
820
821 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
dtls_is_peer_addr_valid(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)822 static bool dtls_is_peer_addr_valid(struct tls_context *context,
823 const struct sockaddr *peer_addr,
824 socklen_t addrlen)
825 {
826 if (context->dtls_peer_addrlen != addrlen) {
827 return false;
828 }
829
830 return peer_addr_cmp(&context->dtls_peer_addr, peer_addr);
831 }
832
dtls_peer_address_set(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)833 static void dtls_peer_address_set(struct tls_context *context,
834 const struct sockaddr *peer_addr,
835 socklen_t addrlen)
836 {
837 if (addrlen <= sizeof(context->dtls_peer_addr)) {
838 memcpy(&context->dtls_peer_addr, peer_addr, addrlen);
839 context->dtls_peer_addrlen = addrlen;
840 }
841 }
842
dtls_peer_address_get(struct tls_context * context,struct sockaddr * peer_addr,socklen_t * addrlen)843 static void dtls_peer_address_get(struct tls_context *context,
844 struct sockaddr *peer_addr,
845 socklen_t *addrlen)
846 {
847 socklen_t len = MIN(context->dtls_peer_addrlen, *addrlen);
848
849 memcpy(peer_addr, &context->dtls_peer_addr, len);
850 *addrlen = len;
851 }
852
dtls_tx(void * ctx,const unsigned char * buf,size_t len)853 static int dtls_tx(void *ctx, const unsigned char *buf, size_t len)
854 {
855 struct tls_context *tls_ctx = ctx;
856 ssize_t sent;
857
858 sent = zsock_sendto(tls_ctx->sock, buf, len, ZSOCK_MSG_DONTWAIT,
859 &tls_ctx->dtls_peer_addr,
860 tls_ctx->dtls_peer_addrlen);
861 if (sent < 0) {
862 if (errno == EAGAIN) {
863 return MBEDTLS_ERR_SSL_WANT_WRITE;
864 }
865
866 return MBEDTLS_ERR_NET_SEND_FAILED;
867 }
868
869 return sent;
870 }
871
dtls_rx(void * ctx,unsigned char * buf,size_t len)872 static int dtls_rx(void *ctx, unsigned char *buf, size_t len)
873 {
874 struct tls_context *tls_ctx = ctx;
875 socklen_t addrlen = sizeof(struct sockaddr);
876 struct sockaddr addr;
877 int err;
878 ssize_t received;
879
880 received = zsock_recvfrom(tls_ctx->sock, buf, len,
881 ZSOCK_MSG_DONTWAIT, &addr, &addrlen);
882 if (received < 0) {
883 if (errno == EAGAIN) {
884 return MBEDTLS_ERR_SSL_WANT_READ;
885 }
886
887 return MBEDTLS_ERR_NET_RECV_FAILED;
888 }
889
890 if (tls_ctx->dtls_peer_addrlen == 0) {
891 /* Only allow to store peer address for DTLS servers. */
892 if (tls_ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
893 dtls_peer_address_set(tls_ctx, &addr, addrlen);
894
895 err = mbedtls_ssl_set_client_transport_id(
896 &tls_ctx->ssl,
897 (const unsigned char *)&addr, addrlen);
898 if (err < 0) {
899 return err;
900 }
901 } else {
902 /* For clients it's incorrect to receive when
903 * no peer has been set up.
904 */
905 return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED;
906 }
907 } else if (!dtls_is_peer_addr_valid(tls_ctx, &addr, addrlen)) {
908 return MBEDTLS_ERR_SSL_WANT_READ;
909 }
910
911 return received;
912 }
913 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
914
tls_tx(void * ctx,const unsigned char * buf,size_t len)915 static int tls_tx(void *ctx, const unsigned char *buf, size_t len)
916 {
917 struct tls_context *tls_ctx = ctx;
918 ssize_t sent;
919
920 sent = zsock_sendto(tls_ctx->sock, buf, len,
921 ZSOCK_MSG_DONTWAIT, NULL, 0);
922 if (sent < 0) {
923 if (errno == EAGAIN) {
924 return MBEDTLS_ERR_SSL_WANT_WRITE;
925 }
926
927 return MBEDTLS_ERR_NET_SEND_FAILED;
928 }
929
930 return sent;
931 }
932
tls_rx(void * ctx,unsigned char * buf,size_t len)933 static int tls_rx(void *ctx, unsigned char *buf, size_t len)
934 {
935 struct tls_context *tls_ctx = ctx;
936 ssize_t received;
937
938 received = zsock_recvfrom(tls_ctx->sock, buf, len,
939 ZSOCK_MSG_DONTWAIT, NULL, 0);
940 if (received < 0) {
941 if (errno == EAGAIN) {
942 return MBEDTLS_ERR_SSL_WANT_READ;
943 }
944
945 return MBEDTLS_ERR_NET_RECV_FAILED;
946 }
947
948 return received;
949 }
950
951 #if defined(MBEDTLS_X509_CRT_PARSE_C)
crt_is_pem(const unsigned char * buf,size_t buflen)952 static bool crt_is_pem(const unsigned char *buf, size_t buflen)
953 {
954 return (buflen != 0 && buf[buflen - 1] == '\0' &&
955 strstr((const char *)buf, "-----BEGIN CERTIFICATE-----") != NULL);
956 }
957 #endif
958
tls_add_ca_certificate(struct tls_context * tls,struct tls_credential * ca_cert)959 static int tls_add_ca_certificate(struct tls_context *tls,
960 struct tls_credential *ca_cert)
961 {
962 #if defined(MBEDTLS_X509_CRT_PARSE_C)
963 int err;
964
965 if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
966 crt_is_pem(ca_cert->buf, ca_cert->len)) {
967 err = mbedtls_x509_crt_parse(&tls->ca_chain, ca_cert->buf,
968 ca_cert->len);
969 } else {
970 err = mbedtls_x509_crt_parse_der_nocopy(&tls->ca_chain,
971 ca_cert->buf,
972 ca_cert->len);
973 }
974
975 if (err != 0) {
976 NET_ERR("Failed to parse CA certificate, err: -0x%x", -err);
977 return -EINVAL;
978 }
979
980 return 0;
981 #endif /* MBEDTLS_X509_CRT_PARSE_C */
982
983 return -ENOTSUP;
984 }
985
tls_set_ca_chain(struct tls_context * tls)986 static void tls_set_ca_chain(struct tls_context *tls)
987 {
988 #if defined(MBEDTLS_X509_CRT_PARSE_C)
989 mbedtls_ssl_conf_ca_chain(&tls->config, &tls->ca_chain, NULL);
990 mbedtls_ssl_conf_cert_profile(&tls->config,
991 &mbedtls_x509_crt_profile_default);
992 #endif /* MBEDTLS_X509_CRT_PARSE_C */
993 }
994
tls_add_own_cert(struct tls_context * tls,struct tls_credential * own_cert)995 static int tls_add_own_cert(struct tls_context *tls,
996 struct tls_credential *own_cert)
997 {
998 #if defined(MBEDTLS_X509_CRT_PARSE_C)
999 int err;
1000
1001 if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
1002 crt_is_pem(own_cert->buf, own_cert->len)) {
1003 err = mbedtls_x509_crt_parse(&tls->own_cert,
1004 own_cert->buf, own_cert->len);
1005 } else {
1006 err = mbedtls_x509_crt_parse_der_nocopy(&tls->own_cert,
1007 own_cert->buf,
1008 own_cert->len);
1009 }
1010
1011 if (err != 0) {
1012 return -EINVAL;
1013 }
1014
1015 return 0;
1016 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1017
1018 return -ENOTSUP;
1019 }
1020
tls_set_own_cert(struct tls_context * tls)1021 static int tls_set_own_cert(struct tls_context *tls)
1022 {
1023 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1024 int err = mbedtls_ssl_conf_own_cert(&tls->config, &tls->own_cert,
1025 &tls->priv_key);
1026 if (err != 0) {
1027 err = -ENOMEM;
1028 }
1029
1030 return err;
1031 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1032
1033 return -ENOTSUP;
1034 }
1035
tls_set_private_key(struct tls_context * tls,struct tls_credential * priv_key)1036 static int tls_set_private_key(struct tls_context *tls,
1037 struct tls_credential *priv_key)
1038 {
1039 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1040 int err;
1041
1042 err = mbedtls_pk_parse_key(&tls->priv_key, priv_key->buf,
1043 priv_key->len, NULL, 0,
1044 tls_ctr_drbg_random, NULL);
1045 if (err != 0) {
1046 return -EINVAL;
1047 }
1048
1049 return 0;
1050 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1051
1052 return -ENOTSUP;
1053 }
1054
tls_set_psk(struct tls_context * tls,struct tls_credential * psk,struct tls_credential * psk_id)1055 static int tls_set_psk(struct tls_context *tls,
1056 struct tls_credential *psk,
1057 struct tls_credential *psk_id)
1058 {
1059 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
1060 int err = mbedtls_ssl_conf_psk(&tls->config,
1061 psk->buf, psk->len,
1062 (const unsigned char *)psk_id->buf,
1063 psk_id->len);
1064 if (err != 0) {
1065 return -EINVAL;
1066 }
1067
1068 return 0;
1069 #endif
1070
1071 return -ENOTSUP;
1072 }
1073
tls_set_credential(struct tls_context * tls,struct tls_credential * cred)1074 static int tls_set_credential(struct tls_context *tls,
1075 struct tls_credential *cred)
1076 {
1077 switch (cred->type) {
1078 case TLS_CREDENTIAL_CA_CERTIFICATE:
1079 return tls_add_ca_certificate(tls, cred);
1080
1081 case TLS_CREDENTIAL_SERVER_CERTIFICATE:
1082 return tls_add_own_cert(tls, cred);
1083
1084 case TLS_CREDENTIAL_PRIVATE_KEY:
1085 return tls_set_private_key(tls, cred);
1086 break;
1087
1088 case TLS_CREDENTIAL_PSK:
1089 {
1090 struct tls_credential *psk_id =
1091 credential_get(cred->tag, TLS_CREDENTIAL_PSK_ID);
1092 if (!psk_id) {
1093 return -ENOENT;
1094 }
1095
1096 return tls_set_psk(tls, cred, psk_id);
1097 }
1098
1099 case TLS_CREDENTIAL_PSK_ID:
1100 /* Ignore PSK ID - it will be used together
1101 * with PSK
1102 */
1103 break;
1104
1105 default:
1106 return -EINVAL;
1107 }
1108
1109 return 0;
1110 }
1111
tls_mbedtls_set_credentials(struct tls_context * tls)1112 static int tls_mbedtls_set_credentials(struct tls_context *tls)
1113 {
1114 struct tls_credential *cred;
1115 sec_tag_t tag;
1116 int i, err = 0;
1117 bool tag_found, ca_cert_present = false, own_cert_present = false;
1118
1119 credentials_lock();
1120
1121 for (i = 0; i < tls->options.sec_tag_list.sec_tag_count; i++) {
1122 tag = tls->options.sec_tag_list.sec_tags[i];
1123 cred = NULL;
1124 tag_found = false;
1125
1126 while ((cred = credential_next_get(tag, cred)) != NULL) {
1127 tag_found = true;
1128
1129 err = tls_set_credential(tls, cred);
1130 if (err != 0) {
1131 goto exit;
1132 }
1133
1134 if (cred->type == TLS_CREDENTIAL_CA_CERTIFICATE) {
1135 ca_cert_present = true;
1136 } else if (cred->type == TLS_CREDENTIAL_SERVER_CERTIFICATE) {
1137 own_cert_present = true;
1138 }
1139 }
1140
1141 if (!tag_found) {
1142 err = -ENOENT;
1143 goto exit;
1144 }
1145 }
1146
1147 exit:
1148 credentials_unlock();
1149
1150 if (err == 0) {
1151 if (ca_cert_present) {
1152 tls_set_ca_chain(tls);
1153 }
1154 if (own_cert_present) {
1155 err = tls_set_own_cert(tls);
1156 }
1157 }
1158
1159 return err;
1160 }
1161
tls_mbedtls_reset(struct tls_context * context)1162 static int tls_mbedtls_reset(struct tls_context *context)
1163 {
1164 int ret;
1165
1166 ret = mbedtls_ssl_session_reset(&context->ssl);
1167 if (ret != 0) {
1168 return ret;
1169 }
1170
1171 k_sem_reset(&context->tls_established);
1172
1173 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1174 /* Server role: reset the address so that a new
1175 * client can connect w/o a need to reopen a socket
1176 * Client role: keep peer addr so socket can continue to be used
1177 * even on handshake timeout
1178 */
1179 if (context->options.role == MBEDTLS_SSL_IS_SERVER) {
1180 (void)memset(&context->dtls_peer_addr, 0,
1181 sizeof(context->dtls_peer_addr));
1182 context->dtls_peer_addrlen = 0;
1183 }
1184 #endif
1185
1186 return 0;
1187 }
1188
tls_mbedtls_handshake(struct tls_context * context,k_timeout_t timeout)1189 static int tls_mbedtls_handshake(struct tls_context *context,
1190 k_timeout_t timeout)
1191 {
1192 k_timepoint_t end;
1193 int ret;
1194
1195 context->handshake_in_progress = true;
1196
1197 end = sys_timepoint_calc(timeout);
1198
1199 while ((ret = mbedtls_ssl_handshake(&context->ssl)) != 0) {
1200 if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
1201 ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
1202 ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
1203 ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
1204 int timeout_ms;
1205
1206 /* Blocking timeout. */
1207 timeout = sys_timepoint_timeout(end);
1208 if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1209 ret = -EAGAIN;
1210 break;
1211 }
1212
1213 /* Block. */
1214 timeout_ms = timeout_to_ms(&timeout);
1215
1216 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1217 if (context->type == SOCK_DGRAM) {
1218 int timeout_dtls =
1219 dtls_get_remaining_timeout(context);
1220
1221 if (timeout_dtls != SYS_FOREVER_MS) {
1222 if (timeout_ms == SYS_FOREVER_MS) {
1223 timeout_ms = timeout_dtls;
1224 } else {
1225 timeout_ms = MIN(timeout_dtls,
1226 timeout_ms);
1227 }
1228 }
1229 }
1230 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1231
1232 ret = wait_for_reason(context->sock, timeout_ms, ret);
1233 if (ret != 0) {
1234 break;
1235 }
1236
1237 continue;
1238 } else if (ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED) {
1239 ret = tls_mbedtls_reset(context);
1240 if (ret == 0) {
1241 if (!K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1242 continue;
1243 }
1244
1245 ret = -EAGAIN;
1246 break;
1247 }
1248 } else if (ret == MBEDTLS_ERR_SSL_TIMEOUT) {
1249 /* MbedTLS API documentation requires session to
1250 * be reset in this case
1251 */
1252 ret = tls_mbedtls_reset(context);
1253 if (ret == 0) {
1254 NET_ERR("TLS handshake timeout");
1255 context->error = ETIMEDOUT;
1256 ret = -ETIMEDOUT;
1257 break;
1258 }
1259 } else {
1260 /* MbedTLS API documentation requires session to
1261 * be reset in other error cases
1262 */
1263 NET_ERR("TLS handshake error: -0x%x", -ret);
1264 ret = tls_mbedtls_reset(context);
1265 if (ret == 0) {
1266 context->error = ECONNABORTED;
1267 ret = -ECONNABORTED;
1268 break;
1269 }
1270 }
1271
1272 /* Avoid constant loop if tls_mbedtls_reset fails */
1273 NET_ERR("TLS reset error: -0x%x", -ret);
1274 context->error = ECONNABORTED;
1275 ret = -ECONNABORTED;
1276 break;
1277 }
1278
1279 if (ret == 0) {
1280 k_sem_give(&context->tls_established);
1281 }
1282
1283 context->handshake_in_progress = false;
1284
1285 return ret;
1286 }
1287
tls_mbedtls_init(struct tls_context * context,bool is_server)1288 static int tls_mbedtls_init(struct tls_context *context, bool is_server)
1289 {
1290 int role, type, ret;
1291
1292 role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
1293
1294 type = (context->type == SOCK_STREAM) ?
1295 MBEDTLS_SSL_TRANSPORT_STREAM :
1296 MBEDTLS_SSL_TRANSPORT_DATAGRAM;
1297
1298 if (type == MBEDTLS_SSL_TRANSPORT_STREAM) {
1299 mbedtls_ssl_set_bio(&context->ssl, context,
1300 tls_tx, tls_rx, NULL);
1301 } else {
1302 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1303 mbedtls_ssl_set_bio(&context->ssl, context,
1304 dtls_tx, dtls_rx, NULL);
1305 #else
1306 return -ENOTSUP;
1307 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1308 }
1309
1310 ret = mbedtls_ssl_config_defaults(&context->config, role, type,
1311 MBEDTLS_SSL_PRESET_DEFAULT);
1312 if (ret != 0) {
1313 /* According to mbedTLS API documentation,
1314 * mbedtls_ssl_config_defaults can fail due to memory
1315 * allocation failure
1316 */
1317 return -ENOMEM;
1318 }
1319 tls_set_max_frag_len(&context->config, context->type);
1320
1321 #if defined(MBEDTLS_SSL_RENEGOTIATION)
1322 mbedtls_ssl_conf_legacy_renegotiation(&context->config,
1323 MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE);
1324 mbedtls_ssl_conf_renegotiation(&context->config,
1325 MBEDTLS_SSL_RENEGOTIATION_ENABLED);
1326 #endif
1327
1328 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1329 if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1330 /* DTLS requires timer callbacks to operate */
1331 mbedtls_ssl_set_timer_cb(&context->ssl,
1332 &context->dtls_timing,
1333 dtls_timing_set_delay,
1334 dtls_timing_get_delay);
1335 mbedtls_ssl_conf_handshake_timeout(&context->config,
1336 context->options.dtls_handshake_timeout_min,
1337 context->options.dtls_handshake_timeout_max);
1338
1339 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1340 if (context->options.dtls_cid.enabled) {
1341 ret = mbedtls_ssl_conf_cid(
1342 &context->config,
1343 context->options.dtls_cid.cid_len,
1344 MBEDTLS_SSL_UNEXPECTED_CID_IGNORE);
1345 if (ret != 0) {
1346 return -EINVAL;
1347 }
1348 }
1349 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1350
1351 /* Configure cookie for DTLS server */
1352 if (role == MBEDTLS_SSL_IS_SERVER) {
1353 ret = mbedtls_ssl_cookie_setup(&context->cookie,
1354 tls_ctr_drbg_random,
1355 NULL);
1356 if (ret != 0) {
1357 return -ENOMEM;
1358 }
1359
1360 mbedtls_ssl_conf_dtls_cookies(&context->config,
1361 mbedtls_ssl_cookie_write,
1362 mbedtls_ssl_cookie_check,
1363 &context->cookie);
1364
1365 mbedtls_ssl_conf_read_timeout(
1366 &context->config,
1367 CONFIG_NET_SOCKETS_DTLS_TIMEOUT);
1368 }
1369 }
1370 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1371
1372 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1373 /* For TLS clients, set hostname to empty string to enforce it's
1374 * verification - only if hostname option was not set. Otherwise
1375 * depend on user configuration.
1376 */
1377 if (!is_server && !context->options.is_hostname_set) {
1378 mbedtls_ssl_set_hostname(&context->ssl, "");
1379 }
1380 #endif
1381
1382 /* If verification level was specified explicitly, set it. Otherwise,
1383 * use mbedTLS default values (required for client, none for server)
1384 */
1385 if (context->options.verify_level != -1) {
1386 mbedtls_ssl_conf_authmode(&context->config,
1387 context->options.verify_level);
1388 }
1389
1390 mbedtls_ssl_conf_rng(&context->config,
1391 tls_ctr_drbg_random,
1392 NULL);
1393
1394 ret = tls_mbedtls_set_credentials(context);
1395 if (ret != 0) {
1396 return ret;
1397 }
1398
1399 if (context->options.ciphersuites[0] != 0) {
1400 /* Specific ciphersuites configured, so use them */
1401 NET_DBG("Using user-specified ciphersuites");
1402 mbedtls_ssl_conf_ciphersuites(&context->config,
1403 context->options.ciphersuites);
1404 }
1405
1406 #if defined(CONFIG_MBEDTLS_SSL_ALPN)
1407 if (ALPN_MAX_PROTOCOLS && context->options.alpn_list[0] != NULL) {
1408 ret = mbedtls_ssl_conf_alpn_protocols(&context->config,
1409 context->options.alpn_list);
1410 if (ret != 0) {
1411 return -EINVAL;
1412 }
1413 }
1414 #endif /* CONFIG_MBEDTLS_SSL_ALPN */
1415
1416 #if defined(MBEDTLS_SSL_CACHE_C)
1417 if (is_server && context->options.cache_enabled) {
1418 mbedtls_ssl_conf_session_cache(&context->config, &server_cache,
1419 mbedtls_ssl_cache_get,
1420 mbedtls_ssl_cache_set);
1421 }
1422 #endif
1423
1424 #if defined(MBEDTLS_SSL_EARLY_DATA)
1425 mbedtls_ssl_conf_early_data(&context->config, MBEDTLS_SSL_EARLY_DATA_ENABLED);
1426 #endif
1427
1428 ret = mbedtls_ssl_setup(&context->ssl,
1429 &context->config);
1430 if (ret != 0) {
1431 /* According to mbedTLS API documentation,
1432 * mbedtls_ssl_setup can fail due to memory allocation failure
1433 */
1434 return -ENOMEM;
1435 }
1436
1437 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) && defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1438 if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1439 if (context->options.dtls_cid.enabled) {
1440 ret = mbedtls_ssl_set_cid(&context->ssl, MBEDTLS_SSL_CID_ENABLED,
1441 context->options.dtls_cid.cid,
1442 context->options.dtls_cid.cid_len);
1443 if (ret != 0) {
1444 return -EINVAL;
1445 }
1446 }
1447 }
1448 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS && CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1449
1450 context->is_initialized = true;
1451
1452 return 0;
1453 }
1454
tls_opt_sec_tag_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1455 static int tls_opt_sec_tag_list_set(struct tls_context *context,
1456 const void *optval, socklen_t optlen)
1457 {
1458 int sec_tag_cnt;
1459
1460 if (!optval) {
1461 return -EINVAL;
1462 }
1463
1464 if (optlen % sizeof(sec_tag_t) != 0) {
1465 return -EINVAL;
1466 }
1467
1468 sec_tag_cnt = optlen / sizeof(sec_tag_t);
1469 if (sec_tag_cnt >
1470 ARRAY_SIZE(context->options.sec_tag_list.sec_tags)) {
1471 return -EINVAL;
1472 }
1473
1474 memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
1475 context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;
1476
1477 return 0;
1478 }
1479
sock_opt_protocol_get(struct tls_context * context,void * optval,socklen_t * optlen)1480 static int sock_opt_protocol_get(struct tls_context *context,
1481 void *optval, socklen_t *optlen)
1482 {
1483 int protocol = (int)context->tls_version;
1484
1485 if (*optlen != sizeof(protocol)) {
1486 return -EINVAL;
1487 }
1488
1489 *(int *)optval = protocol;
1490
1491 return 0;
1492 }
1493
tls_opt_sec_tag_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1494 static int tls_opt_sec_tag_list_get(struct tls_context *context,
1495 void *optval, socklen_t *optlen)
1496 {
1497 int len;
1498
1499 if (*optlen % sizeof(sec_tag_t) != 0 || *optlen == 0) {
1500 return -EINVAL;
1501 }
1502
1503 len = MIN(context->options.sec_tag_list.sec_tag_count *
1504 sizeof(sec_tag_t), *optlen);
1505
1506 memcpy(optval, context->options.sec_tag_list.sec_tags, len);
1507 *optlen = len;
1508
1509 return 0;
1510 }
1511
tls_opt_hostname_set(struct tls_context * context,const void * optval,socklen_t optlen)1512 static int tls_opt_hostname_set(struct tls_context *context,
1513 const void *optval, socklen_t optlen)
1514 {
1515 ARG_UNUSED(optlen);
1516
1517 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1518 if (mbedtls_ssl_set_hostname(&context->ssl, optval) != 0) {
1519 return -EINVAL;
1520 }
1521 #else
1522 return -ENOPROTOOPT;
1523 #endif
1524
1525 context->options.is_hostname_set = true;
1526
1527 return 0;
1528 }
1529
tls_opt_ciphersuite_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1530 static int tls_opt_ciphersuite_list_set(struct tls_context *context,
1531 const void *optval, socklen_t optlen)
1532 {
1533 int cipher_cnt;
1534
1535 if (!optval) {
1536 return -EINVAL;
1537 }
1538
1539 if (optlen % sizeof(int) != 0) {
1540 return -EINVAL;
1541 }
1542
1543 cipher_cnt = optlen / sizeof(int);
1544
1545 /* + 1 for 0-termination. */
1546 if (cipher_cnt + 1 > ARRAY_SIZE(context->options.ciphersuites)) {
1547 return -EINVAL;
1548 }
1549
1550 memcpy(context->options.ciphersuites, optval, optlen);
1551 context->options.ciphersuites[cipher_cnt] = 0;
1552
1553 mbedtls_ssl_conf_ciphersuites(&context->config,
1554 context->options.ciphersuites);
1555 return 0;
1556 }
1557
tls_opt_ciphersuite_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1558 static int tls_opt_ciphersuite_list_get(struct tls_context *context,
1559 void *optval, socklen_t *optlen)
1560 {
1561 const int *selected_ciphers;
1562 int cipher_cnt, i = 0;
1563 int *ciphers = optval;
1564
1565 if (*optlen % sizeof(int) != 0 || *optlen == 0) {
1566 return -EINVAL;
1567 }
1568
1569 if (context->options.ciphersuites[0] == 0) {
1570 /* No specific ciphersuites configured, return all available. */
1571 selected_ciphers = mbedtls_ssl_list_ciphersuites();
1572 } else {
1573 selected_ciphers = context->options.ciphersuites;
1574 }
1575
1576 cipher_cnt = *optlen / sizeof(int);
1577 while (selected_ciphers[i] != 0) {
1578 ciphers[i] = selected_ciphers[i];
1579
1580 if (++i == cipher_cnt) {
1581 break;
1582 }
1583 }
1584
1585 *optlen = i * sizeof(int);
1586
1587 return 0;
1588 }
1589
tls_opt_ciphersuite_used_get(struct tls_context * context,void * optval,socklen_t * optlen)1590 static int tls_opt_ciphersuite_used_get(struct tls_context *context,
1591 void *optval, socklen_t *optlen)
1592 {
1593 const char *ciph;
1594
1595 if (*optlen != sizeof(int)) {
1596 return -EINVAL;
1597 }
1598
1599 ciph = mbedtls_ssl_get_ciphersuite(&context->ssl);
1600 if (ciph == NULL) {
1601 return -ENOTCONN;
1602 }
1603
1604 *(int *)optval = mbedtls_ssl_get_ciphersuite_id(ciph);
1605
1606 return 0;
1607 }
1608
tls_opt_alpn_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1609 static int tls_opt_alpn_list_set(struct tls_context *context,
1610 const void *optval, socklen_t optlen)
1611 {
1612 int alpn_cnt;
1613
1614 if (!ALPN_MAX_PROTOCOLS) {
1615 return -EINVAL;
1616 }
1617
1618 if (!optval) {
1619 return -EINVAL;
1620 }
1621
1622 if (optlen % sizeof(const char *) != 0) {
1623 return -EINVAL;
1624 }
1625
1626 alpn_cnt = optlen / sizeof(const char *);
1627 /* + 1 for NULL-termination. */
1628 if (alpn_cnt + 1 > ARRAY_SIZE(context->options.alpn_list)) {
1629 return -EINVAL;
1630 }
1631
1632 memcpy(context->options.alpn_list, optval, optlen);
1633 context->options.alpn_list[alpn_cnt] = NULL;
1634
1635 return 0;
1636 }
1637
1638 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
tls_opt_dtls_handshake_timeout_get(struct tls_context * context,void * optval,socklen_t * optlen,bool is_max)1639 static int tls_opt_dtls_handshake_timeout_get(struct tls_context *context,
1640 void *optval, socklen_t *optlen,
1641 bool is_max)
1642 {
1643 uint32_t *val = (uint32_t *)optval;
1644
1645 if (sizeof(uint32_t) != *optlen) {
1646 return -EINVAL;
1647 }
1648
1649 if (is_max) {
1650 *val = context->options.dtls_handshake_timeout_max;
1651 } else {
1652 *val = context->options.dtls_handshake_timeout_min;
1653 }
1654
1655 return 0;
1656 }
1657
tls_opt_dtls_handshake_timeout_set(struct tls_context * context,const void * optval,socklen_t optlen,bool is_max)1658 static int tls_opt_dtls_handshake_timeout_set(struct tls_context *context,
1659 const void *optval,
1660 socklen_t optlen, bool is_max)
1661 {
1662 uint32_t *val = (uint32_t *)optval;
1663
1664 if (!optval) {
1665 return -EINVAL;
1666 }
1667
1668 if (sizeof(uint32_t) != optlen) {
1669 return -EINVAL;
1670 }
1671
1672 /* If mbedTLS context not inited, it will
1673 * use these values upon init.
1674 */
1675 if (is_max) {
1676 context->options.dtls_handshake_timeout_max = *val;
1677 } else {
1678 context->options.dtls_handshake_timeout_min = *val;
1679 }
1680
1681 /* If mbedTLS context already inited, we need to
1682 * update mbedTLS config for it to take effect
1683 */
1684 mbedtls_ssl_conf_handshake_timeout(&context->config,
1685 context->options.dtls_handshake_timeout_min,
1686 context->options.dtls_handshake_timeout_max);
1687
1688 return 0;
1689 }
1690
tls_opt_dtls_connection_id_set(struct tls_context * context,const void * optval,socklen_t optlen)1691 static int tls_opt_dtls_connection_id_set(struct tls_context *context,
1692 const void *optval, socklen_t optlen)
1693 {
1694 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1695 int value;
1696
1697 if (optlen > 0 && optval == NULL) {
1698 return -EINVAL;
1699 }
1700
1701 if (optlen != sizeof(int)) {
1702 return -EINVAL;
1703 }
1704
1705 value = *((int *)optval);
1706
1707 switch (value) {
1708 case TLS_DTLS_CID_DISABLED:
1709 context->options.dtls_cid.enabled = false;
1710 context->options.dtls_cid.cid_len = 0;
1711 break;
1712 case TLS_DTLS_CID_SUPPORTED:
1713 context->options.dtls_cid.enabled = true;
1714 context->options.dtls_cid.cid_len = 0;
1715 break;
1716 case TLS_DTLS_CID_ENABLED:
1717 context->options.dtls_cid.enabled = true;
1718 if (context->options.dtls_cid.cid_len == 0) {
1719 /* generate random self cid */
1720 #if defined(CONFIG_CSPRNG_ENABLED)
1721 sys_csrand_get(context->options.dtls_cid.cid,
1722 MBEDTLS_SSL_CID_OUT_LEN_MAX);
1723 #else
1724 sys_rand_get(context->options.dtls_cid.cid,
1725 MBEDTLS_SSL_CID_OUT_LEN_MAX);
1726 #endif
1727 context->options.dtls_cid.cid_len = MBEDTLS_SSL_CID_OUT_LEN_MAX;
1728 }
1729 break;
1730 default:
1731 return -EINVAL;
1732 }
1733
1734 return 0;
1735 #else
1736 return -ENOPROTOOPT;
1737 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1738 }
1739
tls_opt_dtls_connection_id_value_set(struct tls_context * context,const void * optval,socklen_t optlen)1740 static int tls_opt_dtls_connection_id_value_set(struct tls_context *context,
1741 const void *optval,
1742 socklen_t optlen)
1743 {
1744 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1745 if (optlen > 0 && optval == NULL) {
1746 return -EINVAL;
1747 }
1748
1749 if (optlen > MBEDTLS_SSL_CID_IN_LEN_MAX) {
1750 return -EINVAL;
1751 }
1752
1753 context->options.dtls_cid.cid_len = optlen;
1754 memcpy(context->options.dtls_cid.cid, optval, optlen);
1755
1756 return 0;
1757 #else
1758 return -ENOPROTOOPT;
1759 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1760 }
1761
tls_opt_dtls_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1762 static int tls_opt_dtls_connection_id_value_get(struct tls_context *context,
1763 void *optval, socklen_t *optlen)
1764 {
1765 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1766
1767 if (*optlen < context->options.dtls_cid.cid_len) {
1768 return -EINVAL;
1769 }
1770
1771 *optlen = context->options.dtls_cid.cid_len;
1772 memcpy(optval, context->options.dtls_cid.cid, *optlen);
1773
1774 return 0;
1775 #else
1776 return -ENOPROTOOPT;
1777 #endif
1778 }
1779
tls_opt_dtls_peer_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1780 static int tls_opt_dtls_peer_connection_id_value_get(struct tls_context *context,
1781 void *optval,
1782 socklen_t *optlen)
1783 {
1784 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1785 int enabled = false;
1786 int ret;
1787
1788 if (!context->is_initialized) {
1789 return -ENOTCONN;
1790 }
1791
1792 ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled, optval, optlen);
1793 if (!enabled) {
1794 *optlen = 0;
1795 }
1796 return ret;
1797 #else
1798 return -ENOPROTOOPT;
1799 #endif
1800 }
1801
tls_opt_dtls_connection_id_status_get(struct tls_context * context,void * optval,socklen_t * optlen)1802 static int tls_opt_dtls_connection_id_status_get(struct tls_context *context,
1803 void *optval, socklen_t *optlen)
1804 {
1805 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1806 struct tls_dtls_cid cid;
1807 int ret;
1808 int val;
1809 int enabled;
1810 bool have_self_cid;
1811 bool have_peer_cid;
1812
1813 if (sizeof(int) != *optlen) {
1814 return -EINVAL;
1815 }
1816
1817 if (!context->is_initialized) {
1818 return -ENOTCONN;
1819 }
1820
1821 ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled,
1822 cid.cid,
1823 &cid.cid_len);
1824 if (ret) {
1825 /* Handshake is not complete */
1826 return -EAGAIN;
1827 }
1828
1829 cid.enabled = (enabled == MBEDTLS_SSL_CID_ENABLED);
1830 have_self_cid = (context->options.dtls_cid.cid_len != 0);
1831 have_peer_cid = (cid.cid_len != 0);
1832
1833 if (!context->options.dtls_cid.enabled) {
1834 val = TLS_DTLS_CID_STATUS_DISABLED;
1835 } else if (have_self_cid && have_peer_cid) {
1836 val = TLS_DTLS_CID_STATUS_BIDIRECTIONAL;
1837 } else if (have_self_cid) {
1838 val = TLS_DTLS_CID_STATUS_DOWNLINK;
1839 } else if (have_peer_cid) {
1840 val = TLS_DTLS_CID_STATUS_UPLINK;
1841 } else {
1842 val = TLS_DTLS_CID_STATUS_DISABLED;
1843 }
1844
1845 *((int *)optval) = val;
1846 return 0;
1847 #else
1848 return -ENOPROTOOPT;
1849 #endif
1850 }
1851
tls_opt_dtls_handshake_on_connect_set(struct tls_context * context,const void * optval,socklen_t optlen)1852 static int tls_opt_dtls_handshake_on_connect_set(struct tls_context *context,
1853 const void *optval,
1854 socklen_t optlen)
1855 {
1856 int *val = (int *)optval;
1857
1858 if (!optval) {
1859 return -EINVAL;
1860 }
1861
1862 if (sizeof(int) != optlen) {
1863 return -EINVAL;
1864 }
1865
1866 context->options.dtls_handshake_on_connect = (bool)*val;
1867
1868 return 0;
1869 }
1870
tls_opt_dtls_handshake_on_connect_get(struct tls_context * context,void * optval,socklen_t * optlen)1871 static int tls_opt_dtls_handshake_on_connect_get(struct tls_context *context,
1872 void *optval,
1873 socklen_t *optlen)
1874 {
1875 if (*optlen != sizeof(int)) {
1876 return -EINVAL;
1877 }
1878
1879 *(int *)optval = context->options.dtls_handshake_on_connect;
1880
1881 return 0;
1882 }
1883 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1884
tls_opt_alpn_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1885 static int tls_opt_alpn_list_get(struct tls_context *context,
1886 void *optval, socklen_t *optlen)
1887 {
1888 const char **alpn_list = context->options.alpn_list;
1889 int alpn_cnt, i = 0;
1890 const char **ret_list = optval;
1891
1892 if (!ALPN_MAX_PROTOCOLS) {
1893 return -EINVAL;
1894 }
1895
1896 if (*optlen % sizeof(const char *) != 0 || *optlen == 0) {
1897 return -EINVAL;
1898 }
1899
1900 alpn_cnt = *optlen / sizeof(const char *);
1901 while (alpn_list[i] != NULL) {
1902 ret_list[i] = alpn_list[i];
1903
1904 if (++i == alpn_cnt) {
1905 break;
1906 }
1907 }
1908
1909 *optlen = i * sizeof(const char *);
1910
1911 return 0;
1912 }
1913
tls_opt_session_cache_set(struct tls_context * context,const void * optval,socklen_t optlen)1914 static int tls_opt_session_cache_set(struct tls_context *context,
1915 const void *optval, socklen_t optlen)
1916 {
1917 int *val = (int *)optval;
1918
1919 if (!optval) {
1920 return -EINVAL;
1921 }
1922
1923 if (sizeof(int) != optlen) {
1924 return -EINVAL;
1925 }
1926
1927 context->options.cache_enabled = (*val == TLS_SESSION_CACHE_ENABLED);
1928
1929 return 0;
1930 }
1931
tls_opt_session_cache_get(struct tls_context * context,void * optval,socklen_t * optlen)1932 static int tls_opt_session_cache_get(struct tls_context *context,
1933 void *optval, socklen_t *optlen)
1934 {
1935 int cache_enabled = context->options.cache_enabled ?
1936 TLS_SESSION_CACHE_ENABLED :
1937 TLS_SESSION_CACHE_DISABLED;
1938
1939 if (*optlen != sizeof(cache_enabled)) {
1940 return -EINVAL;
1941 }
1942
1943 *(int *)optval = cache_enabled;
1944
1945 return 0;
1946 }
1947
tls_opt_session_cache_purge_set(struct tls_context * context,const void * optval,socklen_t optlen)1948 static int tls_opt_session_cache_purge_set(struct tls_context *context,
1949 const void *optval, socklen_t optlen)
1950 {
1951 ARG_UNUSED(context);
1952 ARG_UNUSED(optval);
1953 ARG_UNUSED(optlen);
1954
1955 tls_session_purge();
1956
1957 return 0;
1958 }
1959
tls_opt_peer_verify_set(struct tls_context * context,const void * optval,socklen_t optlen)1960 static int tls_opt_peer_verify_set(struct tls_context *context,
1961 const void *optval, socklen_t optlen)
1962 {
1963 int *peer_verify;
1964
1965 if (!optval) {
1966 return -EINVAL;
1967 }
1968
1969 if (optlen != sizeof(int)) {
1970 return -EINVAL;
1971 }
1972
1973 peer_verify = (int *)optval;
1974
1975 if (*peer_verify != MBEDTLS_SSL_VERIFY_NONE &&
1976 *peer_verify != MBEDTLS_SSL_VERIFY_OPTIONAL &&
1977 *peer_verify != MBEDTLS_SSL_VERIFY_REQUIRED) {
1978 return -EINVAL;
1979 }
1980
1981 context->options.verify_level = *peer_verify;
1982
1983 return 0;
1984 }
1985
tls_opt_cert_nocopy_set(struct tls_context * context,const void * optval,socklen_t optlen)1986 static int tls_opt_cert_nocopy_set(struct tls_context *context,
1987 const void *optval, socklen_t optlen)
1988 {
1989 int *cert_nocopy;
1990
1991 if (!optval) {
1992 return -EINVAL;
1993 }
1994
1995 if (optlen != sizeof(int)) {
1996 return -EINVAL;
1997 }
1998
1999 cert_nocopy = (int *)optval;
2000
2001 if (*cert_nocopy != TLS_CERT_NOCOPY_NONE &&
2002 *cert_nocopy != TLS_CERT_NOCOPY_OPTIONAL) {
2003 return -EINVAL;
2004 }
2005
2006 context->options.cert_nocopy = *cert_nocopy;
2007
2008 return 0;
2009 }
2010
tls_opt_dtls_role_set(struct tls_context * context,const void * optval,socklen_t optlen)2011 static int tls_opt_dtls_role_set(struct tls_context *context,
2012 const void *optval, socklen_t optlen)
2013 {
2014 int *role;
2015
2016 if (!optval) {
2017 return -EINVAL;
2018 }
2019
2020 if (optlen != sizeof(int)) {
2021 return -EINVAL;
2022 }
2023
2024 role = (int *)optval;
2025 if (*role != MBEDTLS_SSL_IS_CLIENT &&
2026 *role != MBEDTLS_SSL_IS_SERVER) {
2027 return -EINVAL;
2028 }
2029
2030 context->options.role = *role;
2031
2032 return 0;
2033 }
2034
protocol_check(int family,int type,int * proto)2035 static int protocol_check(int family, int type, int *proto)
2036 {
2037 if (family != AF_INET && family != AF_INET6) {
2038 return -EAFNOSUPPORT;
2039 }
2040
2041 if (*proto >= IPPROTO_TLS_1_0 && *proto <= IPPROTO_TLS_1_3) {
2042 if (type != SOCK_STREAM) {
2043 return -EPROTOTYPE;
2044 }
2045
2046 *proto = IPPROTO_TCP;
2047 } else if (*proto >= IPPROTO_DTLS_1_0 && *proto <= IPPROTO_DTLS_1_2) {
2048 if (!IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS)) {
2049 return -EPROTONOSUPPORT;
2050 }
2051
2052 if (type != SOCK_DGRAM) {
2053 return -EPROTOTYPE;
2054 }
2055
2056 *proto = IPPROTO_UDP;
2057 } else {
2058 return -EPROTONOSUPPORT;
2059 }
2060
2061 return 0;
2062 }
2063
ztls_socket(int family,int type,int proto)2064 static int ztls_socket(int family, int type, int proto)
2065 {
2066 enum net_ip_protocol_secure tls_proto = proto;
2067 int fd = zvfs_reserve_fd();
2068 int sock = -1;
2069 int ret;
2070 struct tls_context *ctx;
2071
2072 if (fd < 0) {
2073 return -1;
2074 }
2075
2076 ret = protocol_check(family, type, &proto);
2077 if (ret < 0) {
2078 errno = -ret;
2079 goto free_fd;
2080 }
2081
2082 ctx = tls_alloc();
2083 if (ctx == NULL) {
2084 errno = ENOMEM;
2085 goto free_fd;
2086 }
2087
2088 sock = zsock_socket(family, type, proto);
2089 if (sock < 0) {
2090 goto release_tls;
2091 }
2092
2093 ctx->tls_version = tls_proto;
2094 ctx->type = (proto == IPPROTO_TCP) ? SOCK_STREAM : SOCK_DGRAM;
2095 ctx->sock = sock;
2096
2097 zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2098 ZVFS_MODE_IFSOCK);
2099
2100 return fd;
2101
2102 release_tls:
2103 (void)tls_release(ctx);
2104
2105 free_fd:
2106 zvfs_free_fd(fd);
2107
2108 return -1;
2109 }
2110
ztls_close_ctx(struct tls_context * ctx,int sock)2111 int ztls_close_ctx(struct tls_context *ctx, int sock)
2112 {
2113 int ret, err = 0;
2114
2115 /* Try to send close notification. */
2116 ctx->flags = 0;
2117
2118 (void)mbedtls_ssl_close_notify(&ctx->ssl);
2119
2120 err = tls_release(ctx);
2121 ret = zsock_close(ctx->sock);
2122
2123 if (ret == 0) {
2124 (void)sock_obj_core_dealloc(sock);
2125 }
2126
2127 /* In case close fails, we propagate errno value set by close.
2128 * In case close succeeds, but tls_release fails, set errno
2129 * according to tls_release return value.
2130 */
2131 if (ret == 0 && err < 0) {
2132 errno = -err;
2133 ret = -1;
2134 }
2135
2136 return ret;
2137 }
2138
ztls_connect_ctx(struct tls_context * ctx,const struct sockaddr * addr,socklen_t addrlen)2139 int ztls_connect_ctx(struct tls_context *ctx, const struct sockaddr *addr,
2140 socklen_t addrlen)
2141 {
2142 int ret;
2143 int sock_flags;
2144 bool is_non_block;
2145
2146 sock_flags = zsock_fcntl(ctx->sock, F_GETFL, 0);
2147 if (sock_flags < 0) {
2148 return -EIO;
2149 }
2150
2151 is_non_block = sock_flags & O_NONBLOCK;
2152 if (is_non_block) {
2153 (void)zsock_fcntl(ctx->sock, F_SETFL,
2154 sock_flags & ~O_NONBLOCK);
2155 }
2156
2157 ret = zsock_connect(ctx->sock, addr, addrlen);
2158 if (ret < 0) {
2159 return ret;
2160 }
2161
2162 if (is_non_block) {
2163 (void)zsock_fcntl(ctx->sock, F_SETFL, sock_flags);
2164 }
2165
2166 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2167 if (ctx->type == SOCK_DGRAM) {
2168 dtls_peer_address_set(ctx, addr, addrlen);
2169 }
2170 #endif
2171
2172 if (ctx->type == SOCK_STREAM
2173 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2174 || (ctx->type == SOCK_DGRAM && ctx->options.dtls_handshake_on_connect)
2175 #endif
2176 ) {
2177 ret = tls_mbedtls_init(ctx, false);
2178 if (ret < 0) {
2179 goto error;
2180 }
2181
2182 /* Do not use any socket flags during the handshake. */
2183 ctx->flags = 0;
2184
2185 tls_session_restore(ctx, addr, addrlen);
2186
2187 /* TODO For simplicity, TLS handshake blocks the socket
2188 * even for non-blocking socket.
2189 */
2190 ret = tls_mbedtls_handshake(
2191 ctx, K_MSEC(CONFIG_NET_SOCKETS_CONNECT_TIMEOUT));
2192 if (ret < 0) {
2193 if ((ret == -EAGAIN) && !is_non_block) {
2194 ret = -ETIMEDOUT;
2195 }
2196
2197 goto error;
2198 }
2199
2200 tls_session_store(ctx, addr, addrlen);
2201 }
2202
2203 return 0;
2204
2205 error:
2206 errno = -ret;
2207 return -1;
2208 }
2209
ztls_accept_ctx(struct tls_context * parent,struct sockaddr * addr,socklen_t * addrlen)2210 int ztls_accept_ctx(struct tls_context *parent, struct sockaddr *addr,
2211 socklen_t *addrlen)
2212 {
2213 struct tls_context *child = NULL;
2214 int ret, err, fd, sock;
2215
2216 fd = zvfs_reserve_fd();
2217 if (fd < 0) {
2218 return -1;
2219 }
2220
2221
2222 k_mutex_unlock(parent->lock);
2223 sock = zsock_accept(parent->sock, addr, addrlen);
2224 k_mutex_lock(parent->lock, K_FOREVER);
2225 if (sock < 0) {
2226 ret = -errno;
2227 goto error;
2228 }
2229
2230 child = tls_clone(parent);
2231 if (child == NULL) {
2232 ret = -ENOMEM;
2233 goto error;
2234 }
2235
2236 zvfs_finalize_typed_fd(fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2237 ZVFS_MODE_IFSOCK);
2238
2239 child->sock = sock;
2240
2241 ret = tls_mbedtls_init(child, true);
2242 if (ret < 0) {
2243 goto error;
2244 }
2245
2246 /* Do not use any socket flags during the handshake. */
2247 child->flags = 0;
2248
2249 /* TODO For simplicity, TLS handshake blocks the socket even for
2250 * non-blocking socket.
2251 */
2252 ret = tls_mbedtls_handshake(
2253 child, K_MSEC(CONFIG_NET_SOCKETS_CONNECT_TIMEOUT));
2254 if (ret < 0) {
2255 goto error;
2256 }
2257
2258 return fd;
2259
2260 error:
2261 if (child != NULL) {
2262 err = tls_release(child);
2263 __ASSERT(err == 0, "TLS context release failed");
2264 }
2265
2266 if (sock >= 0) {
2267 err = zsock_close(sock);
2268 __ASSERT(err == 0, "Child socket close failed");
2269 }
2270
2271 zvfs_free_fd(fd);
2272
2273 errno = -ret;
2274 return -1;
2275 }
2276
send_tls(struct tls_context * ctx,const void * buf,size_t len,int flags)2277 static ssize_t send_tls(struct tls_context *ctx, const void *buf,
2278 size_t len, int flags)
2279 {
2280 const bool is_block = is_blocking(ctx->sock, flags);
2281 k_timeout_t timeout;
2282 k_timepoint_t end;
2283 int ret;
2284
2285 if (ctx->error != 0) {
2286 errno = ctx->error;
2287 return -1;
2288 }
2289
2290 if (ctx->session_closed) {
2291 errno = ECONNABORTED;
2292 return -1;
2293 }
2294
2295 if (!is_block) {
2296 timeout = K_NO_WAIT;
2297 } else {
2298 timeout = ctx->options.timeout_tx;
2299 }
2300
2301 end = sys_timepoint_calc(timeout);
2302
2303 do {
2304 ret = mbedtls_ssl_write(&ctx->ssl, buf, len);
2305 if (ret >= 0) {
2306 return ret;
2307 }
2308
2309 if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2310 ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2311 ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2312 ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2313 int timeout_ms;
2314
2315 if (!is_block) {
2316 errno = EAGAIN;
2317 break;
2318 }
2319
2320 /* Blocking timeout. */
2321 timeout = sys_timepoint_timeout(end);
2322 if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2323 errno = EAGAIN;
2324 break;
2325 }
2326
2327 /* Block. */
2328 timeout_ms = timeout_to_ms(&timeout);
2329 ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2330 if (ret != 0) {
2331 errno = -ret;
2332 break;
2333 }
2334 } else {
2335 NET_ERR("TLS send error: -%x", -ret);
2336
2337 /* MbedTLS API documentation requires session to
2338 * be reset in other error cases
2339 */
2340 ret = tls_mbedtls_reset(ctx);
2341 if (ret != 0) {
2342 ctx->error = ENOMEM;
2343 errno = ENOMEM;
2344 } else {
2345 ctx->error = ECONNABORTED;
2346 errno = ECONNABORTED;
2347 }
2348
2349 break;
2350 }
2351 } while (true);
2352
2353 return -1;
2354 }
2355
2356 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
sendto_dtls_client(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2357 static ssize_t sendto_dtls_client(struct tls_context *ctx, const void *buf,
2358 size_t len, int flags,
2359 const struct sockaddr *dest_addr,
2360 socklen_t addrlen)
2361 {
2362 int ret;
2363
2364 if (!dest_addr) {
2365 /* No address provided, check if we have stored one,
2366 * otherwise return error.
2367 */
2368 if (ctx->dtls_peer_addrlen == 0) {
2369 ret = -EDESTADDRREQ;
2370 goto error;
2371 }
2372 } else if (ctx->dtls_peer_addrlen == 0) {
2373 /* Address provided and no peer address stored. */
2374 dtls_peer_address_set(ctx, dest_addr, addrlen);
2375 } else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2376 /* Address provided but it does not match stored one */
2377 ret = -EISCONN;
2378 goto error;
2379 }
2380
2381 if (!ctx->is_initialized) {
2382 ret = tls_mbedtls_init(ctx, false);
2383 if (ret < 0) {
2384 goto error;
2385 }
2386 }
2387
2388 if (!is_handshake_complete(ctx)) {
2389 tls_session_restore(ctx, &ctx->dtls_peer_addr,
2390 ctx->dtls_peer_addrlen);
2391
2392 /* TODO For simplicity, TLS handshake blocks the socket even for
2393 * non-blocking socket.
2394 * DTLS handshake timeout/retransmissions are limited by
2395 * mbed TLS, so K_FOREVER is fine here, the function will not
2396 * block indefinitely.
2397 */
2398 ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2399 if (ret < 0) {
2400 goto error;
2401 }
2402
2403 /* Client socket ready to use again. */
2404 ctx->error = 0;
2405
2406 tls_session_store(ctx, &ctx->dtls_peer_addr,
2407 ctx->dtls_peer_addrlen);
2408 }
2409
2410 return send_tls(ctx, buf, len, flags);
2411
2412 error:
2413 errno = -ret;
2414 return -1;
2415 }
2416
sendto_dtls_server(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2417 static ssize_t sendto_dtls_server(struct tls_context *ctx, const void *buf,
2418 size_t len, int flags,
2419 const struct sockaddr *dest_addr,
2420 socklen_t addrlen)
2421 {
2422 /* For DTLS server, require to have established DTLS connection
2423 * in order to send data.
2424 */
2425 if (!is_handshake_complete(ctx)) {
2426 errno = ENOTCONN;
2427 return -1;
2428 }
2429
2430 /* Verify we are sending to a peer that we have connection with. */
2431 if (dest_addr &&
2432 !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2433 errno = EISCONN;
2434 return -1;
2435 }
2436
2437 return send_tls(ctx, buf, len, flags);
2438 }
2439 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2440
ztls_sendto_ctx(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2441 ssize_t ztls_sendto_ctx(struct tls_context *ctx, const void *buf, size_t len,
2442 int flags, const struct sockaddr *dest_addr,
2443 socklen_t addrlen)
2444 {
2445 ctx->flags = flags;
2446
2447 /* TLS */
2448 if (ctx->type == SOCK_STREAM) {
2449 return send_tls(ctx, buf, len, flags);
2450 }
2451
2452 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2453 /* DTLS */
2454 if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2455 return sendto_dtls_server(ctx, buf, len, flags,
2456 dest_addr, addrlen);
2457 }
2458
2459 return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen);
2460 #else
2461 errno = ENOTSUP;
2462 return -1;
2463 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2464 }
2465
dtls_sendmsg_merge_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2466 static ssize_t dtls_sendmsg_merge_and_send(struct tls_context *ctx,
2467 const struct msghdr *msg,
2468 int flags)
2469 {
2470 static K_MUTEX_DEFINE(sendmsg_lock);
2471 static uint8_t sendmsg_buf[DTLS_SENDMSG_BUF_SIZE];
2472 ssize_t len = 0;
2473
2474 k_mutex_lock(&sendmsg_lock, K_FOREVER);
2475
2476 for (int i = 0; i < msg->msg_iovlen; i++) {
2477 struct iovec *vec = msg->msg_iov + i;
2478
2479 if (vec->iov_len > 0) {
2480 if (len + vec->iov_len > sizeof(sendmsg_buf)) {
2481 k_mutex_unlock(&sendmsg_lock);
2482 errno = EMSGSIZE;
2483 return -1;
2484 }
2485
2486 memcpy(sendmsg_buf + len, vec->iov_base, vec->iov_len);
2487 len += vec->iov_len;
2488 }
2489 }
2490
2491 if (len > 0) {
2492 len = ztls_sendto_ctx(ctx, sendmsg_buf, len, flags,
2493 msg->msg_name, msg->msg_namelen);
2494 }
2495
2496 k_mutex_unlock(&sendmsg_lock);
2497
2498 return len;
2499 }
2500
tls_sendmsg_loop_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2501 static ssize_t tls_sendmsg_loop_and_send(struct tls_context *ctx,
2502 const struct msghdr *msg,
2503 int flags)
2504 {
2505 ssize_t len = 0;
2506 ssize_t ret;
2507
2508 for (int i = 0; i < msg->msg_iovlen; i++) {
2509 struct iovec *vec = msg->msg_iov + i;
2510 size_t sent = 0;
2511
2512 if (vec->iov_len == 0) {
2513 continue;
2514 }
2515
2516 while (sent < vec->iov_len) {
2517 uint8_t *ptr = (uint8_t *)vec->iov_base + sent;
2518
2519 ret = ztls_sendto_ctx(ctx, ptr, vec->iov_len - sent,
2520 flags, msg->msg_name,
2521 msg->msg_namelen);
2522 if (ret < 0) {
2523 return ret;
2524 }
2525 sent += ret;
2526 }
2527 len += sent;
2528 }
2529
2530 return len;
2531 }
2532
ztls_sendmsg_ctx(struct tls_context * ctx,const struct msghdr * msg,int flags)2533 ssize_t ztls_sendmsg_ctx(struct tls_context *ctx, const struct msghdr *msg,
2534 int flags)
2535 {
2536 if (msg == NULL) {
2537 errno = EINVAL;
2538 return -1;
2539 }
2540
2541 if (IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS) &&
2542 ctx->type == SOCK_DGRAM) {
2543 if (DTLS_SENDMSG_BUF_SIZE > 0) {
2544 /* With one buffer only, there's no need to use
2545 * intermediate buffer.
2546 */
2547 if (msghdr_non_empty_iov_count(msg) == 1) {
2548 goto send_loop;
2549 }
2550
2551 return dtls_sendmsg_merge_and_send(ctx, msg, flags);
2552 }
2553
2554 /*
2555 * Current mbedTLS API (i.e. mbedtls_ssl_write()) allows only to send a single
2556 * contiguous buffer. This means that gather write using sendmsg() can only be
2557 * handled correctly if there is a single non-empty buffer in msg->msg_iov.
2558 */
2559 if (msghdr_non_empty_iov_count(msg) > 1) {
2560 errno = EMSGSIZE;
2561 return -1;
2562 }
2563 }
2564
2565 send_loop:
2566 return tls_sendmsg_loop_and_send(ctx, msg, flags);
2567 }
2568
recv_tls(struct tls_context * ctx,void * buf,size_t max_len,int flags)2569 static ssize_t recv_tls(struct tls_context *ctx, void *buf,
2570 size_t max_len, int flags)
2571 {
2572 size_t recv_len = 0;
2573 const bool waitall = flags & ZSOCK_MSG_WAITALL;
2574 const bool is_block = is_blocking(ctx->sock, flags);
2575 k_timeout_t timeout;
2576 k_timepoint_t end;
2577 int ret;
2578
2579 if (ctx->error != 0) {
2580 errno = ctx->error;
2581 return -1;
2582 }
2583
2584 if (ctx->session_closed) {
2585 return 0;
2586 }
2587
2588 if (!is_block) {
2589 timeout = K_NO_WAIT;
2590 } else {
2591 timeout = ctx->options.timeout_rx;
2592 }
2593
2594 end = sys_timepoint_calc(timeout);
2595
2596 do {
2597 size_t read_len = max_len - recv_len;
2598
2599 ret = mbedtls_ssl_read(&ctx->ssl, (uint8_t *)buf + recv_len,
2600 read_len);
2601 if (ret < 0) {
2602 if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2603 /* Peer notified that it's closing the
2604 * connection.
2605 */
2606 ctx->session_closed = true;
2607 break;
2608 }
2609
2610 if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
2611 /* Client reconnect on the same socket is not
2612 * supported. See mbedtls_ssl_read API
2613 * documentation.
2614 */
2615 ctx->session_closed = true;
2616 break;
2617 }
2618
2619 if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2620 ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2621 ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2622 ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
2623 ret == MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET) {
2624 int timeout_ms;
2625
2626 if (!is_block) {
2627 ret = -EAGAIN;
2628 goto err;
2629 }
2630
2631 /* Blocking timeout. */
2632 timeout = sys_timepoint_timeout(end);
2633 if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2634 ret = -EAGAIN;
2635 goto err;
2636 }
2637
2638 timeout_ms = timeout_to_ms(&timeout);
2639
2640 /* Block. */
2641 k_mutex_unlock(ctx->lock);
2642 ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2643 k_mutex_lock(ctx->lock, K_FOREVER);
2644
2645 if (ret == 0) {
2646 /* Retry. */
2647 continue;
2648 }
2649 } else {
2650 NET_ERR("TLS recv error: -%x", -ret);
2651 ret = -EIO;
2652 }
2653
2654 err:
2655 errno = -ret;
2656 return -1;
2657 }
2658
2659 if (ret == 0) {
2660 break;
2661 }
2662
2663 recv_len += ret;
2664 } while ((recv_len == 0) || (waitall && (recv_len < max_len)));
2665
2666 return recv_len;
2667 }
2668
2669 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
recvfrom_dtls_common(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2670 static ssize_t recvfrom_dtls_common(struct tls_context *ctx, void *buf,
2671 size_t max_len, int flags,
2672 struct sockaddr *src_addr,
2673 socklen_t *addrlen)
2674 {
2675 int ret;
2676 bool is_block = is_blocking(ctx->sock, flags);
2677 k_timeout_t timeout;
2678 k_timepoint_t end;
2679
2680 if (ctx->error != 0) {
2681 errno = ctx->error;
2682 return -1;
2683 }
2684
2685 if (!is_block) {
2686 timeout = K_NO_WAIT;
2687 } else {
2688 timeout = ctx->options.timeout_rx;
2689 }
2690
2691 end = sys_timepoint_calc(timeout);
2692
2693 do {
2694 size_t remaining;
2695
2696 ret = mbedtls_ssl_read(&ctx->ssl, buf, max_len);
2697 if (ret < 0) {
2698 if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2699 ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2700 ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2701 ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2702 int timeout_dtls, timeout_sock, timeout_ms;
2703
2704 if (!is_block) {
2705 return ret;
2706 }
2707
2708 /* Blocking timeout. */
2709 timeout = sys_timepoint_timeout(end);
2710 if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2711 return ret;
2712 }
2713
2714 timeout_dtls = dtls_get_remaining_timeout(ctx);
2715 timeout_sock = timeout_to_ms(&timeout);
2716 if (timeout_dtls == SYS_FOREVER_MS ||
2717 timeout_sock == SYS_FOREVER_MS) {
2718 timeout_ms = MAX(timeout_dtls, timeout_sock);
2719 } else {
2720 timeout_ms = MIN(timeout_dtls, timeout_sock);
2721 }
2722
2723 /* Block. */
2724 k_mutex_unlock(ctx->lock);
2725 ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2726 k_mutex_lock(ctx->lock, K_FOREVER);
2727
2728 if (ret == 0) {
2729 /* Retry. */
2730 continue;
2731 } else {
2732 return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2733 }
2734 } else {
2735 return ret;
2736 }
2737 }
2738
2739 if (src_addr && addrlen) {
2740 dtls_peer_address_get(ctx, src_addr, addrlen);
2741 }
2742
2743 /* mbedtls_ssl_get_bytes_avail() indicate the data length
2744 * remaining in the current datagram.
2745 */
2746 remaining = mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2747
2748 /* No more data in the datagram, or dummy read. */
2749 if ((remaining == 0) || (max_len == 0)) {
2750 return ret;
2751 }
2752
2753 if (flags & ZSOCK_MSG_TRUNC) {
2754 ret += remaining;
2755 }
2756
2757 for (int i = 0; i < remaining; i++) {
2758 uint8_t byte;
2759 int err;
2760
2761 err = mbedtls_ssl_read(&ctx->ssl, &byte, sizeof(byte));
2762 if (err <= 0) {
2763 NET_ERR("Error while flushing the rest of the"
2764 " datagram, err %d", err);
2765 ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2766 break;
2767 }
2768 }
2769
2770 break;
2771 } while (true);
2772
2773
2774 return ret;
2775 }
2776
recvfrom_dtls_client(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2777 static ssize_t recvfrom_dtls_client(struct tls_context *ctx, void *buf,
2778 size_t max_len, int flags,
2779 struct sockaddr *src_addr,
2780 socklen_t *addrlen)
2781 {
2782 int ret;
2783
2784 if (!is_handshake_complete(ctx)) {
2785 ret = -ENOTCONN;
2786 goto error;
2787 }
2788
2789 ret = recvfrom_dtls_common(ctx, buf, max_len, flags, src_addr, addrlen);
2790 if (ret >= 0) {
2791 return ret;
2792 }
2793
2794 switch (ret) {
2795 case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2796 /* Peer notified that it's closing the connection. */
2797 ret = tls_mbedtls_reset(ctx);
2798 if (ret == 0) {
2799 ctx->error = ENOTCONN;
2800 ret = -ENOTCONN;
2801 } else {
2802 ctx->error = ENOMEM;
2803 ret = -ENOMEM;
2804 }
2805 break;
2806
2807 case MBEDTLS_ERR_SSL_TIMEOUT:
2808 (void)mbedtls_ssl_close_notify(&ctx->ssl);
2809 ctx->error = ETIMEDOUT;
2810 ret = -ETIMEDOUT;
2811 break;
2812
2813 case MBEDTLS_ERR_SSL_WANT_READ:
2814 case MBEDTLS_ERR_SSL_WANT_WRITE:
2815 case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2816 case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2817 ret = -EAGAIN;
2818 break;
2819
2820 default:
2821 NET_ERR("DTLS client recv error: -%x", -ret);
2822
2823 /* MbedTLS API documentation requires session to
2824 * be reset in other error cases
2825 */
2826 ret = tls_mbedtls_reset(ctx);
2827 if (ret != 0) {
2828 ctx->error = ENOMEM;
2829 errno = ENOMEM;
2830 } else {
2831 ctx->error = ECONNABORTED;
2832 ret = -ECONNABORTED;
2833 }
2834
2835 break;
2836 }
2837
2838 error:
2839 errno = -ret;
2840 return -1;
2841 }
2842
recvfrom_dtls_server(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2843 static ssize_t recvfrom_dtls_server(struct tls_context *ctx, void *buf,
2844 size_t max_len, int flags,
2845 struct sockaddr *src_addr,
2846 socklen_t *addrlen)
2847 {
2848 int ret;
2849 bool repeat;
2850 k_timeout_t timeout;
2851
2852 if (!ctx->is_initialized) {
2853 ret = tls_mbedtls_init(ctx, true);
2854 if (ret < 0) {
2855 goto error;
2856 }
2857 }
2858
2859 if (is_blocking(ctx->sock, flags)) {
2860 timeout = ctx->options.timeout_rx;
2861 } else {
2862 timeout = K_NO_WAIT;
2863 }
2864
2865 /* Loop to enable DTLS reconnection for servers without closing
2866 * a socket.
2867 */
2868 do {
2869 repeat = false;
2870
2871 if (!is_handshake_complete(ctx)) {
2872 ret = tls_mbedtls_handshake(ctx, timeout);
2873 if (ret < 0) {
2874 /* In case of EAGAIN, just exit. */
2875 if (ret == -EAGAIN) {
2876 break;
2877 }
2878
2879 ret = tls_mbedtls_reset(ctx);
2880 if (ret == 0) {
2881 repeat = true;
2882 } else {
2883 ret = -ENOMEM;
2884 }
2885
2886 continue;
2887 }
2888
2889 /* Server socket ready to use again. */
2890 ctx->error = 0;
2891 }
2892
2893 ret = recvfrom_dtls_common(ctx, buf, max_len, flags,
2894 src_addr, addrlen);
2895 if (ret >= 0) {
2896 return ret;
2897 }
2898
2899 switch (ret) {
2900 case MBEDTLS_ERR_SSL_TIMEOUT:
2901 (void)mbedtls_ssl_close_notify(&ctx->ssl);
2902 __fallthrough;
2903 /* fallthrough */
2904
2905 case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2906 case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
2907 ret = tls_mbedtls_reset(ctx);
2908 if (ret == 0) {
2909 repeat = true;
2910 } else {
2911 ctx->error = ENOMEM;
2912 ret = -ENOMEM;
2913 }
2914 break;
2915
2916 case MBEDTLS_ERR_SSL_WANT_READ:
2917 case MBEDTLS_ERR_SSL_WANT_WRITE:
2918 case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2919 case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2920 ret = -EAGAIN;
2921 break;
2922
2923 default:
2924 NET_ERR("DTLS server recv error: -%x", -ret);
2925
2926 ret = tls_mbedtls_reset(ctx);
2927 if (ret != 0) {
2928 ctx->error = ENOMEM;
2929 errno = ENOMEM;
2930 } else {
2931 ctx->error = ECONNABORTED;
2932 ret = -ECONNABORTED;
2933 }
2934
2935 break;
2936 }
2937 } while (repeat);
2938
2939 error:
2940 errno = -ret;
2941 return -1;
2942 }
2943 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2944
ztls_recvfrom_ctx(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2945 ssize_t ztls_recvfrom_ctx(struct tls_context *ctx, void *buf, size_t max_len,
2946 int flags, struct sockaddr *src_addr,
2947 socklen_t *addrlen)
2948 {
2949 if (flags & ZSOCK_MSG_PEEK) {
2950 /* TODO mbedTLS does not support 'peeking' This could be
2951 * bypassed by having intermediate buffer for peeking
2952 */
2953 errno = ENOTSUP;
2954 return -1;
2955 }
2956
2957 ctx->flags = flags;
2958
2959 /* TLS */
2960 if (ctx->type == SOCK_STREAM) {
2961 return recv_tls(ctx, buf, max_len, flags);
2962 }
2963
2964 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2965 /* DTLS */
2966 if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2967 return recvfrom_dtls_server(ctx, buf, max_len, flags,
2968 src_addr, addrlen);
2969 }
2970
2971 return recvfrom_dtls_client(ctx, buf, max_len, flags,
2972 src_addr, addrlen);
2973 #else
2974 errno = ENOTSUP;
2975 return -1;
2976 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2977 }
2978
ztls_poll_prepare_pollin(struct tls_context * ctx)2979 static int ztls_poll_prepare_pollin(struct tls_context *ctx)
2980 {
2981 /* If there already is mbedTLS data to read, there is no
2982 * need to set the k_poll_event object. Return EALREADY
2983 * so we won't block in the k_poll.
2984 */
2985 if (!ctx->is_listening) {
2986 if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
2987 return -EALREADY;
2988 }
2989 }
2990
2991 return 0;
2992 }
2993
ztls_poll_prepare_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)2994 static int ztls_poll_prepare_ctx(struct tls_context *ctx,
2995 struct zsock_pollfd *pfd,
2996 struct k_poll_event **pev,
2997 struct k_poll_event *pev_end)
2998 {
2999 const struct fd_op_vtable *vtable;
3000 struct k_mutex *lock;
3001 void *obj;
3002 int ret;
3003 short events = pfd->events;
3004
3005 /* DTLS client should wait for the handshake to complete before
3006 * it actually starts to poll for data.
3007 */
3008 if ((pfd->events & ZSOCK_POLLIN) && (ctx->type == SOCK_DGRAM) &&
3009 (ctx->options.role == MBEDTLS_SSL_IS_CLIENT) &&
3010 !is_handshake_complete(ctx)) {
3011 (*pev)->obj = &ctx->tls_established;
3012 (*pev)->type = K_POLL_TYPE_SEM_AVAILABLE;
3013 (*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
3014 (*pev)->state = K_POLL_STATE_NOT_READY;
3015 (*pev)++;
3016
3017 /* Since k_poll_event is configured by the TLS layer in this
3018 * case, do not forward ZSOCK_POLLIN to the underlying socket.
3019 */
3020 pfd->events &= ~ZSOCK_POLLIN;
3021 }
3022
3023 obj = zvfs_get_fd_obj_and_vtable(
3024 ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3025 if (obj == NULL) {
3026 ret = -EBADF;
3027 goto exit;
3028 }
3029
3030 (void)k_mutex_lock(lock, K_FOREVER);
3031
3032 ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
3033 pfd, pev, pev_end);
3034 if (ret != 0) {
3035 goto exit;
3036 }
3037
3038 if (pfd->events & ZSOCK_POLLIN) {
3039 ret = ztls_poll_prepare_pollin(ctx);
3040 }
3041
3042 exit:
3043 /* Restore original events. */
3044 pfd->events = events;
3045
3046 k_mutex_unlock(lock);
3047
3048 return ret;
3049 }
3050
3051 #include <zephyr/net/net_core.h>
3052
ztls_socket_data_check(struct tls_context * ctx)3053 static int ztls_socket_data_check(struct tls_context *ctx)
3054 {
3055 int ret;
3056
3057 if (ctx->type == SOCK_STREAM) {
3058 if (!ctx->is_initialized) {
3059 return -ENOTCONN;
3060 }
3061 }
3062 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3063 else {
3064 if (!ctx->is_initialized) {
3065 bool is_server = ctx->options.role == MBEDTLS_SSL_IS_SERVER;
3066
3067 ret = tls_mbedtls_init(ctx, is_server);
3068 if (ret < 0) {
3069 return -ENOMEM;
3070 }
3071 }
3072
3073 if (!is_handshake_complete(ctx)) {
3074 ret = tls_mbedtls_handshake(ctx, K_NO_WAIT);
3075 if (ret < 0) {
3076 if (ret == -EAGAIN) {
3077 return 0;
3078 }
3079
3080 ret = tls_mbedtls_reset(ctx);
3081 if (ret != 0) {
3082 return -ENOMEM;
3083 }
3084
3085 return 0;
3086 }
3087
3088 /* Socket ready to use again. */
3089 ctx->error = 0;
3090
3091 return 0;
3092 }
3093 }
3094 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3095
3096 ctx->flags = ZSOCK_MSG_DONTWAIT;
3097
3098 ret = mbedtls_ssl_read(&ctx->ssl, NULL, 0);
3099 if (ret < 0) {
3100 if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
3101 /* Don't reset the context for STREAM socket - the
3102 * application needs to reopen the socket anyway, and
3103 * resetting the context would result in an error instead
3104 * of 0 in a consecutive recv() call.
3105 */
3106 if (ctx->type == SOCK_DGRAM) {
3107 ret = tls_mbedtls_reset(ctx);
3108 if (ret != 0) {
3109 return -ENOMEM;
3110 }
3111 } else {
3112 ctx->session_closed = true;
3113 }
3114
3115 return -ENOTCONN;
3116 }
3117
3118 if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
3119 ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
3120 return 0;
3121 }
3122
3123 NET_ERR("TLS data check error: -%x", -ret);
3124
3125 /* MbedTLS API documentation requires session to
3126 * be reset in other error cases
3127 */
3128 if (tls_mbedtls_reset(ctx) != 0) {
3129 return -ENOMEM;
3130 }
3131
3132 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3133 if (ret == MBEDTLS_ERR_SSL_TIMEOUT && ctx->type == SOCK_DGRAM) {
3134 /* DTLS timeout interpreted as closing of connection. */
3135 return -ENOTCONN;
3136 }
3137 #endif
3138 return -ECONNABORTED;
3139 }
3140
3141 return mbedtls_ssl_get_bytes_avail(&ctx->ssl);
3142 }
3143
ztls_poll_update_pollin(int fd,struct tls_context * ctx,struct zsock_pollfd * pfd)3144 static int ztls_poll_update_pollin(int fd, struct tls_context *ctx,
3145 struct zsock_pollfd *pfd)
3146 {
3147 int ret;
3148
3149 if (!ctx->is_listening) {
3150 /* Already had TLS data to read on socket. */
3151 if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
3152 pfd->revents |= ZSOCK_POLLIN;
3153 goto next;
3154 }
3155 }
3156
3157 if (ctx->type == SOCK_STREAM) {
3158 if (!(pfd->revents & ZSOCK_POLLIN)) {
3159 /* No new data on a socket. */
3160 goto next;
3161 }
3162
3163 if (ctx->is_listening) {
3164 goto next;
3165 }
3166 }
3167 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3168 else {
3169 /* Perform data check without incoming data for completed DTLS connections.
3170 * This allows the connections to timeout with CONFIG_NET_SOCKETS_DTLS_TIMEOUT.
3171 */
3172 if (!is_handshake_complete(ctx) && !(pfd->revents & ZSOCK_POLLIN)) {
3173 goto next;
3174 }
3175 }
3176 #endif
3177 ret = ztls_socket_data_check(ctx);
3178 if (ret == -ENOTCONN || (pfd->revents & ZSOCK_POLLHUP)) {
3179 /* Datagram does not return 0 on consecutive recv, but an error
3180 * code, hence clear POLLIN.
3181 */
3182 if (ctx->type == SOCK_DGRAM) {
3183 pfd->revents &= ~ZSOCK_POLLIN;
3184 }
3185 pfd->revents |= ZSOCK_POLLHUP;
3186 goto next;
3187 } else if (ret < 0) {
3188 ctx->error = -ret;
3189 pfd->revents |= ZSOCK_POLLERR;
3190 goto next;
3191 } else if (ret == 0) {
3192 goto again;
3193 }
3194
3195 next:
3196 return 0;
3197
3198 again:
3199 /* Received encrypted data, but still not enough
3200 * to decrypt it and return data through socket,
3201 * ask for retry if no other events are set.
3202 */
3203 pfd->revents &= ~ZSOCK_POLLIN;
3204
3205 return -EAGAIN;
3206 }
3207
ztls_poll_update_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev)3208 static int ztls_poll_update_ctx(struct tls_context *ctx,
3209 struct zsock_pollfd *pfd,
3210 struct k_poll_event **pev)
3211 {
3212 const struct fd_op_vtable *vtable;
3213 struct k_mutex *lock;
3214 void *obj;
3215 int ret;
3216 short events = pfd->events;
3217
3218 obj = zvfs_get_fd_obj_and_vtable(
3219 ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3220 if (obj == NULL) {
3221 return -EBADF;
3222 }
3223
3224 (void)k_mutex_lock(lock, K_FOREVER);
3225
3226 /* Check if the socket was waiting for the handshake to complete. */
3227 if ((pfd->events & ZSOCK_POLLIN) &&
3228 ((*pev)->obj == &ctx->tls_established)) {
3229 /* In case handshake is complete, reconfigure the k_poll_event
3230 * to monitor the underlying socket now.
3231 */
3232 if ((*pev)->state != K_POLL_STATE_NOT_READY) {
3233 ret = zvfs_fdtable_call_ioctl(vtable, obj,
3234 ZFD_IOCTL_POLL_PREPARE,
3235 pfd, pev, *pev + 1);
3236 if (ret != 0 && ret != -EALREADY) {
3237 goto out;
3238 }
3239
3240 /* Return -EAGAIN to signal to poll() that it should
3241 * make another iteration with the event reconfigured
3242 * above (if needed).
3243 */
3244 ret = -EAGAIN;
3245 goto out;
3246 }
3247
3248 /* Handshake still not ready - skip ZSOCK_POLLIN verification
3249 * for the underlying socket.
3250 */
3251 (*pev)++;
3252 pfd->events &= ~ZSOCK_POLLIN;
3253 }
3254
3255 ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_UPDATE,
3256 pfd, pev);
3257 if (ret != 0) {
3258 goto exit;
3259 }
3260
3261 if (pfd->events & ZSOCK_POLLIN) {
3262 ret = ztls_poll_update_pollin(pfd->fd, ctx, pfd);
3263 if (ret == -EAGAIN && pfd->revents == 0) {
3264 (*pev - 1)->state = K_POLL_STATE_NOT_READY;
3265 goto exit;
3266 } else {
3267 ret = 0;
3268 }
3269 }
3270 exit:
3271 /* Restore original events. */
3272 pfd->events = events;
3273
3274 out:
3275 k_mutex_unlock(lock);
3276
3277 return ret;
3278 }
3279
3280 /* Return true if needed to retry rightoff or false otherwise. */
poll_offload_dtls_client_retry(struct tls_context * ctx,struct zsock_pollfd * pfd)3281 static bool poll_offload_dtls_client_retry(struct tls_context *ctx,
3282 struct zsock_pollfd *pfd)
3283 {
3284 /* DTLS client should wait for the handshake to complete before it
3285 * reports that data is ready.
3286 */
3287 if ((ctx->type != SOCK_DGRAM) ||
3288 (ctx->options.role != MBEDTLS_SSL_IS_CLIENT)) {
3289 return false;
3290 }
3291
3292 if (ctx->handshake_in_progress) {
3293 /* Add some sleep to allow lower priority threads to proceed
3294 * with handshake.
3295 */
3296 k_msleep(10);
3297
3298 pfd->revents &= ~ZSOCK_POLLIN;
3299 return true;
3300 } else if (!is_handshake_complete(ctx)) {
3301 uint8_t byte;
3302 int ret;
3303
3304 /* Handshake didn't start yet - just drop the incoming data -
3305 * it's the client who should initiate the handshake.
3306 */
3307 ret = zsock_recv(ctx->sock, &byte, sizeof(byte),
3308 ZSOCK_MSG_DONTWAIT);
3309 if (ret < 0) {
3310 pfd->revents |= ZSOCK_POLLERR;
3311 }
3312
3313 pfd->revents &= ~ZSOCK_POLLIN;
3314 return true;
3315 }
3316
3317 /* Handshake complete, just proceed. */
3318 return false;
3319 }
3320
ztls_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)3321 static int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int timeout)
3322 {
3323 int fd_backup[CONFIG_ZVFS_POLL_MAX];
3324 const struct fd_op_vtable *vtable;
3325 void *ctx;
3326 int ret = 0;
3327 int result;
3328 int i;
3329 bool retry;
3330 int remaining;
3331 uint32_t entry = k_uptime_get_32();
3332
3333 /* Overwrite TLS file descriptors with underlying ones. */
3334 for (i = 0; i < nfds; i++) {
3335 fd_backup[i] = fds[i].fd;
3336
3337 ctx = zvfs_get_fd_obj(fds[i].fd,
3338 (const struct fd_op_vtable *)
3339 &tls_sock_fd_op_vtable,
3340 0);
3341 if (ctx == NULL) {
3342 continue;
3343 }
3344
3345 if (fds[i].events & ZSOCK_POLLIN) {
3346 ret = ztls_poll_prepare_pollin(ctx);
3347 /* In case data is already available in mbedtls,
3348 * do not wait in poll.
3349 */
3350 if (ret == -EALREADY) {
3351 timeout = 0;
3352 }
3353 }
3354
3355 fds[i].fd = ((struct tls_context *)ctx)->sock;
3356 }
3357
3358 /* Get offloaded sockets vtable. */
3359 ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
3360 (const struct fd_op_vtable **)&vtable,
3361 NULL);
3362 if (ctx == NULL) {
3363 errno = EINVAL;
3364 goto exit;
3365 }
3366
3367 remaining = timeout;
3368
3369 do {
3370 for (i = 0; i < nfds; i++) {
3371 fds[i].revents = 0;
3372 }
3373
3374 ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
3375 fds, nfds, remaining);
3376 if (ret < 0) {
3377 goto exit;
3378 }
3379
3380 retry = false;
3381 ret = 0;
3382
3383 for (i = 0; i < nfds; i++) {
3384 ctx = zvfs_get_fd_obj(fd_backup[i],
3385 (const struct fd_op_vtable *)
3386 &tls_sock_fd_op_vtable,
3387 0);
3388 if (ctx != NULL) {
3389 if (fds[i].events & ZSOCK_POLLIN) {
3390 if (poll_offload_dtls_client_retry(
3391 ctx, &fds[i])) {
3392 retry = true;
3393 continue;
3394 }
3395
3396 result = ztls_poll_update_pollin(
3397 fd_backup[i], ctx, &fds[i]);
3398 if (result == -EAGAIN) {
3399 retry = true;
3400 }
3401 }
3402 }
3403
3404 if (fds[i].revents != 0) {
3405 ret++;
3406 }
3407 }
3408
3409 if (retry) {
3410 if (ret > 0 || timeout == 0) {
3411 goto exit;
3412 }
3413
3414 if (timeout > 0) {
3415 remaining = time_left(entry, timeout);
3416 if (remaining <= 0) {
3417 goto exit;
3418 }
3419 }
3420 }
3421 } while (retry);
3422
3423 exit:
3424 /* Restore original fds. */
3425 for (i = 0; i < nfds; i++) {
3426 fds[i].fd = fd_backup[i];
3427 }
3428
3429 return ret;
3430 }
3431
ztls_getsockopt_ctx(struct tls_context * ctx,int level,int optname,void * optval,socklen_t * optlen)3432 int ztls_getsockopt_ctx(struct tls_context *ctx, int level, int optname,
3433 void *optval, socklen_t *optlen)
3434 {
3435 int err;
3436
3437 if (!optval || !optlen) {
3438 errno = EINVAL;
3439 return -1;
3440 }
3441
3442 if ((level == SOL_SOCKET) && (optname == SO_PROTOCOL)) {
3443 /* Protocol type is overridden during socket creation. Its
3444 * value is restored here to return current value.
3445 */
3446 err = sock_opt_protocol_get(ctx, optval, optlen);
3447 if (err < 0) {
3448 errno = -err;
3449 return -1;
3450 }
3451 return err;
3452 }
3453
3454 /* In case error was set on a socket at the TLS layer (for example due
3455 * to receiving TLS alert), handle SO_ERROR here, and report that error.
3456 * Otherwise, forward the SO_ERROR option request to the underlying
3457 * TCP/UDP socket to handle.
3458 */
3459 if ((level == SOL_SOCKET) && (optname == SO_ERROR) && ctx->error != 0) {
3460 if (*optlen != sizeof(int)) {
3461 errno = EINVAL;
3462 return -1;
3463 }
3464
3465 *(int *)optval = ctx->error;
3466
3467 return 0;
3468 }
3469
3470 if (level != SOL_TLS) {
3471 return zsock_getsockopt(ctx->sock, level, optname,
3472 optval, optlen);
3473 }
3474
3475 switch (optname) {
3476 case TLS_SEC_TAG_LIST:
3477 err = tls_opt_sec_tag_list_get(ctx, optval, optlen);
3478 break;
3479
3480 case TLS_CIPHERSUITE_LIST:
3481 err = tls_opt_ciphersuite_list_get(ctx, optval, optlen);
3482 break;
3483
3484 case TLS_CIPHERSUITE_USED:
3485 err = tls_opt_ciphersuite_used_get(ctx, optval, optlen);
3486 break;
3487
3488 case TLS_ALPN_LIST:
3489 err = tls_opt_alpn_list_get(ctx, optval, optlen);
3490 break;
3491
3492 case TLS_SESSION_CACHE:
3493 err = tls_opt_session_cache_get(ctx, optval, optlen);
3494 break;
3495
3496 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3497 case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3498 err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3499 optlen, false);
3500 break;
3501
3502 case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3503 err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3504 optlen, true);
3505 break;
3506
3507 case TLS_DTLS_CID_STATUS:
3508 err = tls_opt_dtls_connection_id_status_get(ctx, optval,
3509 optlen);
3510 break;
3511
3512 case TLS_DTLS_CID_VALUE:
3513 err = tls_opt_dtls_connection_id_value_get(ctx, optval, optlen);
3514 break;
3515
3516 case TLS_DTLS_PEER_CID_VALUE:
3517 err = tls_opt_dtls_peer_connection_id_value_get(ctx, optval,
3518 optlen);
3519 break;
3520
3521 case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3522 err = tls_opt_dtls_handshake_on_connect_get(ctx, optval, optlen);
3523 break;
3524 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3525
3526 default:
3527 /* Unknown or write-only option. */
3528 err = -ENOPROTOOPT;
3529 break;
3530 }
3531
3532 if (err < 0) {
3533 errno = -err;
3534 return -1;
3535 }
3536
3537 return 0;
3538 }
3539
set_timeout_opt(k_timeout_t * timeout,const void * optval,socklen_t optlen)3540 static int set_timeout_opt(k_timeout_t *timeout, const void *optval,
3541 socklen_t optlen)
3542 {
3543 const struct zsock_timeval *tval = optval;
3544
3545 if (optlen != sizeof(struct zsock_timeval)) {
3546 return -EINVAL;
3547 }
3548
3549 if (tval->tv_sec == 0 && tval->tv_usec == 0) {
3550 *timeout = K_FOREVER;
3551 } else {
3552 *timeout = K_USEC(tval->tv_sec * 1000000ULL + tval->tv_usec);
3553 }
3554
3555 return 0;
3556 }
3557
ztls_setsockopt_ctx(struct tls_context * ctx,int level,int optname,const void * optval,socklen_t optlen)3558 int ztls_setsockopt_ctx(struct tls_context *ctx, int level, int optname,
3559 const void *optval, socklen_t optlen)
3560 {
3561 int err;
3562
3563 /* Underlying socket is used in non-blocking mode, hence implement
3564 * timeout at the TLS socket level.
3565 */
3566 if ((level == SOL_SOCKET) && (optname == SO_SNDTIMEO)) {
3567 err = set_timeout_opt(&ctx->options.timeout_tx, optval, optlen);
3568 goto out;
3569 }
3570
3571 if ((level == SOL_SOCKET) && (optname == SO_RCVTIMEO)) {
3572 err = set_timeout_opt(&ctx->options.timeout_rx, optval, optlen);
3573 goto out;
3574 }
3575
3576 if (level != SOL_TLS) {
3577 return zsock_setsockopt(ctx->sock, level, optname,
3578 optval, optlen);
3579 }
3580
3581 switch (optname) {
3582 case TLS_SEC_TAG_LIST:
3583 err = tls_opt_sec_tag_list_set(ctx, optval, optlen);
3584 break;
3585
3586 case TLS_HOSTNAME:
3587 err = tls_opt_hostname_set(ctx, optval, optlen);
3588 break;
3589
3590 case TLS_CIPHERSUITE_LIST:
3591 err = tls_opt_ciphersuite_list_set(ctx, optval, optlen);
3592 break;
3593
3594 case TLS_PEER_VERIFY:
3595 err = tls_opt_peer_verify_set(ctx, optval, optlen);
3596 break;
3597
3598 case TLS_CERT_NOCOPY:
3599 err = tls_opt_cert_nocopy_set(ctx, optval, optlen);
3600 break;
3601
3602 case TLS_DTLS_ROLE:
3603 err = tls_opt_dtls_role_set(ctx, optval, optlen);
3604 break;
3605
3606 case TLS_ALPN_LIST:
3607 err = tls_opt_alpn_list_set(ctx, optval, optlen);
3608 break;
3609
3610 case TLS_SESSION_CACHE:
3611 err = tls_opt_session_cache_set(ctx, optval, optlen);
3612 break;
3613
3614 case TLS_SESSION_CACHE_PURGE:
3615 err = tls_opt_session_cache_purge_set(ctx, optval, optlen);
3616 break;
3617
3618 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3619 case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3620 err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3621 optlen, false);
3622 break;
3623
3624 case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3625 err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3626 optlen, true);
3627 break;
3628
3629 case TLS_DTLS_CID:
3630 err = tls_opt_dtls_connection_id_set(ctx, optval, optlen);
3631 break;
3632
3633 case TLS_DTLS_CID_VALUE:
3634 err = tls_opt_dtls_connection_id_value_set(ctx, optval, optlen);
3635 break;
3636
3637 case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3638 err = tls_opt_dtls_handshake_on_connect_set(ctx, optval, optlen);
3639 break;
3640
3641 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3642
3643 case TLS_NATIVE:
3644 /* Option handled at the socket dispatcher level. */
3645 err = 0;
3646 break;
3647
3648 default:
3649 /* Unknown or read-only option. */
3650 err = -ENOPROTOOPT;
3651 break;
3652 }
3653
3654 out:
3655 if (err < 0) {
3656 errno = -err;
3657 return -1;
3658 }
3659
3660 return 0;
3661 }
3662
3663 #if defined(CONFIG_NET_TEST)
ztls_get_mbedtls_ssl_context(int fd)3664 mbedtls_ssl_context *ztls_get_mbedtls_ssl_context(int fd)
3665 {
3666 struct tls_context *ctx;
3667
3668 ctx = zvfs_get_fd_obj(fd, (const struct fd_op_vtable *)
3669 &tls_sock_fd_op_vtable, EBADF);
3670 if (ctx == NULL) {
3671 return NULL;
3672 }
3673
3674 return &ctx->ssl;
3675 }
3676 #endif /* CONFIG_NET_TEST */
3677
tls_sock_read_vmeth(void * obj,void * buffer,size_t count)3678 static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count)
3679 {
3680 return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0);
3681 }
3682
tls_sock_write_vmeth(void * obj,const void * buffer,size_t count)3683 static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer,
3684 size_t count)
3685 {
3686 return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0);
3687 }
3688
tls_sock_ioctl_vmeth(void * obj,unsigned int request,va_list args)3689 static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
3690 {
3691 struct tls_context *ctx = obj;
3692
3693 switch (request) {
3694 /* fcntl() commands */
3695 case F_GETFL:
3696 case F_SETFL: {
3697 const struct fd_op_vtable *vtable;
3698 struct k_mutex *lock;
3699 void *fd_obj;
3700 int ret;
3701
3702 fd_obj = zvfs_get_fd_obj_and_vtable(ctx->sock,
3703 (const struct fd_op_vtable **)&vtable, &lock);
3704 if (fd_obj == NULL) {
3705 errno = EBADF;
3706 return -1;
3707 }
3708
3709 (void)k_mutex_lock(lock, K_FOREVER);
3710
3711 /* Pass the call to the core socket implementation. */
3712 ret = vtable->ioctl(fd_obj, request, args);
3713
3714 k_mutex_unlock(lock);
3715
3716 return ret;
3717 }
3718
3719 case ZFD_IOCTL_SET_LOCK: {
3720 struct k_mutex *lock;
3721
3722 lock = va_arg(args, struct k_mutex *);
3723
3724 ctx_set_lock(obj, lock);
3725
3726 return 0;
3727 }
3728
3729 case ZFD_IOCTL_POLL_PREPARE: {
3730 struct zsock_pollfd *pfd;
3731 struct k_poll_event **pev;
3732 struct k_poll_event *pev_end;
3733
3734 pfd = va_arg(args, struct zsock_pollfd *);
3735 pev = va_arg(args, struct k_poll_event **);
3736 pev_end = va_arg(args, struct k_poll_event *);
3737
3738 return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end);
3739 }
3740
3741 case ZFD_IOCTL_POLL_UPDATE: {
3742 struct zsock_pollfd *pfd;
3743 struct k_poll_event **pev;
3744
3745 pfd = va_arg(args, struct zsock_pollfd *);
3746 pev = va_arg(args, struct k_poll_event **);
3747
3748 return ztls_poll_update_ctx(obj, pfd, pev);
3749 }
3750
3751 case ZFD_IOCTL_POLL_OFFLOAD: {
3752 struct zsock_pollfd *fds;
3753 int nfds;
3754 int timeout;
3755
3756 fds = va_arg(args, struct zsock_pollfd *);
3757 nfds = va_arg(args, int);
3758 timeout = va_arg(args, int);
3759
3760 return ztls_poll_offload(fds, nfds, timeout);
3761 }
3762
3763 default:
3764 errno = EOPNOTSUPP;
3765 return -1;
3766 }
3767 }
3768
tls_sock_shutdown_vmeth(void * obj,int how)3769 static int tls_sock_shutdown_vmeth(void *obj, int how)
3770 {
3771 struct tls_context *ctx = obj;
3772
3773 return zsock_shutdown(ctx->sock, how);
3774 }
3775
tls_sock_bind_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3776 static int tls_sock_bind_vmeth(void *obj, const struct sockaddr *addr,
3777 socklen_t addrlen)
3778 {
3779 struct tls_context *ctx = obj;
3780
3781 return zsock_bind(ctx->sock, addr, addrlen);
3782 }
3783
tls_sock_connect_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3784 static int tls_sock_connect_vmeth(void *obj, const struct sockaddr *addr,
3785 socklen_t addrlen)
3786 {
3787 return ztls_connect_ctx(obj, addr, addrlen);
3788 }
3789
tls_sock_listen_vmeth(void * obj,int backlog)3790 static int tls_sock_listen_vmeth(void *obj, int backlog)
3791 {
3792 struct tls_context *ctx = obj;
3793
3794 ctx->is_listening = true;
3795
3796 return zsock_listen(ctx->sock, backlog);
3797 }
3798
tls_sock_accept_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3799 static int tls_sock_accept_vmeth(void *obj, struct sockaddr *addr,
3800 socklen_t *addrlen)
3801 {
3802 return ztls_accept_ctx(obj, addr, addrlen);
3803 }
3804
tls_sock_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)3805 static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len,
3806 int flags,
3807 const struct sockaddr *dest_addr,
3808 socklen_t addrlen)
3809 {
3810 return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen);
3811 }
3812
tls_sock_sendmsg_vmeth(void * obj,const struct msghdr * msg,int flags)3813 static ssize_t tls_sock_sendmsg_vmeth(void *obj, const struct msghdr *msg,
3814 int flags)
3815 {
3816 return ztls_sendmsg_ctx(obj, msg, flags);
3817 }
3818
tls_sock_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)3819 static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len,
3820 int flags, struct sockaddr *src_addr,
3821 socklen_t *addrlen)
3822 {
3823 return ztls_recvfrom_ctx(obj, buf, max_len, flags,
3824 src_addr, addrlen);
3825 }
3826
tls_sock_getsockopt_vmeth(void * obj,int level,int optname,void * optval,socklen_t * optlen)3827 static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname,
3828 void *optval, socklen_t *optlen)
3829 {
3830 return ztls_getsockopt_ctx(obj, level, optname, optval, optlen);
3831 }
3832
tls_sock_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,socklen_t optlen)3833 static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname,
3834 const void *optval, socklen_t optlen)
3835 {
3836 return ztls_setsockopt_ctx(obj, level, optname, optval, optlen);
3837 }
3838
tls_sock_close2_vmeth(void * obj,int sock)3839 static int tls_sock_close2_vmeth(void *obj, int sock)
3840 {
3841 return ztls_close_ctx(obj, sock);
3842 }
3843
tls_sock_getpeername_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3844 static int tls_sock_getpeername_vmeth(void *obj, struct sockaddr *addr,
3845 socklen_t *addrlen)
3846 {
3847 struct tls_context *ctx = obj;
3848
3849 return zsock_getpeername(ctx->sock, addr, addrlen);
3850 }
3851
tls_sock_getsockname_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3852 static int tls_sock_getsockname_vmeth(void *obj, struct sockaddr *addr,
3853 socklen_t *addrlen)
3854 {
3855 struct tls_context *ctx = obj;
3856
3857 return zsock_getsockname(ctx->sock, addr, addrlen);
3858 }
3859
3860 static const struct socket_op_vtable tls_sock_fd_op_vtable = {
3861 .fd_vtable = {
3862 .read = tls_sock_read_vmeth,
3863 .write = tls_sock_write_vmeth,
3864 .close2 = tls_sock_close2_vmeth,
3865 .ioctl = tls_sock_ioctl_vmeth,
3866 },
3867 .shutdown = tls_sock_shutdown_vmeth,
3868 .bind = tls_sock_bind_vmeth,
3869 .connect = tls_sock_connect_vmeth,
3870 .listen = tls_sock_listen_vmeth,
3871 .accept = tls_sock_accept_vmeth,
3872 .sendto = tls_sock_sendto_vmeth,
3873 .sendmsg = tls_sock_sendmsg_vmeth,
3874 .recvfrom = tls_sock_recvfrom_vmeth,
3875 .getsockopt = tls_sock_getsockopt_vmeth,
3876 .setsockopt = tls_sock_setsockopt_vmeth,
3877 .getpeername = tls_sock_getpeername_vmeth,
3878 .getsockname = tls_sock_getsockname_vmeth,
3879 };
3880
tls_is_supported(int family,int type,int proto)3881 static bool tls_is_supported(int family, int type, int proto)
3882 {
3883 if (protocol_check(family, type, &proto) == 0) {
3884 return true;
3885 }
3886
3887 return false;
3888 }
3889
3890 /* Since both, TLS sockets and regular ones fall under the same address family,
3891 * it's required to process TLS first in order to capture socket calls which
3892 * create sockets for secure protocols. Every other call for AF_INET/AF_INET6
3893 * will be forwarded to regular socket implementation.
3894 */
3895 BUILD_ASSERT(CONFIG_NET_SOCKETS_TLS_PRIORITY < CONFIG_NET_SOCKETS_PRIORITY_DEFAULT,
3896 "CONFIG_NET_SOCKETS_TLS_PRIORITY have to be smaller than CONFIG_NET_SOCKETS_PRIORITY_DEFAULT");
3897
3898 NET_SOCKET_REGISTER(tls, CONFIG_NET_SOCKETS_TLS_PRIORITY, AF_UNSPEC,
3899 tls_is_supported, ztls_socket);
3900