1 /*
2  * Copyright (c) 2018 Intel Corporation
3  * Copyright (c) 2018 Nordic Semiconductor ASA
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <stdbool.h>
9 
10 #include <zephyr/logging/log.h>
11 LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL);
12 
13 #include <zephyr/init.h>
14 #include <zephyr/sys/util.h>
15 #include <zephyr/net/socket.h>
16 #include <zephyr/random/random.h>
17 #include <zephyr/internal/syscall_handler.h>
18 #include <zephyr/sys/fdtable.h>
19 
20 /* TODO: Remove all direct access to private fields.
21  * According with Mbed TLS migration guide:
22  *
23  * Direct access to fields of structures
24  * (`struct` types) declared in public headers is no longer
25  * supported. In Mbed TLS 3, the layout of structures is not
26  * considered part of the stable API, and minor versions (3.1, 3.2,
27  * etc.) may add, remove, rename, reorder or change the type of
28  * structure fields.
29  */
30 #if !defined(MBEDTLS_ALLOW_PRIVATE_ACCESS)
31 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
32 #endif
33 
34 #if defined(CONFIG_MBEDTLS)
35 #if !defined(CONFIG_MBEDTLS_CFG_FILE)
36 #include "mbedtls/config.h"
37 #else
38 #include CONFIG_MBEDTLS_CFG_FILE
39 #endif /* CONFIG_MBEDTLS_CFG_FILE */
40 
41 #include <mbedtls/net_sockets.h>
42 #include <mbedtls/x509.h>
43 #include <mbedtls/x509_crt.h>
44 #include <mbedtls/ssl.h>
45 #include <mbedtls/ssl_cookie.h>
46 #include <mbedtls/error.h>
47 #include <mbedtls/platform.h>
48 #include <mbedtls/ssl_cache.h>
49 #endif /* CONFIG_MBEDTLS */
50 
51 #include "sockets_internal.h"
52 #include "tls_internal.h"
53 
54 #if defined(CONFIG_MBEDTLS_DEBUG)
55 #include <zephyr_mbedtls_priv.h>
56 #endif
57 
58 #if defined(CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS)
59 #define ALPN_MAX_PROTOCOLS (CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS + 1)
60 #else
61 #define ALPN_MAX_PROTOCOLS 0
62 #endif /* CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS */
63 
64 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
65 #define DTLS_SENDMSG_BUF_SIZE (CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE)
66 #else
67 #define DTLS_SENDMSG_BUF_SIZE 0
68 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
69 
70 static const struct socket_op_vtable tls_sock_fd_op_vtable;
71 
72 #ifndef MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED
73 #define MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE
74 #endif
75 
76 /** A list of secure tags that TLS context should use. */
77 struct sec_tag_list {
78 	/** An array of secure tags referencing TLS credentials. */
79 	sec_tag_t sec_tags[CONFIG_NET_SOCKETS_TLS_MAX_CREDENTIALS];
80 
81 	/** Number of configured secure tags. */
82 	int sec_tag_count;
83 };
84 
85 /** Timer context for DTLS. */
86 struct dtls_timing_context {
87 	/** Current time, stored during timer set. */
88 	uint32_t snapshot;
89 
90 	/** Intermediate delay value. For details, refer to mbedTLS API
91 	 *  documentation (mbedtls_ssl_set_timer_t).
92 	 */
93 	uint32_t int_ms;
94 
95 	/** Final delay value. For details, refer to mbedTLS API documentation
96 	 *  (mbedtls_ssl_set_timer_t).
97 	 */
98 	uint32_t fin_ms;
99 };
100 
101 /** TLS peer address/session ID mapping. */
102 struct tls_session_cache {
103 	/** Creation time. */
104 	int64_t timestamp;
105 
106 	/** Peer address. */
107 	struct net_sockaddr peer_addr;
108 
109 	/** Session buffer. */
110 	uint8_t *session;
111 
112 	/** Session length. */
113 	size_t session_len;
114 };
115 
116 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
117 struct tls_dtls_cid {
118 	bool enabled;
119 	unsigned char cid[MAX(MBEDTLS_SSL_CID_OUT_LEN_MAX,
120 			      MBEDTLS_SSL_CID_IN_LEN_MAX)];
121 	size_t cid_len;
122 };
123 #endif
124 
125 /** TLS context information. */
126 __net_socket struct tls_context {
127 	/** Underlying TCP/UDP socket. */
128 	int sock;
129 
130 	/** Information whether TLS context is used. */
131 	bool is_used : 1;
132 
133 	/** Information whether TLS context was initialized. */
134 	bool is_initialized : 1;
135 
136 	/** Information whether underlying socket is listening. */
137 	bool is_listening : 1;
138 
139 	/** Information whether TLS handshake is currently in progress. */
140 	bool handshake_in_progress : 1;
141 
142 	/** Session ended at the TLS/DTLS level. */
143 	bool session_closed : 1;
144 
145 	/** Socket type. */
146 	enum net_sock_type type;
147 
148 	/** Secure protocol version running on TLS context. */
149 	enum net_ip_protocol_secure tls_version;
150 
151 	/** Socket flags passed to a socket call. */
152 	int flags;
153 
154 	/* Indicates whether socket is in error state at TLS/DTLS level. */
155 	int error;
156 
157 	/** Information whether TLS handshake is complete or not. */
158 	struct k_sem tls_established;
159 
160 	/* TLS socket mutex lock. */
161 	struct k_mutex *lock;
162 
163 	/** TLS specific option values. */
164 	struct {
165 		/** Select which credentials to use with TLS. */
166 		struct sec_tag_list sec_tag_list;
167 
168 		/** 0-terminated list of allowed ciphersuites (mbedTLS format).
169 		 */
170 		int ciphersuites[CONFIG_NET_SOCKETS_TLS_MAX_CIPHERSUITES + 1];
171 
172 		/** Information if hostname was explicitly set on a socket. */
173 		bool is_hostname_set;
174 
175 		/** Peer verification level. */
176 		int8_t verify_level;
177 
178 		/** Indicating on whether DER certificates should not be copied
179 		 * to the heap.
180 		 */
181 		int8_t cert_nocopy;
182 
183 		/** DTLS role, client by default. */
184 		int8_t role;
185 
186 		/** NULL-terminated list of allowed application layer
187 		 * protocols.
188 		 */
189 		const char *alpn_list[ALPN_MAX_PROTOCOLS];
190 
191 		/** Session cache enabled on a socket. */
192 		bool cache_enabled;
193 
194 		/** Socket TX timeout */
195 		k_timeout_t timeout_tx;
196 
197 		/** Socket RX timeout */
198 		k_timeout_t timeout_rx;
199 
200 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
201 		/* DTLS handshake timeout */
202 		uint32_t dtls_handshake_timeout_min;
203 		uint32_t dtls_handshake_timeout_max;
204 
205 		struct tls_dtls_cid dtls_cid;
206 
207 		bool dtls_handshake_on_connect;
208 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
209 
210 #if defined(CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK)
211 		struct tls_cert_verify_cb cert_verify;
212 #endif /* CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK */
213 	} options;
214 
215 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
216 	/** Context information for DTLS timing. */
217 	struct dtls_timing_context dtls_timing;
218 
219 	/** mbedTLS cookie context for DTLS */
220 	mbedtls_ssl_cookie_ctx cookie;
221 
222 	/** DTLS peer address. */
223 	struct net_sockaddr dtls_peer_addr;
224 
225 	/** DTLS peer address length. */
226 	net_socklen_t dtls_peer_addrlen;
227 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
228 
229 #if defined(CONFIG_MBEDTLS)
230 	/** mbedTLS context. */
231 	mbedtls_ssl_context ssl;
232 
233 	/** mbedTLS configuration. */
234 	mbedtls_ssl_config config;
235 
236 #if defined(MBEDTLS_X509_CRT_PARSE_C)
237 	/** mbedTLS structure for CA chain. */
238 	mbedtls_x509_crt ca_chain;
239 
240 	/** mbedTLS structure for own certificate. */
241 	mbedtls_x509_crt own_cert;
242 
243 	/** mbedTLS structure for own private key. */
244 	mbedtls_pk_context priv_key;
245 #endif /* MBEDTLS_X509_CRT_PARSE_C */
246 
247 #endif /* CONFIG_MBEDTLS */
248 };
249 
250 
251 /* A global pool of TLS contexts. */
252 static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS];
253 
254 static struct tls_session_cache client_cache[CONFIG_NET_SOCKETS_TLS_MAX_CLIENT_SESSION_COUNT];
255 
256 #if defined(MBEDTLS_SSL_CACHE_C)
257 static mbedtls_ssl_cache_context server_cache;
258 #endif
259 
260 /* A mutex for protecting TLS context allocation. */
261 static struct k_mutex context_lock;
262 
263 /* Arbitrary delay value to wait if mbedTLS reports it cannot proceed for
264  * reasons other than TX/RX block.
265  */
266 #define TLS_WAIT_MS 100
267 
tls_session_cache_reset(void)268 static void tls_session_cache_reset(void)
269 {
270 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
271 		if (client_cache[i].session != NULL) {
272 			mbedtls_free(client_cache[i].session);
273 		}
274 	}
275 
276 	(void)memset(client_cache, 0, sizeof(client_cache));
277 }
278 
net_socket_is_tls(void * obj)279 bool net_socket_is_tls(void *obj)
280 {
281 	return PART_OF_ARRAY(tls_contexts, (struct tls_context *)obj);
282 }
283 
tls_ctr_drbg_random(void * ctx,unsigned char * buf,size_t len)284 static int tls_ctr_drbg_random(void *ctx, unsigned char *buf, size_t len)
285 {
286 	ARG_UNUSED(ctx);
287 
288 #if defined(CONFIG_CSPRNG_ENABLED)
289 	return sys_csrand_get(buf, len);
290 #else
291 	sys_rand_get(buf, len);
292 
293 	return 0;
294 #endif
295 }
296 
297 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
298 /* mbedTLS-defined function for setting timer. */
dtls_timing_set_delay(void * data,uint32_t int_ms,uint32_t fin_ms)299 static void dtls_timing_set_delay(void *data, uint32_t int_ms, uint32_t fin_ms)
300 {
301 	struct dtls_timing_context *ctx = data;
302 
303 	ctx->int_ms = int_ms;
304 	ctx->fin_ms = fin_ms;
305 
306 	if (fin_ms != 0U) {
307 		ctx->snapshot = k_uptime_get_32();
308 	}
309 }
310 
311 /* mbedTLS-defined function for getting timer status.
312  * The return values are specified by mbedTLS. The callback must return:
313  *   -1 if cancelled (fin_ms == 0),
314  *    0 if none of the delays have passed,
315  *    1 if only the intermediate delay has passed,
316  *    2 if the final delay has passed.
317  */
dtls_timing_get_delay(void * data)318 static int dtls_timing_get_delay(void *data)
319 {
320 	struct dtls_timing_context *timing = data;
321 	unsigned long elapsed_ms;
322 
323 	NET_ASSERT(timing);
324 
325 	if (timing->fin_ms == 0U) {
326 		return -1;
327 	}
328 
329 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
330 
331 	if (elapsed_ms >= timing->fin_ms) {
332 		return 2;
333 	}
334 
335 	if (elapsed_ms >= timing->int_ms) {
336 		return 1;
337 	}
338 
339 	return 0;
340 }
341 
dtls_get_remaining_timeout(struct tls_context * ctx)342 static int dtls_get_remaining_timeout(struct tls_context *ctx)
343 {
344 	struct dtls_timing_context *timing = &ctx->dtls_timing;
345 	uint32_t elapsed_ms;
346 
347 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
348 
349 	if (timing->fin_ms == 0U) {
350 		return SYS_FOREVER_MS;
351 	}
352 
353 	if (elapsed_ms >= timing->fin_ms) {
354 		return 0;
355 	}
356 
357 	return timing->fin_ms - elapsed_ms;
358 }
359 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
360 
361 /* Initialize TLS internals. */
tls_init(void)362 static int tls_init(void)
363 {
364 
365 #if !defined(CONFIG_ENTROPY_HAS_DRIVER)
366 	NET_WARN("No entropy device on the system, "
367 		 "TLS communication is insecure!");
368 #endif
369 
370 	(void)memset(tls_contexts, 0, sizeof(tls_contexts));
371 	(void)memset(client_cache, 0, sizeof(client_cache));
372 
373 	k_mutex_init(&context_lock);
374 
375 #if defined(MBEDTLS_SSL_CACHE_C)
376 	mbedtls_ssl_cache_init(&server_cache);
377 #endif
378 
379 	return 0;
380 }
381 
382 SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT);
383 
is_handshake_complete(struct tls_context * ctx)384 static inline bool is_handshake_complete(struct tls_context *ctx)
385 {
386 	return k_sem_count_get(&ctx->tls_established) != 0;
387 }
388 
389 /*
390  * Copied from include/mbedtls/ssl_internal.h
391  *
392  * Maximum length we can advertise as our max content length for
393  * RFC 6066 max_fragment_length extension negotiation purposes
394  * (the lesser of both sizes, if they are unequal.)
395  */
396 #define MBEDTLS_TLS_EXT_ADV_CONTENT_LEN (                            \
397 	(MBEDTLS_SSL_IN_CONTENT_LEN > MBEDTLS_SSL_OUT_CONTENT_LEN)   \
398 	? (MBEDTLS_SSL_OUT_CONTENT_LEN)				     \
399 	: (MBEDTLS_SSL_IN_CONTENT_LEN)				     \
400 	)
401 
402 #if defined(CONFIG_NET_SOCKETS_TLS_SET_MAX_FRAGMENT_LENGTH) &&	\
403 	defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) &&		\
404 	(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN < 16384)
405 
406 BUILD_ASSERT(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN >= 512,
407 	     "Too small content length!");
408 
tls_mfl_code_from_content_len(size_t len)409 static inline unsigned char tls_mfl_code_from_content_len(size_t len)
410 {
411 	if (len >= 4096) {
412 		return MBEDTLS_SSL_MAX_FRAG_LEN_4096;
413 	} else if (len >= 2048) {
414 		return MBEDTLS_SSL_MAX_FRAG_LEN_2048;
415 	} else if (len >= 1024) {
416 		return MBEDTLS_SSL_MAX_FRAG_LEN_1024;
417 	} else if (len >= 512) {
418 		return MBEDTLS_SSL_MAX_FRAG_LEN_512;
419 	} else {
420 		return MBEDTLS_SSL_MAX_FRAG_LEN_INVALID;
421 	}
422 }
423 
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)424 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type)
425 {
426 	unsigned char mfl_code;
427 	size_t len = MBEDTLS_TLS_EXT_ADV_CONTENT_LEN;
428 
429 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
430 	if (type == NET_SOCK_DGRAM && len > CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH) {
431 		len = CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH;
432 	}
433 #endif
434 	mfl_code = tls_mfl_code_from_content_len(len);
435 
436 	mbedtls_ssl_conf_max_frag_len(config, mfl_code);
437 }
438 #else
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)439 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type) {}
440 #endif
441 
442 /* Allocate TLS context. */
tls_alloc(void)443 static struct tls_context *tls_alloc(void)
444 {
445 	int i;
446 	struct tls_context *tls = NULL;
447 
448 	k_mutex_lock(&context_lock, K_FOREVER);
449 
450 	for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) {
451 		if (!tls_contexts[i].is_used) {
452 			tls = &tls_contexts[i];
453 			(void)memset(tls, 0, sizeof(*tls));
454 			tls->is_used = true;
455 			tls->options.verify_level = -1;
456 			tls->options.timeout_tx = K_FOREVER;
457 			tls->options.timeout_rx = K_FOREVER;
458 			tls->sock = -1;
459 
460 			NET_DBG("Allocated TLS context, %p", tls);
461 			break;
462 		}
463 	}
464 
465 	k_mutex_unlock(&context_lock);
466 
467 	if (tls) {
468 		k_sem_init(&tls->tls_established, 0, 1);
469 
470 		mbedtls_ssl_init(&tls->ssl);
471 		mbedtls_ssl_config_init(&tls->config);
472 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
473 		mbedtls_ssl_cookie_init(&tls->cookie);
474 		tls->options.dtls_handshake_timeout_min =
475 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN;
476 		tls->options.dtls_handshake_timeout_max =
477 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX;
478 		tls->options.dtls_cid.cid_len = 0;
479 		tls->options.dtls_cid.enabled = false;
480 		tls->options.dtls_handshake_on_connect = true;
481 #endif
482 #if defined(MBEDTLS_X509_CRT_PARSE_C)
483 		mbedtls_x509_crt_init(&tls->ca_chain);
484 		mbedtls_x509_crt_init(&tls->own_cert);
485 		mbedtls_pk_init(&tls->priv_key);
486 #endif
487 
488 #if defined(CONFIG_MBEDTLS_DEBUG)
489 		mbedtls_ssl_conf_dbg(&tls->config, zephyr_mbedtls_debug, NULL);
490 #endif
491 	} else {
492 		NET_WARN("Failed to allocate TLS context");
493 	}
494 
495 	return tls;
496 }
497 
498 /* Allocate new TLS context and copy the content from the source context. */
tls_clone(struct tls_context * source_tls)499 static struct tls_context *tls_clone(struct tls_context *source_tls)
500 {
501 	struct tls_context *target_tls;
502 
503 	target_tls = tls_alloc();
504 	if (!target_tls) {
505 		return NULL;
506 	}
507 
508 	target_tls->tls_version = source_tls->tls_version;
509 	target_tls->type = source_tls->type;
510 
511 	memcpy(&target_tls->options, &source_tls->options,
512 	       sizeof(target_tls->options));
513 
514 #if defined(MBEDTLS_X509_CRT_PARSE_C)
515 	if (target_tls->options.is_hostname_set) {
516 		mbedtls_ssl_set_hostname(&target_tls->ssl,
517 					 source_tls->ssl.hostname);
518 	}
519 #endif
520 
521 	return target_tls;
522 }
523 
524 /* Release TLS context. */
tls_release(struct tls_context * tls)525 static int tls_release(struct tls_context *tls)
526 {
527 	if (!PART_OF_ARRAY(tls_contexts, tls)) {
528 		NET_ERR("Invalid TLS context");
529 		return -EBADF;
530 	}
531 
532 	if (!tls->is_used) {
533 		NET_ERR("Deallocating unused TLS context");
534 		return -EBADF;
535 	}
536 
537 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
538 	mbedtls_ssl_cookie_free(&tls->cookie);
539 #endif
540 	mbedtls_ssl_config_free(&tls->config);
541 	mbedtls_ssl_free(&tls->ssl);
542 #if defined(MBEDTLS_X509_CRT_PARSE_C)
543 	mbedtls_x509_crt_free(&tls->ca_chain);
544 	mbedtls_x509_crt_free(&tls->own_cert);
545 	mbedtls_pk_free(&tls->priv_key);
546 #endif
547 
548 	tls->is_used = false;
549 
550 	return 0;
551 }
552 
peer_addr_cmp(const struct net_sockaddr * addr,const struct net_sockaddr * peer_addr)553 static bool peer_addr_cmp(const struct net_sockaddr *addr,
554 			  const struct net_sockaddr *peer_addr)
555 {
556 	if (addr->sa_family != peer_addr->sa_family) {
557 		return false;
558 	}
559 
560 	if (IS_ENABLED(CONFIG_NET_IPV6) && peer_addr->sa_family == NET_AF_INET6) {
561 		struct net_sockaddr_in6 *addr1 = net_sin6(peer_addr);
562 		struct net_sockaddr_in6 *addr2 = net_sin6(addr);
563 
564 		return (addr1->sin6_port == addr2->sin6_port) &&
565 			net_ipv6_addr_cmp(&addr1->sin6_addr, &addr2->sin6_addr);
566 	} else if (IS_ENABLED(CONFIG_NET_IPV4) && peer_addr->sa_family == NET_AF_INET) {
567 		struct net_sockaddr_in *addr1 = net_sin(peer_addr);
568 		struct net_sockaddr_in *addr2 = net_sin(addr);
569 
570 		return (addr1->sin_port == addr2->sin_port) &&
571 			net_ipv4_addr_cmp(&addr1->sin_addr, &addr2->sin_addr);
572 	}
573 
574 	return false;
575 }
576 
tls_session_save(const struct net_sockaddr * peer_addr,mbedtls_ssl_session * session)577 static int tls_session_save(const struct net_sockaddr *peer_addr,
578 			    mbedtls_ssl_session *session)
579 {
580 	struct tls_session_cache *entry = NULL;
581 	size_t session_len;
582 	int ret;
583 
584 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
585 		if (client_cache[i].session == NULL) {
586 			/* New entry. */
587 			if (entry == NULL || entry->session != NULL) {
588 				entry = &client_cache[i];
589 			}
590 		} else {
591 			if (peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
592 				/* Reuse old entry for given address. */
593 				entry = &client_cache[i];
594 				break;
595 			}
596 
597 			/* Remember the oldest entry and reuse if needed. */
598 			if (entry == NULL ||
599 			    (entry->session != NULL &&
600 			     entry->timestamp < client_cache[i].timestamp)) {
601 				entry = &client_cache[i];
602 			}
603 		}
604 	}
605 
606 	/* Allocate session and save */
607 
608 	if (entry->session != NULL) {
609 		mbedtls_free(entry->session);
610 		entry->session = NULL;
611 	}
612 
613 	(void)mbedtls_ssl_session_save(session, NULL, 0, &session_len);
614 
615 	entry->session = mbedtls_calloc(1, session_len);
616 	if (entry->session == NULL) {
617 		NET_ERR("Failed to allocate session buffer.");
618 		return -ENOMEM;
619 	}
620 
621 	ret = mbedtls_ssl_session_save(session, entry->session, session_len,
622 				       &session_len);
623 	if (ret < 0) {
624 		NET_ERR("Failed to serialize session, err: -0x%x.", -ret);
625 		mbedtls_free(entry->session);
626 		entry->session = NULL;
627 		return -ENOMEM;
628 	}
629 
630 	entry->session_len = session_len;
631 	entry->timestamp = k_uptime_get();
632 	memcpy(&entry->peer_addr, peer_addr, sizeof(*peer_addr));
633 
634 	return 0;
635 }
636 
tls_session_get(const struct net_sockaddr * peer_addr,mbedtls_ssl_session * session)637 static int tls_session_get(const struct net_sockaddr *peer_addr,
638 			   mbedtls_ssl_session *session)
639 {
640 	struct tls_session_cache *entry = NULL;
641 	int ret;
642 
643 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
644 		if (client_cache[i].session != NULL &&
645 		    peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
646 			entry = &client_cache[i];
647 			break;
648 		}
649 	}
650 
651 	if (entry == NULL) {
652 		return -ENOENT;
653 	}
654 
655 	ret = mbedtls_ssl_session_load(session, entry->session,
656 				       entry->session_len);
657 	if (ret < 0) {
658 		/* Discard corrupted session data. */
659 		mbedtls_free(entry->session);
660 		entry->session = NULL;
661 		NET_ERR("Failed to load TLS session %d", ret);
662 		return -EIO;
663 	}
664 
665 	return 0;
666 }
667 
tls_session_store(struct tls_context * context,const struct net_sockaddr * addr,net_socklen_t addrlen)668 static void tls_session_store(struct tls_context *context,
669 			      const struct net_sockaddr *addr,
670 			      net_socklen_t addrlen)
671 {
672 	mbedtls_ssl_session session;
673 	struct net_sockaddr peer_addr = { 0 };
674 	int ret;
675 
676 	if (!context->options.cache_enabled) {
677 		return;
678 	}
679 
680 	memcpy(&peer_addr, addr, addrlen);
681 	mbedtls_ssl_session_init(&session);
682 
683 	ret = mbedtls_ssl_get_session(&context->ssl, &session);
684 	if (ret < 0) {
685 		NET_ERR("Failed to obtain session for %p", context);
686 		goto exit;
687 	}
688 
689 	ret = tls_session_save(&peer_addr, &session);
690 	if (ret < 0) {
691 		NET_ERR("Failed to save session for %p", context);
692 	}
693 
694 exit:
695 	mbedtls_ssl_session_free(&session);
696 }
697 
tls_session_restore(struct tls_context * context,const struct net_sockaddr * addr,net_socklen_t addrlen)698 static void tls_session_restore(struct tls_context *context,
699 				const struct net_sockaddr *addr,
700 				net_socklen_t addrlen)
701 {
702 	mbedtls_ssl_session session;
703 	struct net_sockaddr peer_addr = { 0 };
704 	int ret;
705 
706 	if (!context->options.cache_enabled) {
707 		return;
708 	}
709 
710 	memcpy(&peer_addr, addr, addrlen);
711 	mbedtls_ssl_session_init(&session);
712 
713 	ret = tls_session_get(&peer_addr, &session);
714 	if (ret < 0) {
715 		NET_DBG("Session not found for %p", context);
716 		goto exit;
717 	}
718 
719 	ret = mbedtls_ssl_set_session(&context->ssl, &session);
720 	if (ret < 0) {
721 		NET_ERR("Failed to set session for %p", context);
722 	}
723 
724 exit:
725 	mbedtls_ssl_session_free(&session);
726 }
727 
tls_session_purge(void)728 static void tls_session_purge(void)
729 {
730 	tls_session_cache_reset();
731 
732 #if defined(MBEDTLS_SSL_CACHE_C)
733 	mbedtls_ssl_cache_free(&server_cache);
734 	mbedtls_ssl_cache_init(&server_cache);
735 #endif
736 }
737 
time_left(uint32_t start,uint32_t timeout)738 static inline int time_left(uint32_t start, uint32_t timeout)
739 {
740 	uint32_t elapsed = k_uptime_get_32() - start;
741 
742 	return timeout - elapsed;
743 }
744 
wait(int sock,int timeout,int event)745 static int wait(int sock, int timeout, int event)
746 {
747 	struct zsock_pollfd fds = {
748 		.fd = sock,
749 		.events = event,
750 	};
751 	int ret;
752 
753 	ret = zsock_poll(&fds, 1, timeout);
754 	if (ret < 0) {
755 		return ret;
756 	}
757 
758 	if (ret == 1) {
759 		if (fds.revents & ZSOCK_POLLNVAL) {
760 			return -EBADF;
761 		}
762 
763 		if (fds.revents & ZSOCK_POLLERR) {
764 			int optval;
765 			net_socklen_t optlen = sizeof(optval);
766 
767 			if (zsock_getsockopt(fds.fd, ZSOCK_SOL_SOCKET, ZSOCK_SO_ERROR,
768 					     &optval, &optlen) == 0) {
769 				NET_ERR("TLS underlying socket poll error %d",
770 					-optval);
771 				return -optval;
772 			}
773 
774 			return -EIO;
775 		}
776 	}
777 
778 	return 0;
779 }
780 
wait_for_reason(int sock,int timeout,int reason)781 static int wait_for_reason(int sock, int timeout, int reason)
782 {
783 	if (reason == MBEDTLS_ERR_SSL_WANT_READ) {
784 		return wait(sock, timeout, ZSOCK_POLLIN);
785 	}
786 
787 	if (reason == MBEDTLS_ERR_SSL_WANT_WRITE) {
788 		return wait(sock, timeout, ZSOCK_POLLOUT);
789 	}
790 
791 	/* Any other reason - no way to monitor, just wait for some time. */
792 	k_msleep(TLS_WAIT_MS);
793 
794 	return 0;
795 }
796 
is_blocking(int sock,int flags)797 static bool is_blocking(int sock, int flags)
798 {
799 	int sock_flags = zsock_fcntl(sock, ZVFS_F_GETFL, 0);
800 
801 	if (sock_flags == -1) {
802 		return false;
803 	}
804 
805 	return !((flags & ZSOCK_MSG_DONTWAIT) || (sock_flags & ZVFS_O_NONBLOCK));
806 }
807 
timeout_to_ms(k_timeout_t * timeout)808 static int timeout_to_ms(k_timeout_t *timeout)
809 {
810 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
811 		return 0;
812 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
813 		return SYS_FOREVER_MS;
814 	} else {
815 		return k_ticks_to_ms_floor32(timeout->ticks);
816 	}
817 }
818 
ctx_set_lock(struct tls_context * ctx,struct k_mutex * lock)819 static void ctx_set_lock(struct tls_context *ctx, struct k_mutex *lock)
820 {
821 	ctx->lock = lock;
822 }
823 
824 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
dtls_is_peer_addr_valid(struct tls_context * context,const struct net_sockaddr * peer_addr,net_socklen_t addrlen)825 static bool dtls_is_peer_addr_valid(struct tls_context *context,
826 				    const struct net_sockaddr *peer_addr,
827 				    net_socklen_t addrlen)
828 {
829 	if (context->dtls_peer_addrlen != addrlen) {
830 		return false;
831 	}
832 
833 	return peer_addr_cmp(&context->dtls_peer_addr, peer_addr);
834 }
835 
dtls_peer_address_set(struct tls_context * context,const struct net_sockaddr * peer_addr,net_socklen_t addrlen)836 static void dtls_peer_address_set(struct tls_context *context,
837 				  const struct net_sockaddr *peer_addr,
838 				  net_socklen_t addrlen)
839 {
840 	if (addrlen <= sizeof(context->dtls_peer_addr)) {
841 		memcpy(&context->dtls_peer_addr, peer_addr, addrlen);
842 		context->dtls_peer_addrlen = addrlen;
843 	}
844 }
845 
dtls_peer_address_get(struct tls_context * context,struct net_sockaddr * peer_addr,net_socklen_t * addrlen)846 static void dtls_peer_address_get(struct tls_context *context,
847 				  struct net_sockaddr *peer_addr,
848 				  net_socklen_t *addrlen)
849 {
850 	net_socklen_t len = MIN(context->dtls_peer_addrlen, *addrlen);
851 
852 	memcpy(peer_addr, &context->dtls_peer_addr, len);
853 	*addrlen = len;
854 }
855 
dtls_tx(void * ctx,const unsigned char * buf,size_t len)856 static int dtls_tx(void *ctx, const unsigned char *buf, size_t len)
857 {
858 	struct tls_context *tls_ctx = ctx;
859 	ssize_t sent;
860 
861 	sent = zsock_sendto(tls_ctx->sock, buf, len, ZSOCK_MSG_DONTWAIT,
862 			    &tls_ctx->dtls_peer_addr,
863 			    tls_ctx->dtls_peer_addrlen);
864 	if (sent < 0) {
865 		if (errno == EAGAIN) {
866 			return MBEDTLS_ERR_SSL_WANT_WRITE;
867 		}
868 
869 		return MBEDTLS_ERR_NET_SEND_FAILED;
870 	}
871 
872 	return sent;
873 }
874 
dtls_rx(void * ctx,unsigned char * buf,size_t len)875 static int dtls_rx(void *ctx, unsigned char *buf, size_t len)
876 {
877 	struct tls_context *tls_ctx = ctx;
878 	net_socklen_t addrlen = sizeof(struct net_sockaddr);
879 	struct net_sockaddr addr;
880 	int err;
881 	ssize_t received;
882 
883 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
884 				  ZSOCK_MSG_DONTWAIT, &addr, &addrlen);
885 	if (received < 0) {
886 		if (errno == EAGAIN) {
887 			return MBEDTLS_ERR_SSL_WANT_READ;
888 		}
889 
890 		return MBEDTLS_ERR_NET_RECV_FAILED;
891 	}
892 
893 	if (tls_ctx->dtls_peer_addrlen == 0) {
894 		/* Only allow to store peer address for DTLS servers. */
895 		if (tls_ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
896 			dtls_peer_address_set(tls_ctx, &addr, addrlen);
897 
898 			err = mbedtls_ssl_set_client_transport_id(
899 				&tls_ctx->ssl,
900 				(const unsigned char *)&addr, addrlen);
901 			if (err < 0) {
902 				return err;
903 			}
904 		} else {
905 			/* For clients it's incorrect to receive when
906 			 * no peer has been set up.
907 			 */
908 			return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED;
909 		}
910 	} else if (!dtls_is_peer_addr_valid(tls_ctx, &addr, addrlen)) {
911 		return MBEDTLS_ERR_SSL_WANT_READ;
912 	}
913 
914 	return received;
915 }
916 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
917 
tls_tx(void * ctx,const unsigned char * buf,size_t len)918 static int tls_tx(void *ctx, const unsigned char *buf, size_t len)
919 {
920 	struct tls_context *tls_ctx = ctx;
921 	ssize_t sent;
922 
923 	sent = zsock_sendto(tls_ctx->sock, buf, len,
924 			    ZSOCK_MSG_DONTWAIT, NULL, 0);
925 	if (sent < 0) {
926 		if (errno == EAGAIN) {
927 			return MBEDTLS_ERR_SSL_WANT_WRITE;
928 		}
929 
930 		return MBEDTLS_ERR_NET_SEND_FAILED;
931 	}
932 
933 	return sent;
934 }
935 
tls_rx(void * ctx,unsigned char * buf,size_t len)936 static int tls_rx(void *ctx, unsigned char *buf, size_t len)
937 {
938 	struct tls_context *tls_ctx = ctx;
939 	ssize_t received;
940 
941 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
942 				  ZSOCK_MSG_DONTWAIT, NULL, 0);
943 	if (received < 0) {
944 		if (errno == EAGAIN) {
945 			return MBEDTLS_ERR_SSL_WANT_READ;
946 		}
947 
948 		return MBEDTLS_ERR_NET_RECV_FAILED;
949 	}
950 
951 	return received;
952 }
953 
954 #if defined(MBEDTLS_X509_CRT_PARSE_C)
crt_is_pem(const unsigned char * buf,size_t buflen)955 static bool crt_is_pem(const unsigned char *buf, size_t buflen)
956 {
957 	return (buflen != 0 && buf[buflen - 1] == '\0' &&
958 		strstr((const char *)buf, "-----BEGIN CERTIFICATE-----") != NULL);
959 }
960 #endif
961 
tls_add_ca_certificate(struct tls_context * tls,struct tls_credential * ca_cert)962 static int tls_add_ca_certificate(struct tls_context *tls,
963 				  struct tls_credential *ca_cert)
964 {
965 #if defined(MBEDTLS_X509_CRT_PARSE_C)
966 	int err;
967 
968 	if (tls->options.cert_nocopy == ZSOCK_TLS_CERT_NOCOPY_NONE ||
969 	    crt_is_pem(ca_cert->buf, ca_cert->len)) {
970 		err = mbedtls_x509_crt_parse(&tls->ca_chain, ca_cert->buf,
971 					     ca_cert->len);
972 	} else {
973 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->ca_chain,
974 							ca_cert->buf,
975 							ca_cert->len);
976 	}
977 
978 	if (err != 0) {
979 		NET_ERR("Failed to parse CA certificate, err: -0x%x", -err);
980 		return -EINVAL;
981 	}
982 
983 	return 0;
984 #endif /* MBEDTLS_X509_CRT_PARSE_C */
985 
986 	return -ENOTSUP;
987 }
988 
tls_set_ca_chain(struct tls_context * tls)989 static void tls_set_ca_chain(struct tls_context *tls)
990 {
991 #if defined(MBEDTLS_X509_CRT_PARSE_C)
992 	mbedtls_ssl_conf_ca_chain(&tls->config, &tls->ca_chain, NULL);
993 	mbedtls_ssl_conf_cert_profile(&tls->config,
994 				      &mbedtls_x509_crt_profile_default);
995 #endif /* MBEDTLS_X509_CRT_PARSE_C */
996 }
997 
tls_add_own_cert(struct tls_context * tls,struct tls_credential * own_cert)998 static int tls_add_own_cert(struct tls_context *tls,
999 			    struct tls_credential *own_cert)
1000 {
1001 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1002 	int err;
1003 
1004 	if (tls->options.cert_nocopy == ZSOCK_TLS_CERT_NOCOPY_NONE ||
1005 	    crt_is_pem(own_cert->buf, own_cert->len)) {
1006 		err = mbedtls_x509_crt_parse(&tls->own_cert,
1007 					     own_cert->buf, own_cert->len);
1008 	} else {
1009 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->own_cert,
1010 							own_cert->buf,
1011 							own_cert->len);
1012 	}
1013 
1014 	if (err != 0) {
1015 		return -EINVAL;
1016 	}
1017 
1018 	return 0;
1019 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1020 
1021 	return -ENOTSUP;
1022 }
1023 
tls_set_own_cert(struct tls_context * tls)1024 static int tls_set_own_cert(struct tls_context *tls)
1025 {
1026 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1027 	int err = mbedtls_ssl_conf_own_cert(&tls->config, &tls->own_cert,
1028 					    &tls->priv_key);
1029 	if (err != 0) {
1030 		err = -ENOMEM;
1031 	}
1032 
1033 	return err;
1034 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1035 
1036 	return -ENOTSUP;
1037 }
1038 
tls_set_private_key(struct tls_context * tls,struct tls_credential * priv_key)1039 static int tls_set_private_key(struct tls_context *tls,
1040 			       struct tls_credential *priv_key)
1041 {
1042 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1043 	int err;
1044 
1045 	err = mbedtls_pk_parse_key(&tls->priv_key, priv_key->buf,
1046 				   priv_key->len, NULL, 0,
1047 				   tls_ctr_drbg_random, NULL);
1048 	if (err != 0) {
1049 		return -EINVAL;
1050 	}
1051 
1052 	return 0;
1053 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1054 
1055 	return -ENOTSUP;
1056 }
1057 
tls_set_psk(struct tls_context * tls,struct tls_credential * psk,struct tls_credential * psk_id)1058 static int tls_set_psk(struct tls_context *tls,
1059 		       struct tls_credential *psk,
1060 		       struct tls_credential *psk_id)
1061 {
1062 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
1063 	int err = mbedtls_ssl_conf_psk(&tls->config,
1064 				       psk->buf, psk->len,
1065 				       (const unsigned char *)psk_id->buf,
1066 				       psk_id->len);
1067 	if (err != 0) {
1068 		return -EINVAL;
1069 	}
1070 
1071 	return 0;
1072 #endif
1073 
1074 	return -ENOTSUP;
1075 }
1076 
tls_set_credential(struct tls_context * tls,struct tls_credential * cred)1077 static int tls_set_credential(struct tls_context *tls,
1078 			      struct tls_credential *cred)
1079 {
1080 	switch (cred->type) {
1081 	case TLS_CREDENTIAL_CA_CERTIFICATE:
1082 		return tls_add_ca_certificate(tls, cred);
1083 
1084 	case TLS_CREDENTIAL_PUBLIC_CERTIFICATE:
1085 		return tls_add_own_cert(tls, cred);
1086 
1087 	case TLS_CREDENTIAL_PRIVATE_KEY:
1088 		return tls_set_private_key(tls, cred);
1089 	break;
1090 
1091 	case TLS_CREDENTIAL_PSK:
1092 	{
1093 		struct tls_credential *psk_id =
1094 			credential_get(cred->tag, TLS_CREDENTIAL_PSK_ID);
1095 		if (!psk_id) {
1096 			return -ENOENT;
1097 		}
1098 
1099 		return tls_set_psk(tls, cred, psk_id);
1100 	}
1101 
1102 	case TLS_CREDENTIAL_PSK_ID:
1103 		/* Ignore PSK ID - it will be used together
1104 		 * with PSK
1105 		 */
1106 		break;
1107 
1108 	default:
1109 		return -EINVAL;
1110 	}
1111 
1112 	return 0;
1113 }
1114 
tls_mbedtls_set_credentials(struct tls_context * tls)1115 static int tls_mbedtls_set_credentials(struct tls_context *tls)
1116 {
1117 	struct tls_credential *cred;
1118 	sec_tag_t tag;
1119 	int i, err = 0;
1120 	bool tag_found, ca_cert_present = false, own_cert_present = false;
1121 
1122 	credentials_lock();
1123 
1124 	for (i = 0; i < tls->options.sec_tag_list.sec_tag_count; i++) {
1125 		tag = tls->options.sec_tag_list.sec_tags[i];
1126 		cred = NULL;
1127 		tag_found = false;
1128 
1129 		while ((cred = credential_next_get(tag, cred)) != NULL) {
1130 			tag_found = true;
1131 
1132 			err = tls_set_credential(tls, cred);
1133 			if (err != 0) {
1134 				goto exit;
1135 			}
1136 
1137 			if (cred->type == TLS_CREDENTIAL_CA_CERTIFICATE) {
1138 				ca_cert_present = true;
1139 			} else if (cred->type == TLS_CREDENTIAL_PUBLIC_CERTIFICATE) {
1140 				own_cert_present = true;
1141 			}
1142 		}
1143 
1144 		if (!tag_found) {
1145 			err = -ENOENT;
1146 			goto exit;
1147 		}
1148 	}
1149 
1150 exit:
1151 	credentials_unlock();
1152 
1153 	if (err == 0) {
1154 		if (ca_cert_present) {
1155 			tls_set_ca_chain(tls);
1156 		}
1157 		if (own_cert_present) {
1158 			err = tls_set_own_cert(tls);
1159 		}
1160 	}
1161 
1162 	return err;
1163 }
1164 
tls_mbedtls_reset(struct tls_context * context)1165 static int tls_mbedtls_reset(struct tls_context *context)
1166 {
1167 	int ret;
1168 
1169 	ret = mbedtls_ssl_session_reset(&context->ssl);
1170 	if (ret != 0) {
1171 		return ret;
1172 	}
1173 
1174 	k_sem_reset(&context->tls_established);
1175 
1176 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1177 	/* Server role: reset the address so that a new
1178 	 *              client can connect w/o a need to reopen a socket
1179 	 * Client role: keep peer addr so socket can continue to be used
1180 	 *              even on handshake timeout
1181 	 */
1182 	if (context->options.role == MBEDTLS_SSL_IS_SERVER) {
1183 		(void)memset(&context->dtls_peer_addr, 0,
1184 			     sizeof(context->dtls_peer_addr));
1185 		context->dtls_peer_addrlen = 0;
1186 	}
1187 #endif
1188 
1189 	return 0;
1190 }
1191 
tls_mbedtls_handshake(struct tls_context * context,k_timeout_t timeout)1192 static int tls_mbedtls_handshake(struct tls_context *context,
1193 				 k_timeout_t timeout)
1194 {
1195 	k_timepoint_t end;
1196 	int ret;
1197 
1198 	context->handshake_in_progress = true;
1199 
1200 	end = sys_timepoint_calc(timeout);
1201 
1202 	while ((ret = mbedtls_ssl_handshake(&context->ssl)) != 0) {
1203 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
1204 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
1205 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
1206 		    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
1207 			int timeout_ms;
1208 
1209 			/* Blocking timeout. */
1210 			timeout = sys_timepoint_timeout(end);
1211 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1212 				ret = -EAGAIN;
1213 				break;
1214 			}
1215 
1216 			/* Block. */
1217 			timeout_ms = timeout_to_ms(&timeout);
1218 
1219 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1220 			if (context->type == NET_SOCK_DGRAM) {
1221 				int timeout_dtls =
1222 					dtls_get_remaining_timeout(context);
1223 
1224 				if (timeout_dtls != SYS_FOREVER_MS) {
1225 					if (timeout_ms == SYS_FOREVER_MS) {
1226 						timeout_ms = timeout_dtls;
1227 					} else {
1228 						timeout_ms = MIN(timeout_dtls,
1229 								 timeout_ms);
1230 					}
1231 				}
1232 			}
1233 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1234 
1235 			ret = wait_for_reason(context->sock, timeout_ms, ret);
1236 			if (ret != 0) {
1237 				break;
1238 			}
1239 
1240 			continue;
1241 		} else if (ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED) {
1242 			ret = tls_mbedtls_reset(context);
1243 			if (ret == 0) {
1244 				if (!K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1245 					continue;
1246 				}
1247 
1248 				ret = -EAGAIN;
1249 				break;
1250 			}
1251 		} else if (ret == MBEDTLS_ERR_SSL_TIMEOUT) {
1252 			/* MbedTLS API documentation requires session to
1253 			 * be reset in this case
1254 			 */
1255 			ret = tls_mbedtls_reset(context);
1256 			if (ret == 0) {
1257 				NET_ERR("TLS handshake timeout");
1258 				context->error = ETIMEDOUT;
1259 				ret = -ETIMEDOUT;
1260 				break;
1261 			}
1262 		} else {
1263 			/* MbedTLS API documentation requires session to
1264 			 * be reset in other error cases
1265 			 */
1266 			NET_ERR("TLS handshake error: -0x%x", -ret);
1267 			ret = tls_mbedtls_reset(context);
1268 			if (ret == 0) {
1269 				context->error = ECONNABORTED;
1270 				ret = -ECONNABORTED;
1271 				break;
1272 			}
1273 		}
1274 
1275 		/* Avoid constant loop if tls_mbedtls_reset fails */
1276 		NET_ERR("TLS reset error: -0x%x", -ret);
1277 		context->error = ECONNABORTED;
1278 		ret = -ECONNABORTED;
1279 		break;
1280 	}
1281 
1282 	if (ret == 0) {
1283 		k_sem_give(&context->tls_established);
1284 	}
1285 
1286 	context->handshake_in_progress = false;
1287 
1288 	return ret;
1289 }
1290 
tls_mbedtls_init(struct tls_context * context,bool is_server)1291 static int tls_mbedtls_init(struct tls_context *context, bool is_server)
1292 {
1293 	int role, type, ret;
1294 
1295 	role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
1296 
1297 	type = (context->type == NET_SOCK_STREAM) ?
1298 		MBEDTLS_SSL_TRANSPORT_STREAM :
1299 		MBEDTLS_SSL_TRANSPORT_DATAGRAM;
1300 
1301 	if (type == MBEDTLS_SSL_TRANSPORT_STREAM) {
1302 		mbedtls_ssl_set_bio(&context->ssl, context,
1303 				    tls_tx, tls_rx, NULL);
1304 	} else {
1305 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1306 		mbedtls_ssl_set_bio(&context->ssl, context,
1307 				    dtls_tx, dtls_rx, NULL);
1308 #else
1309 		return -ENOTSUP;
1310 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1311 	}
1312 
1313 	ret = mbedtls_ssl_config_defaults(&context->config, role, type,
1314 					  MBEDTLS_SSL_PRESET_DEFAULT);
1315 	if (ret != 0) {
1316 		/* According to mbedTLS API documentation,
1317 		 * mbedtls_ssl_config_defaults can fail due to memory
1318 		 * allocation failure
1319 		 */
1320 		return -ENOMEM;
1321 	}
1322 	tls_set_max_frag_len(&context->config, context->type);
1323 
1324 #if defined(MBEDTLS_SSL_RENEGOTIATION)
1325 	mbedtls_ssl_conf_legacy_renegotiation(&context->config,
1326 					   MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE);
1327 	mbedtls_ssl_conf_renegotiation(&context->config,
1328 				       MBEDTLS_SSL_RENEGOTIATION_ENABLED);
1329 #endif
1330 
1331 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1332 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1333 		/* DTLS requires timer callbacks to operate */
1334 		mbedtls_ssl_set_timer_cb(&context->ssl,
1335 					 &context->dtls_timing,
1336 					 dtls_timing_set_delay,
1337 					 dtls_timing_get_delay);
1338 		mbedtls_ssl_conf_handshake_timeout(&context->config,
1339 				context->options.dtls_handshake_timeout_min,
1340 				context->options.dtls_handshake_timeout_max);
1341 
1342 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1343 		if (context->options.dtls_cid.enabled) {
1344 			ret = mbedtls_ssl_conf_cid(
1345 					&context->config,
1346 					context->options.dtls_cid.cid_len,
1347 					MBEDTLS_SSL_UNEXPECTED_CID_IGNORE);
1348 			if (ret != 0) {
1349 				return -EINVAL;
1350 			}
1351 		}
1352 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1353 
1354 		/* Configure cookie for DTLS server */
1355 		if (role == MBEDTLS_SSL_IS_SERVER) {
1356 			ret = mbedtls_ssl_cookie_setup(&context->cookie,
1357 						       tls_ctr_drbg_random,
1358 						       NULL);
1359 			if (ret != 0) {
1360 				return -ENOMEM;
1361 			}
1362 
1363 			mbedtls_ssl_conf_dtls_cookies(&context->config,
1364 						      mbedtls_ssl_cookie_write,
1365 						      mbedtls_ssl_cookie_check,
1366 						      &context->cookie);
1367 
1368 			mbedtls_ssl_conf_read_timeout(
1369 					&context->config,
1370 					CONFIG_NET_SOCKETS_DTLS_TIMEOUT);
1371 		}
1372 	}
1373 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1374 
1375 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1376 	/* For TLS clients, set hostname to empty string to enforce it's
1377 	 * verification - only if hostname option was not set. Otherwise
1378 	 * depend on user configuration.
1379 	 */
1380 	if (!is_server && !context->options.is_hostname_set) {
1381 		mbedtls_ssl_set_hostname(&context->ssl, "");
1382 	}
1383 #endif
1384 
1385 	/* If verification level was specified explicitly, set it. Otherwise,
1386 	 * use mbedTLS default values (required for client, none for server)
1387 	 */
1388 	if (context->options.verify_level != -1) {
1389 		mbedtls_ssl_conf_authmode(&context->config,
1390 					  context->options.verify_level);
1391 	}
1392 
1393 	mbedtls_ssl_conf_rng(&context->config,
1394 			     tls_ctr_drbg_random,
1395 			     NULL);
1396 
1397 	ret = tls_mbedtls_set_credentials(context);
1398 	if (ret != 0) {
1399 		return ret;
1400 	}
1401 
1402 	if (context->options.ciphersuites[0] != 0) {
1403 		/* Specific ciphersuites configured, so use them */
1404 		NET_DBG("Using user-specified ciphersuites");
1405 		mbedtls_ssl_conf_ciphersuites(&context->config,
1406 					      context->options.ciphersuites);
1407 	}
1408 
1409 #if defined(CONFIG_MBEDTLS_SSL_ALPN)
1410 	if (ALPN_MAX_PROTOCOLS && context->options.alpn_list[0] != NULL) {
1411 		ret = mbedtls_ssl_conf_alpn_protocols(&context->config,
1412 				context->options.alpn_list);
1413 		if (ret != 0) {
1414 			return -EINVAL;
1415 		}
1416 	}
1417 #endif /* CONFIG_MBEDTLS_SSL_ALPN */
1418 
1419 #if defined(MBEDTLS_SSL_CACHE_C)
1420 	if (is_server && context->options.cache_enabled) {
1421 		mbedtls_ssl_conf_session_cache(&context->config, &server_cache,
1422 					       mbedtls_ssl_cache_get,
1423 					       mbedtls_ssl_cache_set);
1424 	}
1425 #endif
1426 
1427 #if defined(MBEDTLS_SSL_EARLY_DATA)
1428 	mbedtls_ssl_conf_early_data(&context->config, MBEDTLS_SSL_EARLY_DATA_ENABLED);
1429 #endif
1430 
1431 #if defined(CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK)
1432 	if (context->options.cert_verify.cb != NULL) {
1433 		mbedtls_ssl_conf_verify(&context->config,
1434 					context->options.cert_verify.cb,
1435 					context->options.cert_verify.ctx);
1436 	}
1437 #endif /* CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK */
1438 
1439 	ret = mbedtls_ssl_setup(&context->ssl,
1440 				&context->config);
1441 	if (ret != 0) {
1442 		/* According to mbedTLS API documentation,
1443 		 * mbedtls_ssl_setup can fail due to memory allocation failure
1444 		 */
1445 		return -ENOMEM;
1446 	}
1447 
1448 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) && defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1449 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1450 		if (context->options.dtls_cid.enabled) {
1451 			ret = mbedtls_ssl_set_cid(&context->ssl, MBEDTLS_SSL_CID_ENABLED,
1452 						  context->options.dtls_cid.cid,
1453 						  context->options.dtls_cid.cid_len);
1454 			if (ret != 0) {
1455 				return -EINVAL;
1456 			}
1457 		}
1458 	}
1459 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS && CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1460 
1461 	context->is_initialized = true;
1462 
1463 	return 0;
1464 }
1465 
tls_check_cert(struct tls_credential * cert)1466 static int tls_check_cert(struct tls_credential *cert)
1467 {
1468 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1469 	mbedtls_x509_crt cert_ctx;
1470 	int err;
1471 
1472 	mbedtls_x509_crt_init(&cert_ctx);
1473 
1474 	if (crt_is_pem(cert->buf, cert->len)) {
1475 		err = mbedtls_x509_crt_parse(&cert_ctx, cert->buf, cert->len);
1476 	} else {
1477 		/* For DER case, use the no copy version of the parsing function
1478 		 * to avoid unnecessary heap allocations.
1479 		 */
1480 		err = mbedtls_x509_crt_parse_der_nocopy(&cert_ctx, cert->buf,
1481 							cert->len);
1482 	}
1483 
1484 	if (err != 0) {
1485 		NET_ERR("Failed to parse %s on tag %d, err: -0x%x",
1486 			"certificate", cert->tag, -err);
1487 		return -EINVAL;
1488 	}
1489 
1490 	mbedtls_x509_crt_free(&cert_ctx);
1491 
1492 	return err;
1493 #else
1494 	NET_ERR("TLS with certificates disabled. "
1495 		"Reconfigure mbed TLS to support certificate based key exchange.");
1496 
1497 	return -ENOTSUP;
1498 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1499 }
1500 
tls_check_priv_key(struct tls_credential * priv_key)1501 static int tls_check_priv_key(struct tls_credential *priv_key)
1502 {
1503 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1504 	mbedtls_pk_context key_ctx;
1505 	int err;
1506 
1507 	mbedtls_pk_init(&key_ctx);
1508 
1509 	err = mbedtls_pk_parse_key(&key_ctx, priv_key->buf,
1510 				   priv_key->len, NULL, 0,
1511 				   tls_ctr_drbg_random, NULL);
1512 	if (err != 0) {
1513 		NET_ERR("Failed to parse %s on tag %d, err: -0x%x",
1514 			"private key", priv_key->tag, -err);
1515 		err = -EINVAL;
1516 	}
1517 
1518 	mbedtls_pk_free(&key_ctx);
1519 
1520 	return err;
1521 #else
1522 	NET_ERR("TLS with certificates disabled. "
1523 		"Reconfigure mbed TLS to support certificate based key exchange.");
1524 
1525 	return -ENOTSUP;
1526 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1527 }
1528 
tls_check_psk(struct tls_credential * psk)1529 static int tls_check_psk(struct tls_credential *psk)
1530 {
1531 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
1532 	struct tls_credential *psk_id;
1533 
1534 	psk_id = credential_get(psk->tag, TLS_CREDENTIAL_PSK_ID);
1535 	if (psk_id == NULL) {
1536 		NET_ERR("No matching PSK ID found for tag %d", psk->tag);
1537 		return -EINVAL;
1538 	}
1539 
1540 	if (psk->len == 0 || psk_id->len == 0) {
1541 		NET_ERR("PSK or PSK ID empty on tag %d", psk->tag);
1542 		return -EINVAL;
1543 	}
1544 
1545 	return 0;
1546 #else
1547 	NET_ERR("TLS with PSK disabled. "
1548 		"Reconfigure mbed TLS to support PSK based key exchange.");
1549 
1550 	return -ENOTSUP;
1551 #endif
1552 }
1553 
tls_check_credentials(const sec_tag_t * sec_tags,int sec_tag_count)1554 static int tls_check_credentials(const sec_tag_t *sec_tags, int sec_tag_count)
1555 {
1556 	int err = 0;
1557 
1558 	credentials_lock();
1559 
1560 	for (int i = 0; i < sec_tag_count; i++) {
1561 		sec_tag_t tag = sec_tags[i];
1562 		struct tls_credential *cred = NULL;
1563 		bool tag_found = false;
1564 
1565 		while ((cred = credential_next_get(tag, cred)) != NULL) {
1566 			tag_found = true;
1567 
1568 			switch (cred->type) {
1569 			case TLS_CREDENTIAL_CA_CERTIFICATE:
1570 				__fallthrough;
1571 			case TLS_CREDENTIAL_PUBLIC_CERTIFICATE:
1572 				err = tls_check_cert(cred);
1573 				if (err != 0) {
1574 					goto exit;
1575 				}
1576 
1577 				break;
1578 			case TLS_CREDENTIAL_PRIVATE_KEY:
1579 				err = tls_check_priv_key(cred);
1580 				if (err != 0) {
1581 					goto exit;
1582 				}
1583 
1584 				break;
1585 			case TLS_CREDENTIAL_PSK:
1586 				err = tls_check_psk(cred);
1587 				if (err != 0) {
1588 					goto exit;
1589 				}
1590 
1591 				break;
1592 			case TLS_CREDENTIAL_PSK_ID:
1593 				/* Ignore PSK ID - it will be verified together
1594 				 * with PSK.
1595 				 */
1596 				break;
1597 			default:
1598 				return -EINVAL;
1599 			}
1600 		}
1601 
1602 		/* If no credential is found with such a tag, report an error. */
1603 		if (!tag_found) {
1604 			NET_ERR("No TLS credential found with tag %d", tag);
1605 			err = -ENOENT;
1606 			goto exit;
1607 		}
1608 	}
1609 
1610 exit:
1611 	credentials_unlock();
1612 
1613 	return err;
1614 }
1615 
tls_opt_sec_tag_list_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1616 static int tls_opt_sec_tag_list_set(struct tls_context *context,
1617 				    const void *optval, net_socklen_t optlen)
1618 {
1619 	int sec_tag_cnt;
1620 	int ret;
1621 
1622 	if (!optval) {
1623 		return -EINVAL;
1624 	}
1625 
1626 	if (optlen % sizeof(sec_tag_t) != 0) {
1627 		return -EINVAL;
1628 	}
1629 
1630 	sec_tag_cnt = optlen / sizeof(sec_tag_t);
1631 	if (sec_tag_cnt >
1632 		ARRAY_SIZE(context->options.sec_tag_list.sec_tags)) {
1633 		return -EINVAL;
1634 	}
1635 
1636 	ret = tls_check_credentials((const sec_tag_t *)optval, sec_tag_cnt);
1637 	if (ret < 0) {
1638 		return ret;
1639 	}
1640 
1641 	memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
1642 	context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;
1643 
1644 	return 0;
1645 }
1646 
sock_opt_protocol_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1647 static int sock_opt_protocol_get(struct tls_context *context,
1648 				 void *optval, net_socklen_t *optlen)
1649 {
1650 	int protocol = (int)context->tls_version;
1651 
1652 	if (*optlen != sizeof(protocol)) {
1653 		return -EINVAL;
1654 	}
1655 
1656 	*(int *)optval = protocol;
1657 
1658 	return 0;
1659 }
1660 
tls_opt_sec_tag_list_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1661 static int tls_opt_sec_tag_list_get(struct tls_context *context,
1662 				    void *optval, net_socklen_t *optlen)
1663 {
1664 	int len;
1665 
1666 	if (*optlen % sizeof(sec_tag_t) != 0 || *optlen == 0) {
1667 		return -EINVAL;
1668 	}
1669 
1670 	len = MIN(context->options.sec_tag_list.sec_tag_count *
1671 		  sizeof(sec_tag_t), *optlen);
1672 
1673 	memcpy(optval, context->options.sec_tag_list.sec_tags, len);
1674 	*optlen = len;
1675 
1676 	return 0;
1677 }
1678 
tls_opt_hostname_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1679 static int tls_opt_hostname_set(struct tls_context *context,
1680 				const void *optval, net_socklen_t optlen)
1681 {
1682 	ARG_UNUSED(optlen);
1683 
1684 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1685 	if (mbedtls_ssl_set_hostname(&context->ssl, optval) != 0) {
1686 		return -EINVAL;
1687 	}
1688 #else
1689 	return -ENOPROTOOPT;
1690 #endif
1691 
1692 	context->options.is_hostname_set = true;
1693 
1694 	return 0;
1695 }
1696 
tls_opt_ciphersuite_list_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1697 static int tls_opt_ciphersuite_list_set(struct tls_context *context,
1698 					const void *optval, net_socklen_t optlen)
1699 {
1700 	int cipher_cnt;
1701 
1702 	if (!optval) {
1703 		return -EINVAL;
1704 	}
1705 
1706 	if (optlen % sizeof(int) != 0) {
1707 		return -EINVAL;
1708 	}
1709 
1710 	cipher_cnt = optlen / sizeof(int);
1711 
1712 	/* + 1 for 0-termination. */
1713 	if (cipher_cnt + 1 > ARRAY_SIZE(context->options.ciphersuites)) {
1714 		return -EINVAL;
1715 	}
1716 
1717 	memcpy(context->options.ciphersuites, optval, optlen);
1718 	context->options.ciphersuites[cipher_cnt] = 0;
1719 
1720 	mbedtls_ssl_conf_ciphersuites(&context->config,
1721 				      context->options.ciphersuites);
1722 	return 0;
1723 }
1724 
tls_opt_ciphersuite_list_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1725 static int tls_opt_ciphersuite_list_get(struct tls_context *context,
1726 					void *optval, net_socklen_t *optlen)
1727 {
1728 	const int *selected_ciphers;
1729 	int cipher_cnt, i = 0;
1730 	int *ciphers = optval;
1731 
1732 	if (*optlen % sizeof(int) != 0 || *optlen == 0) {
1733 		return -EINVAL;
1734 	}
1735 
1736 	if (context->options.ciphersuites[0] == 0) {
1737 		/* No specific ciphersuites configured, return all available. */
1738 		selected_ciphers = mbedtls_ssl_list_ciphersuites();
1739 	} else {
1740 		selected_ciphers = context->options.ciphersuites;
1741 	}
1742 
1743 	cipher_cnt = *optlen / sizeof(int);
1744 	while (selected_ciphers[i] != 0) {
1745 		ciphers[i] = selected_ciphers[i];
1746 
1747 		if (++i == cipher_cnt) {
1748 			break;
1749 		}
1750 	}
1751 
1752 	*optlen = i * sizeof(int);
1753 
1754 	return 0;
1755 }
1756 
tls_opt_ciphersuite_used_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1757 static int tls_opt_ciphersuite_used_get(struct tls_context *context,
1758 					void *optval, net_socklen_t *optlen)
1759 {
1760 	const char *ciph;
1761 
1762 	if (*optlen != sizeof(int)) {
1763 		return -EINVAL;
1764 	}
1765 
1766 	ciph = mbedtls_ssl_get_ciphersuite(&context->ssl);
1767 	if (ciph == NULL) {
1768 		return -ENOTCONN;
1769 	}
1770 
1771 	*(int *)optval = mbedtls_ssl_get_ciphersuite_id(ciph);
1772 
1773 	return 0;
1774 }
1775 
tls_opt_alpn_list_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1776 static int tls_opt_alpn_list_set(struct tls_context *context,
1777 				 const void *optval, net_socklen_t optlen)
1778 {
1779 	int alpn_cnt;
1780 
1781 	if (!ALPN_MAX_PROTOCOLS) {
1782 		return -EINVAL;
1783 	}
1784 
1785 	if (!optval) {
1786 		return -EINVAL;
1787 	}
1788 
1789 	if (optlen % sizeof(const char *) != 0) {
1790 		return -EINVAL;
1791 	}
1792 
1793 	alpn_cnt = optlen / sizeof(const char *);
1794 	/* + 1 for NULL-termination. */
1795 	if (alpn_cnt + 1 > ARRAY_SIZE(context->options.alpn_list)) {
1796 		return -EINVAL;
1797 	}
1798 
1799 	memcpy(context->options.alpn_list, optval, optlen);
1800 	context->options.alpn_list[alpn_cnt] = NULL;
1801 
1802 	return 0;
1803 }
1804 
1805 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
tls_opt_dtls_handshake_timeout_get(struct tls_context * context,void * optval,net_socklen_t * optlen,bool is_max)1806 static int tls_opt_dtls_handshake_timeout_get(struct tls_context *context,
1807 					      void *optval, net_socklen_t *optlen,
1808 					      bool is_max)
1809 {
1810 	uint32_t *val = (uint32_t *)optval;
1811 
1812 	if (sizeof(uint32_t) != *optlen) {
1813 		return -EINVAL;
1814 	}
1815 
1816 	if (is_max) {
1817 		*val = context->options.dtls_handshake_timeout_max;
1818 	} else {
1819 		*val = context->options.dtls_handshake_timeout_min;
1820 	}
1821 
1822 	return 0;
1823 }
1824 
tls_opt_dtls_handshake_timeout_set(struct tls_context * context,const void * optval,net_socklen_t optlen,bool is_max)1825 static int tls_opt_dtls_handshake_timeout_set(struct tls_context *context,
1826 					      const void *optval,
1827 					      net_socklen_t optlen, bool is_max)
1828 {
1829 	uint32_t *val = (uint32_t *)optval;
1830 
1831 	if (!optval) {
1832 		return -EINVAL;
1833 	}
1834 
1835 	if (sizeof(uint32_t) != optlen) {
1836 		return -EINVAL;
1837 	}
1838 
1839 	/* If mbedTLS context not inited, it will
1840 	 * use these values upon init.
1841 	 */
1842 	if (is_max) {
1843 		context->options.dtls_handshake_timeout_max = *val;
1844 	} else {
1845 		context->options.dtls_handshake_timeout_min = *val;
1846 	}
1847 
1848 	/* If mbedTLS context already inited, we need to
1849 	 * update mbedTLS config for it to take effect
1850 	 */
1851 	mbedtls_ssl_conf_handshake_timeout(&context->config,
1852 			context->options.dtls_handshake_timeout_min,
1853 			context->options.dtls_handshake_timeout_max);
1854 
1855 	return 0;
1856 }
1857 
tls_opt_dtls_connection_id_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1858 static int tls_opt_dtls_connection_id_set(struct tls_context *context,
1859 					  const void *optval, net_socklen_t optlen)
1860 {
1861 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1862 	int value;
1863 
1864 	if (optlen > 0 && optval == NULL) {
1865 		return -EINVAL;
1866 	}
1867 
1868 	if (optlen != sizeof(int)) {
1869 		return -EINVAL;
1870 	}
1871 
1872 	value = *((int *)optval);
1873 
1874 	switch (value) {
1875 	case ZSOCK_TLS_DTLS_CID_DISABLED:
1876 		context->options.dtls_cid.enabled = false;
1877 		context->options.dtls_cid.cid_len = 0;
1878 		break;
1879 	case ZSOCK_TLS_DTLS_CID_SUPPORTED:
1880 		context->options.dtls_cid.enabled = true;
1881 		context->options.dtls_cid.cid_len = 0;
1882 		break;
1883 	case ZSOCK_TLS_DTLS_CID_ENABLED:
1884 		context->options.dtls_cid.enabled = true;
1885 		if (context->options.dtls_cid.cid_len == 0) {
1886 			/* generate random self cid */
1887 #if defined(CONFIG_CSPRNG_ENABLED)
1888 			sys_csrand_get(context->options.dtls_cid.cid,
1889 				       MBEDTLS_SSL_CID_OUT_LEN_MAX);
1890 #else
1891 			sys_rand_get(context->options.dtls_cid.cid,
1892 				     MBEDTLS_SSL_CID_OUT_LEN_MAX);
1893 #endif
1894 			context->options.dtls_cid.cid_len = MBEDTLS_SSL_CID_OUT_LEN_MAX;
1895 		}
1896 		break;
1897 	default:
1898 		return -EINVAL;
1899 	}
1900 
1901 	return 0;
1902 #else
1903 	return -ENOPROTOOPT;
1904 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1905 }
1906 
tls_opt_dtls_connection_id_value_set(struct tls_context * context,const void * optval,net_socklen_t optlen)1907 static int tls_opt_dtls_connection_id_value_set(struct tls_context *context,
1908 						const void *optval,
1909 						net_socklen_t optlen)
1910 {
1911 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1912 	if (optlen > 0 && optval == NULL) {
1913 		return -EINVAL;
1914 	}
1915 
1916 	if (optlen > MBEDTLS_SSL_CID_IN_LEN_MAX) {
1917 		return -EINVAL;
1918 	}
1919 
1920 	context->options.dtls_cid.cid_len = optlen;
1921 	memcpy(context->options.dtls_cid.cid, optval, optlen);
1922 
1923 	return 0;
1924 #else
1925 	return -ENOPROTOOPT;
1926 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1927 }
1928 
tls_opt_dtls_connection_id_value_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1929 static int tls_opt_dtls_connection_id_value_get(struct tls_context *context,
1930 						void *optval, net_socklen_t *optlen)
1931 {
1932 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1933 
1934 	if (*optlen < context->options.dtls_cid.cid_len) {
1935 		return -EINVAL;
1936 	}
1937 
1938 	*optlen = context->options.dtls_cid.cid_len;
1939 	memcpy(optval, context->options.dtls_cid.cid, *optlen);
1940 
1941 	return 0;
1942 #else
1943 	return -ENOPROTOOPT;
1944 #endif
1945 }
1946 
tls_opt_dtls_peer_connection_id_value_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1947 static int tls_opt_dtls_peer_connection_id_value_get(struct tls_context *context,
1948 						     void *optval,
1949 						     net_socklen_t *optlen)
1950 {
1951 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1952 	int enabled = false;
1953 	int ret;
1954 	size_t optlen_local;
1955 
1956 	if (!context->is_initialized) {
1957 		return -ENOTCONN;
1958 	}
1959 
1960 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled, optval, &optlen_local);
1961 	if (enabled) {
1962 		*optlen = optlen_local;
1963 	} else {
1964 		*optlen = 0;
1965 	}
1966 	return ret;
1967 #else
1968 	return -ENOPROTOOPT;
1969 #endif
1970 }
1971 
tls_opt_dtls_connection_id_status_get(struct tls_context * context,void * optval,net_socklen_t * optlen)1972 static int tls_opt_dtls_connection_id_status_get(struct tls_context *context,
1973 					  void *optval, net_socklen_t *optlen)
1974 {
1975 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1976 	struct tls_dtls_cid cid;
1977 	int ret;
1978 	int val;
1979 	int enabled;
1980 	bool have_self_cid;
1981 	bool have_peer_cid;
1982 
1983 	if (sizeof(int) != *optlen) {
1984 		return -EINVAL;
1985 	}
1986 
1987 	if (!context->is_initialized) {
1988 		return -ENOTCONN;
1989 	}
1990 
1991 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled,
1992 				       cid.cid,
1993 				       &cid.cid_len);
1994 	if (ret) {
1995 		/* Handshake is not complete */
1996 		return -EAGAIN;
1997 	}
1998 
1999 	cid.enabled = (enabled == MBEDTLS_SSL_CID_ENABLED);
2000 	have_self_cid = (context->options.dtls_cid.cid_len != 0);
2001 	have_peer_cid = (cid.cid_len != 0);
2002 
2003 	if (!context->options.dtls_cid.enabled) {
2004 		val = ZSOCK_TLS_DTLS_CID_STATUS_DISABLED;
2005 	} else if (have_self_cid && have_peer_cid) {
2006 		val = ZSOCK_TLS_DTLS_CID_STATUS_BIDIRECTIONAL;
2007 	} else if (have_self_cid) {
2008 		val = ZSOCK_TLS_DTLS_CID_STATUS_DOWNLINK;
2009 	} else if (have_peer_cid) {
2010 		val = ZSOCK_TLS_DTLS_CID_STATUS_UPLINK;
2011 	} else {
2012 		val = ZSOCK_TLS_DTLS_CID_STATUS_DISABLED;
2013 	}
2014 
2015 	*((int *)optval) = val;
2016 	return 0;
2017 #else
2018 	return -ENOPROTOOPT;
2019 #endif
2020 }
2021 
tls_opt_dtls_handshake_on_connect_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2022 static int tls_opt_dtls_handshake_on_connect_set(struct tls_context *context,
2023 						 const void *optval,
2024 						 net_socklen_t optlen)
2025 {
2026 	int *val = (int *)optval;
2027 
2028 	if (!optval) {
2029 		return -EINVAL;
2030 	}
2031 
2032 	if (sizeof(int) != optlen) {
2033 		return -EINVAL;
2034 	}
2035 
2036 	context->options.dtls_handshake_on_connect = (bool)*val;
2037 
2038 	return 0;
2039 }
2040 
tls_opt_dtls_handshake_on_connect_get(struct tls_context * context,void * optval,net_socklen_t * optlen)2041 static int tls_opt_dtls_handshake_on_connect_get(struct tls_context *context,
2042 						 void *optval,
2043 						 net_socklen_t *optlen)
2044 {
2045 	if (*optlen != sizeof(int)) {
2046 		return -EINVAL;
2047 	}
2048 
2049 	*(int *)optval = context->options.dtls_handshake_on_connect;
2050 
2051 	return 0;
2052 }
2053 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2054 
tls_opt_alpn_list_get(struct tls_context * context,void * optval,net_socklen_t * optlen)2055 static int tls_opt_alpn_list_get(struct tls_context *context,
2056 				 void *optval, net_socklen_t *optlen)
2057 {
2058 	const char **alpn_list = context->options.alpn_list;
2059 	int alpn_cnt, i = 0;
2060 	const char **ret_list = optval;
2061 
2062 	if (!ALPN_MAX_PROTOCOLS) {
2063 		return -EINVAL;
2064 	}
2065 
2066 	if (*optlen % sizeof(const char *) != 0 || *optlen == 0) {
2067 		return -EINVAL;
2068 	}
2069 
2070 	alpn_cnt = *optlen / sizeof(const char *);
2071 	while (alpn_list[i] != NULL) {
2072 		ret_list[i] = alpn_list[i];
2073 
2074 		if (++i == alpn_cnt) {
2075 			break;
2076 		}
2077 	}
2078 
2079 	*optlen = i * sizeof(const char *);
2080 
2081 	return 0;
2082 }
2083 
tls_opt_session_cache_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2084 static int tls_opt_session_cache_set(struct tls_context *context,
2085 				     const void *optval, net_socklen_t optlen)
2086 {
2087 	int *val = (int *)optval;
2088 
2089 	if (!optval) {
2090 		return -EINVAL;
2091 	}
2092 
2093 	if (sizeof(int) != optlen) {
2094 		return -EINVAL;
2095 	}
2096 
2097 	context->options.cache_enabled = (*val == ZSOCK_TLS_SESSION_CACHE_ENABLED);
2098 
2099 	return 0;
2100 }
2101 
tls_opt_session_cache_get(struct tls_context * context,void * optval,net_socklen_t * optlen)2102 static int tls_opt_session_cache_get(struct tls_context *context,
2103 				     void *optval, net_socklen_t *optlen)
2104 {
2105 	int cache_enabled = context->options.cache_enabled ?
2106 			    ZSOCK_TLS_SESSION_CACHE_ENABLED :
2107 			    ZSOCK_TLS_SESSION_CACHE_DISABLED;
2108 
2109 	if (*optlen != sizeof(cache_enabled)) {
2110 		return -EINVAL;
2111 	}
2112 
2113 	*(int *)optval = cache_enabled;
2114 
2115 	return 0;
2116 }
2117 
tls_opt_cert_verify_result_get(struct tls_context * context,void * optval,net_socklen_t * optlen)2118 static int tls_opt_cert_verify_result_get(struct tls_context *context,
2119 					  void *optval, net_socklen_t *optlen)
2120 {
2121 	if (*optlen != sizeof(uint32_t)) {
2122 		return -EINVAL;
2123 	}
2124 
2125 	*(uint32_t *)optval = mbedtls_ssl_get_verify_result(&context->ssl);
2126 
2127 	return 0;
2128 }
2129 
tls_opt_session_cache_purge_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2130 static int tls_opt_session_cache_purge_set(struct tls_context *context,
2131 					   const void *optval, net_socklen_t optlen)
2132 {
2133 	ARG_UNUSED(context);
2134 	ARG_UNUSED(optval);
2135 	ARG_UNUSED(optlen);
2136 
2137 	tls_session_purge();
2138 
2139 	return 0;
2140 }
2141 
tls_opt_peer_verify_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2142 static int tls_opt_peer_verify_set(struct tls_context *context,
2143 				   const void *optval, net_socklen_t optlen)
2144 {
2145 	int *peer_verify;
2146 
2147 	if (!optval) {
2148 		return -EINVAL;
2149 	}
2150 
2151 	if (optlen != sizeof(int)) {
2152 		return -EINVAL;
2153 	}
2154 
2155 	peer_verify = (int *)optval;
2156 
2157 	if (*peer_verify != MBEDTLS_SSL_VERIFY_NONE &&
2158 	    *peer_verify != MBEDTLS_SSL_VERIFY_OPTIONAL &&
2159 	    *peer_verify != MBEDTLS_SSL_VERIFY_REQUIRED) {
2160 		return -EINVAL;
2161 	}
2162 
2163 	context->options.verify_level = *peer_verify;
2164 
2165 	return 0;
2166 }
2167 
tls_opt_cert_nocopy_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2168 static int tls_opt_cert_nocopy_set(struct tls_context *context,
2169 				   const void *optval, net_socklen_t optlen)
2170 {
2171 	int *cert_nocopy;
2172 
2173 	if (!optval) {
2174 		return -EINVAL;
2175 	}
2176 
2177 	if (optlen != sizeof(int)) {
2178 		return -EINVAL;
2179 	}
2180 
2181 	cert_nocopy = (int *)optval;
2182 
2183 	if (*cert_nocopy != ZSOCK_TLS_CERT_NOCOPY_NONE &&
2184 	    *cert_nocopy != ZSOCK_TLS_CERT_NOCOPY_OPTIONAL) {
2185 		return -EINVAL;
2186 	}
2187 
2188 	context->options.cert_nocopy = *cert_nocopy;
2189 
2190 	return 0;
2191 }
2192 
tls_opt_dtls_role_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2193 static int tls_opt_dtls_role_set(struct tls_context *context,
2194 				 const void *optval, net_socklen_t optlen)
2195 {
2196 	int *role;
2197 
2198 	if (!optval) {
2199 		return -EINVAL;
2200 	}
2201 
2202 	if (optlen != sizeof(int)) {
2203 		return -EINVAL;
2204 	}
2205 
2206 	role = (int *)optval;
2207 	if (*role != MBEDTLS_SSL_IS_CLIENT &&
2208 	    *role != MBEDTLS_SSL_IS_SERVER) {
2209 		return -EINVAL;
2210 	}
2211 
2212 	context->options.role = *role;
2213 
2214 	return 0;
2215 }
2216 
2217 #if defined(CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK)
tls_opt_cert_verify_callback_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2218 static int tls_opt_cert_verify_callback_set(struct tls_context *context,
2219 					    const void *optval,
2220 					    net_socklen_t optlen)
2221 {
2222 	struct tls_cert_verify_cb *cert_verify;
2223 
2224 	if (!optval) {
2225 		return -EINVAL;
2226 	}
2227 
2228 	if (optlen != sizeof(struct tls_cert_verify_cb)) {
2229 		return -EINVAL;
2230 	}
2231 
2232 	cert_verify = (struct tls_cert_verify_cb *)optval;
2233 	if (cert_verify->cb == NULL) {
2234 		return -EINVAL;
2235 	}
2236 
2237 	context->options.cert_verify = *cert_verify;
2238 
2239 	return 0;
2240 }
2241 #else /* CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK */
tls_opt_cert_verify_callback_set(struct tls_context * context,const void * optval,net_socklen_t optlen)2242 static int tls_opt_cert_verify_callback_set(struct tls_context *context,
2243 					    const void *optval,
2244 					    net_socklen_t optlen)
2245 {
2246 	NET_ERR("TLS_CERT_VERIFY_CALLBACK option requires "
2247 		"CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK enabled");
2248 
2249 	return -ENOPROTOOPT;
2250 }
2251 #endif /* CONFIG_NET_SOCKETS_TLS_CERT_VERIFY_CALLBACK */
2252 
protocol_check(int family,int type,int * proto)2253 static int protocol_check(int family, int type, int *proto)
2254 {
2255 	if (family != NET_AF_INET && family != NET_AF_INET6) {
2256 		return -EAFNOSUPPORT;
2257 	}
2258 
2259 	if (*proto >= NET_IPPROTO_TLS_1_0 && *proto <= NET_IPPROTO_TLS_1_3) {
2260 		if (type != NET_SOCK_STREAM) {
2261 			return -EPROTOTYPE;
2262 		}
2263 
2264 		*proto = NET_IPPROTO_TCP;
2265 	} else if (*proto >= NET_IPPROTO_DTLS_1_0 && *proto <= NET_IPPROTO_DTLS_1_2) {
2266 		if (!IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS)) {
2267 			return -EPROTONOSUPPORT;
2268 		}
2269 
2270 		if (type != NET_SOCK_DGRAM) {
2271 			return -EPROTOTYPE;
2272 		}
2273 
2274 		*proto = NET_IPPROTO_UDP;
2275 	} else {
2276 		return -EPROTONOSUPPORT;
2277 	}
2278 
2279 	return 0;
2280 }
2281 
ztls_socket(int family,int type,int proto)2282 static int ztls_socket(int family, int type, int proto)
2283 {
2284 	enum net_ip_protocol_secure tls_proto = proto;
2285 	int fd = zvfs_reserve_fd();
2286 	int sock = -1;
2287 	int ret;
2288 	struct tls_context *ctx;
2289 
2290 	if (fd < 0) {
2291 		return -1;
2292 	}
2293 
2294 	ret = protocol_check(family, type, &proto);
2295 	if (ret < 0) {
2296 		errno = -ret;
2297 		goto free_fd;
2298 	}
2299 
2300 	ctx = tls_alloc();
2301 	if (ctx == NULL) {
2302 		errno = ENOMEM;
2303 		goto free_fd;
2304 	}
2305 
2306 	sock = zsock_socket(family, type, proto);
2307 	if (sock < 0) {
2308 		goto release_tls;
2309 	}
2310 
2311 	ctx->tls_version = tls_proto;
2312 	ctx->type = (proto == NET_IPPROTO_TCP) ? NET_SOCK_STREAM : NET_SOCK_DGRAM;
2313 	ctx->sock = sock;
2314 
2315 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2316 			    ZVFS_MODE_IFSOCK);
2317 
2318 	return fd;
2319 
2320 release_tls:
2321 	(void)tls_release(ctx);
2322 
2323 free_fd:
2324 	zvfs_free_fd(fd);
2325 
2326 	return -1;
2327 }
2328 
ztls_close_ctx(struct tls_context * ctx,int sock)2329 int ztls_close_ctx(struct tls_context *ctx, int sock)
2330 {
2331 	int ret, err = 0;
2332 
2333 	/* Try to send close notification. */
2334 	ctx->flags = 0;
2335 
2336 	(void)mbedtls_ssl_close_notify(&ctx->ssl);
2337 
2338 	err = tls_release(ctx);
2339 	ret = zsock_close(ctx->sock);
2340 
2341 	if (ret == 0) {
2342 		(void)sock_obj_core_dealloc(sock);
2343 	}
2344 
2345 	/* In case close fails, we propagate errno value set by close.
2346 	 * In case close succeeds, but tls_release fails, set errno
2347 	 * according to tls_release return value.
2348 	 */
2349 	if (ret == 0 && err < 0) {
2350 		errno = -err;
2351 		ret = -1;
2352 	}
2353 
2354 	return ret;
2355 }
2356 
ztls_connect_ctx(struct tls_context * ctx,const struct net_sockaddr * addr,net_socklen_t addrlen)2357 int ztls_connect_ctx(struct tls_context *ctx, const struct net_sockaddr *addr,
2358 		     net_socklen_t addrlen)
2359 {
2360 	int ret;
2361 	int sock_flags;
2362 	bool is_non_block;
2363 
2364 	sock_flags = zsock_fcntl(ctx->sock, ZVFS_F_GETFL, 0);
2365 	if (sock_flags < 0) {
2366 		return -EIO;
2367 	}
2368 
2369 	is_non_block = sock_flags & ZVFS_O_NONBLOCK;
2370 	if (is_non_block) {
2371 		(void)zsock_fcntl(ctx->sock, ZVFS_F_SETFL,
2372 				  sock_flags & ~ZVFS_O_NONBLOCK);
2373 	}
2374 
2375 	ret = zsock_connect(ctx->sock, addr, addrlen);
2376 	if (ret < 0) {
2377 		return ret;
2378 	}
2379 
2380 	if (is_non_block) {
2381 		(void)zsock_fcntl(ctx->sock, ZVFS_F_SETFL, sock_flags);
2382 	}
2383 
2384 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2385 	if (ctx->type == NET_SOCK_DGRAM) {
2386 		dtls_peer_address_set(ctx, addr, addrlen);
2387 	}
2388 #endif
2389 
2390 	if (ctx->type == NET_SOCK_STREAM
2391 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2392 	    || (ctx->type == NET_SOCK_DGRAM && ctx->options.dtls_handshake_on_connect)
2393 #endif
2394 	    ) {
2395 		ret = tls_mbedtls_init(ctx, false);
2396 		if (ret < 0) {
2397 			goto error;
2398 		}
2399 
2400 		/* Do not use any socket flags during the handshake. */
2401 		ctx->flags = 0;
2402 
2403 		tls_session_restore(ctx, addr, addrlen);
2404 
2405 		/* TODO For simplicity, TLS handshake blocks the socket
2406 		 * even for non-blocking socket.
2407 		 */
2408 		ret = tls_mbedtls_handshake(
2409 			ctx, K_MSEC(CONFIG_NET_SOCKETS_TLS_CONNECT_TIMEOUT));
2410 		if (ret < 0) {
2411 			if ((ret == -EAGAIN) && !is_non_block) {
2412 				ret = -ETIMEDOUT;
2413 			}
2414 
2415 			goto error;
2416 		}
2417 
2418 		tls_session_store(ctx, addr, addrlen);
2419 	}
2420 
2421 	return 0;
2422 
2423 error:
2424 	errno = -ret;
2425 	return -1;
2426 }
2427 
ztls_accept_ctx(struct tls_context * parent,struct net_sockaddr * addr,net_socklen_t * addrlen)2428 int ztls_accept_ctx(struct tls_context *parent, struct net_sockaddr *addr,
2429 		    net_socklen_t *addrlen)
2430 {
2431 	struct tls_context *child = NULL;
2432 	int ret, err, fd, sock;
2433 
2434 	fd = zvfs_reserve_fd();
2435 	if (fd < 0) {
2436 		return -1;
2437 	}
2438 
2439 
2440 	k_mutex_unlock(parent->lock);
2441 	sock = zsock_accept(parent->sock, addr, addrlen);
2442 	k_mutex_lock(parent->lock, K_FOREVER);
2443 	if (sock < 0) {
2444 		ret = -errno;
2445 		goto error;
2446 	}
2447 
2448 	child = tls_clone(parent);
2449 	if (child == NULL) {
2450 		ret = -ENOMEM;
2451 		goto error;
2452 	}
2453 
2454 	zvfs_finalize_typed_fd(fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2455 			    ZVFS_MODE_IFSOCK);
2456 
2457 	child->sock = sock;
2458 
2459 	ret = tls_mbedtls_init(child, true);
2460 	if (ret < 0) {
2461 		goto error;
2462 	}
2463 
2464 	/* Do not use any socket flags during the handshake. */
2465 	child->flags = 0;
2466 
2467 	/* TODO For simplicity, TLS handshake blocks the socket even for
2468 	 * non-blocking socket.
2469 	 */
2470 	ret = tls_mbedtls_handshake(
2471 		child, K_MSEC(CONFIG_NET_SOCKETS_TLS_CONNECT_TIMEOUT));
2472 	if (ret < 0) {
2473 		if ((ret == -EAGAIN) && is_blocking(parent->sock, 0)) {
2474 			ret = -ETIMEDOUT;
2475 		}
2476 
2477 		goto error;
2478 	}
2479 
2480 	return fd;
2481 
2482 error:
2483 	if (child != NULL) {
2484 		err = tls_release(child);
2485 		__ASSERT(err == 0, "TLS context release failed");
2486 	}
2487 
2488 	if (sock >= 0) {
2489 		err = zsock_close(sock);
2490 		__ASSERT(err == 0, "Child socket close failed");
2491 	}
2492 
2493 	zvfs_free_fd(fd);
2494 
2495 	errno = -ret;
2496 	return -1;
2497 }
2498 
send_tls(struct tls_context * ctx,const void * buf,size_t len,int flags)2499 static ssize_t send_tls(struct tls_context *ctx, const void *buf,
2500 			size_t len, int flags)
2501 {
2502 	const bool is_block = is_blocking(ctx->sock, flags);
2503 	k_timeout_t timeout;
2504 	k_timepoint_t end;
2505 	int ret;
2506 
2507 	if (ctx->error != 0) {
2508 		errno = ctx->error;
2509 		return -1;
2510 	}
2511 
2512 	if (ctx->session_closed) {
2513 		errno = ECONNABORTED;
2514 		return -1;
2515 	}
2516 
2517 	if (!is_block) {
2518 		timeout = K_NO_WAIT;
2519 	} else {
2520 		timeout = ctx->options.timeout_tx;
2521 	}
2522 
2523 	end = sys_timepoint_calc(timeout);
2524 
2525 	do {
2526 		ret = mbedtls_ssl_write(&ctx->ssl, buf, len);
2527 		if (ret >= 0) {
2528 			return ret;
2529 		}
2530 
2531 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2532 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2533 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2534 		    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2535 			int timeout_ms;
2536 
2537 			if (!is_block) {
2538 				errno = EAGAIN;
2539 				break;
2540 			}
2541 
2542 			/* Blocking timeout. */
2543 			timeout = sys_timepoint_timeout(end);
2544 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2545 				errno = EAGAIN;
2546 				break;
2547 			}
2548 
2549 			/* Block. */
2550 			timeout_ms = timeout_to_ms(&timeout);
2551 			ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2552 			if (ret != 0) {
2553 				errno = -ret;
2554 				break;
2555 			}
2556 		} else {
2557 			NET_ERR("TLS send error: -%x", -ret);
2558 
2559 			/* MbedTLS API documentation requires session to
2560 			 * be reset in other error cases
2561 			 */
2562 			ret = tls_mbedtls_reset(ctx);
2563 			if (ret != 0) {
2564 				ctx->error = ENOMEM;
2565 				errno = ENOMEM;
2566 			} else {
2567 				ctx->error = ECONNABORTED;
2568 				errno = ECONNABORTED;
2569 			}
2570 
2571 			break;
2572 		}
2573 	} while (true);
2574 
2575 	return -1;
2576 }
2577 
2578 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
sendto_dtls_client(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)2579 static ssize_t sendto_dtls_client(struct tls_context *ctx, const void *buf,
2580 				  size_t len, int flags,
2581 				  const struct net_sockaddr *dest_addr,
2582 				  net_socklen_t addrlen)
2583 {
2584 	int ret;
2585 
2586 	if (!dest_addr) {
2587 		/* No address provided, check if we have stored one,
2588 		 * otherwise return error.
2589 		 */
2590 		if (ctx->dtls_peer_addrlen == 0) {
2591 			ret = -EDESTADDRREQ;
2592 			goto error;
2593 		}
2594 	} else if (ctx->dtls_peer_addrlen == 0) {
2595 		/* Address provided and no peer address stored. */
2596 		dtls_peer_address_set(ctx, dest_addr, addrlen);
2597 	} else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2598 		/* Address provided but it does not match stored one */
2599 		ret = -EISCONN;
2600 		goto error;
2601 	}
2602 
2603 	if (!ctx->is_initialized) {
2604 		ret = tls_mbedtls_init(ctx, false);
2605 		if (ret < 0) {
2606 			goto error;
2607 		}
2608 	}
2609 
2610 	if (!is_handshake_complete(ctx)) {
2611 		tls_session_restore(ctx, &ctx->dtls_peer_addr,
2612 				    ctx->dtls_peer_addrlen);
2613 
2614 		/* TODO For simplicity, TLS handshake blocks the socket even for
2615 		 * non-blocking socket.
2616 		 * DTLS handshake timeout/retransmissions are limited by
2617 		 * mbed TLS, so K_FOREVER is fine here, the function will not
2618 		 * block indefinitely.
2619 		 */
2620 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2621 		if (ret < 0) {
2622 			goto error;
2623 		}
2624 
2625 		/* Client socket ready to use again. */
2626 		ctx->error = 0;
2627 
2628 		tls_session_store(ctx, &ctx->dtls_peer_addr,
2629 				  ctx->dtls_peer_addrlen);
2630 	}
2631 
2632 	return send_tls(ctx, buf, len, flags);
2633 
2634 error:
2635 	errno = -ret;
2636 	return -1;
2637 }
2638 
sendto_dtls_server(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)2639 static ssize_t sendto_dtls_server(struct tls_context *ctx, const void *buf,
2640 				  size_t len, int flags,
2641 				  const struct net_sockaddr *dest_addr,
2642 				  net_socklen_t addrlen)
2643 {
2644 	/* For DTLS server, require to have established DTLS connection
2645 	 * in order to send data.
2646 	 */
2647 	if (!is_handshake_complete(ctx)) {
2648 		errno = ENOTCONN;
2649 		return -1;
2650 	}
2651 
2652 	/* Verify we are sending to a peer that we have connection with. */
2653 	if (dest_addr &&
2654 	    !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2655 		errno = EISCONN;
2656 		return -1;
2657 	}
2658 
2659 	return send_tls(ctx, buf, len, flags);
2660 }
2661 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2662 
ztls_sendto_ctx(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)2663 ssize_t ztls_sendto_ctx(struct tls_context *ctx, const void *buf, size_t len,
2664 			int flags, const struct net_sockaddr *dest_addr,
2665 			net_socklen_t addrlen)
2666 {
2667 	ctx->flags = flags;
2668 
2669 	/* TLS */
2670 	if (ctx->type == NET_SOCK_STREAM) {
2671 		return send_tls(ctx, buf, len, flags);
2672 	}
2673 
2674 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2675 	/* DTLS */
2676 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2677 		return sendto_dtls_server(ctx, buf, len, flags,
2678 					  dest_addr, addrlen);
2679 	}
2680 
2681 	return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen);
2682 #else
2683 	errno = ENOTSUP;
2684 	return -1;
2685 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2686 }
2687 
dtls_sendmsg_merge_and_send(struct tls_context * ctx,const struct net_msghdr * msg,int flags)2688 static ssize_t dtls_sendmsg_merge_and_send(struct tls_context *ctx,
2689 					   const struct net_msghdr *msg,
2690 					   int flags)
2691 {
2692 	static K_MUTEX_DEFINE(sendmsg_lock);
2693 	static uint8_t sendmsg_buf[DTLS_SENDMSG_BUF_SIZE];
2694 	ssize_t len = 0;
2695 
2696 	k_mutex_lock(&sendmsg_lock, K_FOREVER);
2697 
2698 	for (int i = 0; i < msg->msg_iovlen; i++) {
2699 		struct net_iovec *vec = msg->msg_iov + i;
2700 
2701 		if (vec->iov_len > 0) {
2702 			if (len + vec->iov_len > sizeof(sendmsg_buf)) {
2703 				k_mutex_unlock(&sendmsg_lock);
2704 				errno = EMSGSIZE;
2705 				return -1;
2706 			}
2707 
2708 			memcpy(sendmsg_buf + len, vec->iov_base, vec->iov_len);
2709 			len += vec->iov_len;
2710 		}
2711 	}
2712 
2713 	if (len > 0) {
2714 		len = ztls_sendto_ctx(ctx, sendmsg_buf, len, flags,
2715 				      msg->msg_name, msg->msg_namelen);
2716 	}
2717 
2718 	k_mutex_unlock(&sendmsg_lock);
2719 
2720 	return len;
2721 }
2722 
tls_sendmsg_loop_and_send(struct tls_context * ctx,const struct net_msghdr * msg,int flags)2723 static ssize_t tls_sendmsg_loop_and_send(struct tls_context *ctx,
2724 					 const struct net_msghdr *msg,
2725 					 int flags)
2726 {
2727 	ssize_t len = 0;
2728 	ssize_t ret;
2729 
2730 	for (int i = 0; i < msg->msg_iovlen; i++) {
2731 		struct net_iovec *vec = msg->msg_iov + i;
2732 		size_t sent = 0;
2733 
2734 		if (vec->iov_len == 0) {
2735 			continue;
2736 		}
2737 
2738 		while (sent < vec->iov_len) {
2739 			uint8_t *ptr = (uint8_t *)vec->iov_base + sent;
2740 
2741 			ret = ztls_sendto_ctx(ctx, ptr, vec->iov_len - sent,
2742 					      flags, msg->msg_name,
2743 					      msg->msg_namelen);
2744 			if (ret < 0) {
2745 				return ret;
2746 			}
2747 			sent += ret;
2748 		}
2749 		len += sent;
2750 	}
2751 
2752 	return len;
2753 }
2754 
ztls_sendmsg_ctx(struct tls_context * ctx,const struct net_msghdr * msg,int flags)2755 ssize_t ztls_sendmsg_ctx(struct tls_context *ctx, const struct net_msghdr *msg,
2756 			 int flags)
2757 {
2758 	if (msg == NULL) {
2759 		errno = EINVAL;
2760 		return -1;
2761 	}
2762 
2763 	if (IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS) &&
2764 	    ctx->type == NET_SOCK_DGRAM) {
2765 		if (DTLS_SENDMSG_BUF_SIZE > 0) {
2766 			/* With one buffer only, there's no need to use
2767 			 * intermediate buffer.
2768 			 */
2769 			if (msghdr_non_empty_iov_count(msg) == 1) {
2770 				goto send_loop;
2771 			}
2772 
2773 			return dtls_sendmsg_merge_and_send(ctx, msg, flags);
2774 		}
2775 
2776 		/*
2777 		 * Current mbedTLS API (i.e. mbedtls_ssl_write()) allows only to send a single
2778 		 * contiguous buffer. This means that gather write using sendmsg() can only be
2779 		 * handled correctly if there is a single non-empty buffer in msg->msg_iov.
2780 		 */
2781 		if (msghdr_non_empty_iov_count(msg) > 1) {
2782 			errno = EMSGSIZE;
2783 			return -1;
2784 		}
2785 	}
2786 
2787 send_loop:
2788 	return tls_sendmsg_loop_and_send(ctx, msg, flags);
2789 }
2790 
recv_tls(struct tls_context * ctx,void * buf,size_t max_len,int flags)2791 static ssize_t recv_tls(struct tls_context *ctx, void *buf,
2792 			size_t max_len, int flags)
2793 {
2794 	size_t recv_len = 0;
2795 	const bool waitall = flags & ZSOCK_MSG_WAITALL;
2796 	const bool is_block = is_blocking(ctx->sock, flags);
2797 	k_timeout_t timeout;
2798 	k_timepoint_t end;
2799 	int ret;
2800 
2801 	if (ctx->error != 0) {
2802 		errno = ctx->error;
2803 		return -1;
2804 	}
2805 
2806 	if (ctx->session_closed) {
2807 		return 0;
2808 	}
2809 
2810 	if (!is_block) {
2811 		timeout = K_NO_WAIT;
2812 	} else {
2813 		timeout = ctx->options.timeout_rx;
2814 	}
2815 
2816 	end = sys_timepoint_calc(timeout);
2817 
2818 	do {
2819 		size_t read_len = max_len - recv_len;
2820 
2821 		ret = mbedtls_ssl_read(&ctx->ssl, (uint8_t *)buf + recv_len,
2822 				       read_len);
2823 		if (ret < 0) {
2824 			if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2825 				/* Peer notified that it's closing the
2826 				 * connection.
2827 				 */
2828 				ctx->session_closed = true;
2829 				break;
2830 			}
2831 
2832 			if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
2833 				/* Client reconnect on the same socket is not
2834 				 * supported. See mbedtls_ssl_read API
2835 				 * documentation.
2836 				 */
2837 				ctx->session_closed = true;
2838 				break;
2839 			}
2840 
2841 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2842 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2843 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2844 			    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
2845 			    ret == MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET) {
2846 				int timeout_ms;
2847 
2848 				if (!is_block) {
2849 					ret = -EAGAIN;
2850 					goto err;
2851 				}
2852 
2853 				/* Blocking timeout. */
2854 				timeout = sys_timepoint_timeout(end);
2855 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2856 					ret = -EAGAIN;
2857 					goto err;
2858 				}
2859 
2860 				timeout_ms = timeout_to_ms(&timeout);
2861 
2862 				/* Block. */
2863 				k_mutex_unlock(ctx->lock);
2864 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2865 				k_mutex_lock(ctx->lock, K_FOREVER);
2866 
2867 				if (ret == 0) {
2868 					/* Retry. */
2869 					continue;
2870 				}
2871 			} else {
2872 				NET_ERR("TLS recv error: -%x", -ret);
2873 				ret = -EIO;
2874 			}
2875 
2876 err:
2877 			errno = -ret;
2878 			return -1;
2879 		}
2880 
2881 		if (ret == 0) {
2882 			break;
2883 		}
2884 
2885 		recv_len += ret;
2886 	} while ((recv_len == 0) || (waitall && (recv_len < max_len)));
2887 
2888 	return recv_len;
2889 }
2890 
2891 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
recvfrom_dtls_common(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)2892 static ssize_t recvfrom_dtls_common(struct tls_context *ctx, void *buf,
2893 				    size_t max_len, int flags,
2894 				    struct net_sockaddr *src_addr,
2895 				    net_socklen_t *addrlen)
2896 {
2897 	int ret;
2898 	bool is_block = is_blocking(ctx->sock, flags);
2899 	k_timeout_t timeout;
2900 	k_timepoint_t end;
2901 
2902 	if (ctx->error != 0) {
2903 		errno = ctx->error;
2904 		return -1;
2905 	}
2906 
2907 	if (!is_block) {
2908 		timeout = K_NO_WAIT;
2909 	} else {
2910 		timeout = ctx->options.timeout_rx;
2911 	}
2912 
2913 	end = sys_timepoint_calc(timeout);
2914 
2915 	do {
2916 		size_t remaining;
2917 
2918 		ret = mbedtls_ssl_read(&ctx->ssl, buf, max_len);
2919 		if (ret < 0) {
2920 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2921 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2922 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2923 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2924 				int timeout_dtls, timeout_sock, timeout_ms;
2925 
2926 				if (!is_block) {
2927 					return ret;
2928 				}
2929 
2930 				/* Blocking timeout. */
2931 				timeout = sys_timepoint_timeout(end);
2932 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2933 					return ret;
2934 				}
2935 
2936 				timeout_dtls = dtls_get_remaining_timeout(ctx);
2937 				timeout_sock = timeout_to_ms(&timeout);
2938 				if (timeout_dtls == SYS_FOREVER_MS ||
2939 				    timeout_sock == SYS_FOREVER_MS) {
2940 					timeout_ms = MAX(timeout_dtls, timeout_sock);
2941 				} else {
2942 					timeout_ms = MIN(timeout_dtls, timeout_sock);
2943 				}
2944 
2945 				/* Block. */
2946 				k_mutex_unlock(ctx->lock);
2947 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2948 				k_mutex_lock(ctx->lock, K_FOREVER);
2949 
2950 				if (ret == 0) {
2951 					/* Retry. */
2952 					continue;
2953 				} else {
2954 					return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2955 				}
2956 			} else {
2957 				return ret;
2958 			}
2959 		}
2960 
2961 		if (src_addr && addrlen) {
2962 			dtls_peer_address_get(ctx, src_addr, addrlen);
2963 		}
2964 
2965 		/* mbedtls_ssl_get_bytes_avail() indicate the data length
2966 		 * remaining in the current datagram.
2967 		 */
2968 		remaining = mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2969 
2970 		/* No more data in the datagram, or dummy read. */
2971 		if ((remaining == 0) || (max_len == 0)) {
2972 			return ret;
2973 		}
2974 
2975 		if (flags & ZSOCK_MSG_TRUNC) {
2976 			ret += remaining;
2977 		}
2978 
2979 		for (int i = 0; i < remaining; i++) {
2980 			uint8_t byte;
2981 			int err;
2982 
2983 			err = mbedtls_ssl_read(&ctx->ssl, &byte, sizeof(byte));
2984 			if (err <= 0) {
2985 				NET_ERR("Error while flushing the rest of the"
2986 					" datagram, err %d", err);
2987 				ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2988 				break;
2989 			}
2990 		}
2991 
2992 		break;
2993 	} while (true);
2994 
2995 
2996 	return ret;
2997 }
2998 
recvfrom_dtls_client(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)2999 static ssize_t recvfrom_dtls_client(struct tls_context *ctx, void *buf,
3000 				    size_t max_len, int flags,
3001 				    struct net_sockaddr *src_addr,
3002 				    net_socklen_t *addrlen)
3003 {
3004 	int ret;
3005 
3006 	if (!is_handshake_complete(ctx)) {
3007 		ret = -ENOTCONN;
3008 		goto error;
3009 	}
3010 
3011 	ret = recvfrom_dtls_common(ctx, buf, max_len, flags, src_addr, addrlen);
3012 	if (ret >= 0) {
3013 		return ret;
3014 	}
3015 
3016 	switch (ret) {
3017 	case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
3018 		/* Peer notified that it's closing the connection. */
3019 		ret = tls_mbedtls_reset(ctx);
3020 		if (ret == 0) {
3021 			ctx->error = ENOTCONN;
3022 			ret = -ENOTCONN;
3023 		} else {
3024 			ctx->error = ENOMEM;
3025 			ret = -ENOMEM;
3026 		}
3027 		break;
3028 
3029 	case MBEDTLS_ERR_SSL_TIMEOUT:
3030 		(void)mbedtls_ssl_close_notify(&ctx->ssl);
3031 		ctx->error = ETIMEDOUT;
3032 		ret = -ETIMEDOUT;
3033 		break;
3034 
3035 	case MBEDTLS_ERR_SSL_WANT_READ:
3036 	case MBEDTLS_ERR_SSL_WANT_WRITE:
3037 	case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
3038 	case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
3039 		ret = -EAGAIN;
3040 		break;
3041 
3042 	default:
3043 		NET_ERR("DTLS client recv error: -%x", -ret);
3044 
3045 		/* MbedTLS API documentation requires session to
3046 		 * be reset in other error cases
3047 		 */
3048 		ret = tls_mbedtls_reset(ctx);
3049 		if (ret != 0) {
3050 			ctx->error = ENOMEM;
3051 			errno = ENOMEM;
3052 		} else {
3053 			ctx->error = ECONNABORTED;
3054 			ret = -ECONNABORTED;
3055 		}
3056 
3057 		break;
3058 	}
3059 
3060 error:
3061 	errno = -ret;
3062 	return -1;
3063 }
3064 
recvfrom_dtls_server(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)3065 static ssize_t recvfrom_dtls_server(struct tls_context *ctx, void *buf,
3066 				    size_t max_len, int flags,
3067 				    struct net_sockaddr *src_addr,
3068 				    net_socklen_t *addrlen)
3069 {
3070 	int ret;
3071 	bool repeat;
3072 	k_timeout_t timeout;
3073 
3074 	if (!ctx->is_initialized) {
3075 		ret = tls_mbedtls_init(ctx, true);
3076 		if (ret < 0) {
3077 			goto error;
3078 		}
3079 	}
3080 
3081 	if (is_blocking(ctx->sock, flags)) {
3082 		timeout = ctx->options.timeout_rx;
3083 	} else {
3084 		timeout = K_NO_WAIT;
3085 	}
3086 
3087 	/* Loop to enable DTLS reconnection for servers without closing
3088 	 * a socket.
3089 	 */
3090 	do {
3091 		repeat = false;
3092 
3093 		if (!is_handshake_complete(ctx)) {
3094 			ret = tls_mbedtls_handshake(ctx, timeout);
3095 			if (ret < 0) {
3096 				/* In case of EAGAIN, just exit. */
3097 				if (ret == -EAGAIN) {
3098 					break;
3099 				}
3100 
3101 				ret = tls_mbedtls_reset(ctx);
3102 				if (ret == 0) {
3103 					repeat = true;
3104 				} else {
3105 					ret = -ENOMEM;
3106 				}
3107 
3108 				continue;
3109 			}
3110 
3111 			/* Server socket ready to use again. */
3112 			ctx->error = 0;
3113 		}
3114 
3115 		ret = recvfrom_dtls_common(ctx, buf, max_len, flags,
3116 					   src_addr, addrlen);
3117 		if (ret >= 0) {
3118 			return ret;
3119 		}
3120 
3121 		switch (ret) {
3122 		case MBEDTLS_ERR_SSL_TIMEOUT:
3123 			(void)mbedtls_ssl_close_notify(&ctx->ssl);
3124 			__fallthrough;
3125 			/* fallthrough */
3126 
3127 		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
3128 		case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
3129 			ret = tls_mbedtls_reset(ctx);
3130 			if (ret == 0) {
3131 				repeat = true;
3132 			} else {
3133 				ctx->error = ENOMEM;
3134 				ret = -ENOMEM;
3135 			}
3136 			break;
3137 
3138 		case MBEDTLS_ERR_SSL_WANT_READ:
3139 		case MBEDTLS_ERR_SSL_WANT_WRITE:
3140 		case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
3141 		case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
3142 			ret = -EAGAIN;
3143 			break;
3144 
3145 		default:
3146 			NET_ERR("DTLS server recv error: -%x", -ret);
3147 
3148 			ret = tls_mbedtls_reset(ctx);
3149 			if (ret != 0) {
3150 				ctx->error = ENOMEM;
3151 				errno = ENOMEM;
3152 			} else {
3153 				ctx->error = ECONNABORTED;
3154 				ret = -ECONNABORTED;
3155 			}
3156 
3157 			break;
3158 		}
3159 	} while (repeat);
3160 
3161 error:
3162 	errno = -ret;
3163 	return -1;
3164 }
3165 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3166 
ztls_recvfrom_ctx(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)3167 ssize_t ztls_recvfrom_ctx(struct tls_context *ctx, void *buf, size_t max_len,
3168 			  int flags, struct net_sockaddr *src_addr,
3169 			  net_socklen_t *addrlen)
3170 {
3171 	if (flags & ZSOCK_MSG_PEEK) {
3172 		/* TODO mbedTLS does not support 'peeking' This could be
3173 		 * bypassed by having intermediate buffer for peeking
3174 		 */
3175 		errno = ENOTSUP;
3176 		return -1;
3177 	}
3178 
3179 	ctx->flags = flags;
3180 
3181 	/* TLS */
3182 	if (ctx->type == NET_SOCK_STREAM) {
3183 		return recv_tls(ctx, buf, max_len, flags);
3184 	}
3185 
3186 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3187 	/* DTLS */
3188 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
3189 		return recvfrom_dtls_server(ctx, buf, max_len, flags,
3190 					    src_addr, addrlen);
3191 	}
3192 
3193 	return recvfrom_dtls_client(ctx, buf, max_len, flags,
3194 				    src_addr, addrlen);
3195 #else
3196 	errno = ENOTSUP;
3197 	return -1;
3198 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3199 }
3200 
ztls_poll_prepare_pollin(struct tls_context * ctx)3201 static int ztls_poll_prepare_pollin(struct tls_context *ctx)
3202 {
3203 	/* If there already is mbedTLS data to read, there is no
3204 	 * need to set the k_poll_event object. Return EALREADY
3205 	 * so we won't block in the k_poll.
3206 	 */
3207 	if (!ctx->is_listening) {
3208 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
3209 			return -EALREADY;
3210 		}
3211 	}
3212 
3213 	return 0;
3214 }
3215 
ztls_poll_prepare_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)3216 static int ztls_poll_prepare_ctx(struct tls_context *ctx,
3217 				 struct zsock_pollfd *pfd,
3218 				 struct k_poll_event **pev,
3219 				 struct k_poll_event *pev_end)
3220 {
3221 	const struct fd_op_vtable *vtable;
3222 	struct k_mutex *lock;
3223 	void *obj;
3224 	int ret;
3225 	short events = pfd->events;
3226 
3227 	/* DTLS client should wait for the handshake to complete before
3228 	 * it actually starts to poll for data.
3229 	 */
3230 	if ((pfd->events & ZSOCK_POLLIN) && (ctx->type == NET_SOCK_DGRAM) &&
3231 	    (ctx->options.role == MBEDTLS_SSL_IS_CLIENT) &&
3232 	    !is_handshake_complete(ctx)) {
3233 		(*pev)->obj = &ctx->tls_established;
3234 		(*pev)->type = K_POLL_TYPE_SEM_AVAILABLE;
3235 		(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
3236 		(*pev)->state = K_POLL_STATE_NOT_READY;
3237 		(*pev)++;
3238 
3239 		/* Since k_poll_event is configured by the TLS layer in this
3240 		 * case, do not forward ZSOCK_POLLIN to the underlying socket.
3241 		 */
3242 		pfd->events &= ~ZSOCK_POLLIN;
3243 	}
3244 
3245 	obj = zvfs_get_fd_obj_and_vtable(
3246 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3247 	if (obj == NULL) {
3248 		ret = -EBADF;
3249 		goto exit;
3250 	}
3251 
3252 	(void)k_mutex_lock(lock, K_FOREVER);
3253 
3254 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
3255 				   pfd, pev, pev_end);
3256 	if (ret != 0) {
3257 		goto exit;
3258 	}
3259 
3260 	if (pfd->events & ZSOCK_POLLIN) {
3261 		ret = ztls_poll_prepare_pollin(ctx);
3262 	}
3263 
3264 exit:
3265 	/* Restore original events. */
3266 	pfd->events = events;
3267 
3268 	k_mutex_unlock(lock);
3269 
3270 	return ret;
3271 }
3272 
3273 #include <zephyr/net/net_core.h>
3274 
ztls_socket_data_check(struct tls_context * ctx)3275 static int ztls_socket_data_check(struct tls_context *ctx)
3276 {
3277 	int ret;
3278 
3279 	if (ctx->type == NET_SOCK_STREAM) {
3280 		if (!ctx->is_initialized) {
3281 			return -ENOTCONN;
3282 		}
3283 	}
3284 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3285 	else {
3286 		if (!ctx->is_initialized) {
3287 			bool is_server = ctx->options.role == MBEDTLS_SSL_IS_SERVER;
3288 
3289 			ret = tls_mbedtls_init(ctx, is_server);
3290 			if (ret < 0) {
3291 				return -ENOMEM;
3292 			}
3293 		}
3294 
3295 		if (!is_handshake_complete(ctx)) {
3296 			ret = tls_mbedtls_handshake(ctx, K_NO_WAIT);
3297 			if (ret < 0) {
3298 				if (ret == -EAGAIN) {
3299 					return 0;
3300 				}
3301 
3302 				ret = tls_mbedtls_reset(ctx);
3303 				if (ret != 0) {
3304 					return -ENOMEM;
3305 				}
3306 
3307 				return 0;
3308 			}
3309 
3310 			/* Socket ready to use again. */
3311 			ctx->error = 0;
3312 
3313 			return 0;
3314 		}
3315 	}
3316 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3317 
3318 	ctx->flags = ZSOCK_MSG_DONTWAIT;
3319 
3320 	ret = mbedtls_ssl_read(&ctx->ssl, NULL, 0);
3321 	if (ret < 0) {
3322 		if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
3323 			/* Don't reset the context for STREAM socket - the
3324 			 * application needs to reopen the socket anyway, and
3325 			 * resetting the context would result in an error instead
3326 			 * of 0 in a consecutive recv() call.
3327 			 */
3328 			if (ctx->type == NET_SOCK_DGRAM) {
3329 				ret = tls_mbedtls_reset(ctx);
3330 				if (ret != 0) {
3331 					return -ENOMEM;
3332 				}
3333 			} else {
3334 				ctx->session_closed = true;
3335 			}
3336 
3337 			return -ENOTCONN;
3338 		}
3339 
3340 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
3341 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
3342 			return 0;
3343 		}
3344 
3345 		NET_ERR("TLS data check error: -%x", -ret);
3346 
3347 		/* MbedTLS API documentation requires session to
3348 		 * be reset in other error cases
3349 		 */
3350 		if (tls_mbedtls_reset(ctx) != 0) {
3351 			return -ENOMEM;
3352 		}
3353 
3354 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3355 		if (ret == MBEDTLS_ERR_SSL_TIMEOUT && ctx->type == NET_SOCK_DGRAM) {
3356 			/* DTLS timeout interpreted as closing of connection. */
3357 			return -ENOTCONN;
3358 		}
3359 #endif
3360 		return -ECONNABORTED;
3361 	}
3362 
3363 	return mbedtls_ssl_get_bytes_avail(&ctx->ssl);
3364 }
3365 
ztls_poll_update_pollin(int fd,struct tls_context * ctx,struct zsock_pollfd * pfd)3366 static int ztls_poll_update_pollin(int fd, struct tls_context *ctx,
3367 				   struct zsock_pollfd *pfd)
3368 {
3369 	int ret;
3370 
3371 	if (!ctx->is_listening) {
3372 		/* Already had TLS data to read on socket. */
3373 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
3374 			pfd->revents |= ZSOCK_POLLIN;
3375 			goto next;
3376 		}
3377 	}
3378 
3379 	if (ctx->type == NET_SOCK_STREAM) {
3380 		if (!(pfd->revents & ZSOCK_POLLIN)) {
3381 			/* No new data on a socket. */
3382 			goto next;
3383 		}
3384 
3385 		if (ctx->is_listening) {
3386 			goto next;
3387 		}
3388 	}
3389 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3390 	else {
3391 		/* Perform data check without incoming data for completed DTLS connections.
3392 		 * This allows the connections to timeout with CONFIG_NET_SOCKETS_DTLS_TIMEOUT.
3393 		 */
3394 		if (!is_handshake_complete(ctx) && !(pfd->revents & ZSOCK_POLLIN)) {
3395 			goto next;
3396 		}
3397 	}
3398 #endif
3399 	ret = ztls_socket_data_check(ctx);
3400 	if (ret == -ENOTCONN || (pfd->revents & ZSOCK_POLLHUP)) {
3401 		/* Datagram does not return 0 on consecutive recv, but an error
3402 		 * code, hence clear POLLIN.
3403 		 */
3404 		if (ctx->type == NET_SOCK_DGRAM) {
3405 			pfd->revents &= ~ZSOCK_POLLIN;
3406 		}
3407 		pfd->revents |= ZSOCK_POLLHUP;
3408 		goto next;
3409 	} else if (ret < 0) {
3410 		ctx->error = -ret;
3411 		pfd->revents |= ZSOCK_POLLERR;
3412 		goto next;
3413 	} else if (ret == 0) {
3414 		goto again;
3415 	}
3416 
3417 next:
3418 	return 0;
3419 
3420 again:
3421 	/* Received encrypted data, but still not enough
3422 	 * to decrypt it and return data through socket,
3423 	 * ask for retry if no other events are set.
3424 	 */
3425 	pfd->revents &= ~ZSOCK_POLLIN;
3426 
3427 	return -EAGAIN;
3428 }
3429 
ztls_poll_update_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev)3430 static int ztls_poll_update_ctx(struct tls_context *ctx,
3431 				struct zsock_pollfd *pfd,
3432 				struct k_poll_event **pev)
3433 {
3434 	const struct fd_op_vtable *vtable;
3435 	struct k_mutex *lock;
3436 	void *obj;
3437 	int ret;
3438 	short events = pfd->events;
3439 
3440 	obj = zvfs_get_fd_obj_and_vtable(
3441 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3442 	if (obj == NULL) {
3443 		return -EBADF;
3444 	}
3445 
3446 	(void)k_mutex_lock(lock, K_FOREVER);
3447 
3448 	/* Check if the socket was waiting for the handshake to complete. */
3449 	if ((pfd->events & ZSOCK_POLLIN) &&
3450 	    ((*pev)->obj == &ctx->tls_established)) {
3451 		/* In case handshake is complete, reconfigure the k_poll_event
3452 		 * to monitor the underlying socket now.
3453 		 */
3454 		if ((*pev)->state != K_POLL_STATE_NOT_READY) {
3455 			ret = zvfs_fdtable_call_ioctl(vtable, obj,
3456 						   ZFD_IOCTL_POLL_PREPARE,
3457 						   pfd, pev, *pev + 1);
3458 			if (ret != 0 && ret != -EALREADY) {
3459 				goto out;
3460 			}
3461 
3462 			/* Return -EAGAIN to signal to poll() that it should
3463 			 * make another iteration with the event reconfigured
3464 			 * above (if needed).
3465 			 */
3466 			ret = -EAGAIN;
3467 			goto out;
3468 		}
3469 
3470 		/* Handshake still not ready - skip ZSOCK_POLLIN verification
3471 		 * for the underlying socket.
3472 		 */
3473 		(*pev)++;
3474 		pfd->events &= ~ZSOCK_POLLIN;
3475 	}
3476 
3477 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_UPDATE,
3478 				   pfd, pev);
3479 	if (ret != 0) {
3480 		goto exit;
3481 	}
3482 
3483 	if (pfd->events & ZSOCK_POLLIN) {
3484 		ret = ztls_poll_update_pollin(pfd->fd, ctx, pfd);
3485 		if (ret == -EAGAIN && pfd->revents == 0) {
3486 			(*pev - 1)->state = K_POLL_STATE_NOT_READY;
3487 			goto exit;
3488 		} else {
3489 			ret = 0;
3490 		}
3491 	}
3492 exit:
3493 	/* Restore original events. */
3494 	pfd->events = events;
3495 
3496 out:
3497 	k_mutex_unlock(lock);
3498 
3499 	return ret;
3500 }
3501 
3502 /* Return true if needed to retry rightoff or false otherwise. */
poll_offload_dtls_client_retry(struct tls_context * ctx,struct zsock_pollfd * pfd)3503 static bool poll_offload_dtls_client_retry(struct tls_context *ctx,
3504 					   struct zsock_pollfd *pfd)
3505 {
3506 	/* DTLS client should wait for the handshake to complete before it
3507 	 * reports that data is ready.
3508 	 */
3509 	if ((ctx->type != NET_SOCK_DGRAM) ||
3510 	    (ctx->options.role != MBEDTLS_SSL_IS_CLIENT)) {
3511 		return false;
3512 	}
3513 
3514 	if (ctx->handshake_in_progress) {
3515 		/* Add some sleep to allow lower priority threads to proceed
3516 		 * with handshake.
3517 		 */
3518 		k_msleep(10);
3519 
3520 		pfd->revents &= ~ZSOCK_POLLIN;
3521 		return true;
3522 	} else if (!is_handshake_complete(ctx)) {
3523 		uint8_t byte;
3524 		int ret;
3525 
3526 		/* Handshake didn't start yet - just drop the incoming data -
3527 		 * it's the client who should initiate the handshake.
3528 		 */
3529 		ret = zsock_recv(ctx->sock, &byte, sizeof(byte),
3530 				 ZSOCK_MSG_DONTWAIT);
3531 		if (ret < 0) {
3532 			pfd->revents |= ZSOCK_POLLERR;
3533 		}
3534 
3535 		pfd->revents &= ~ZSOCK_POLLIN;
3536 		return true;
3537 	}
3538 
3539 	/* Handshake complete, just proceed. */
3540 	return false;
3541 }
3542 
ztls_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)3543 static int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int timeout)
3544 {
3545 	int fd_backup[CONFIG_ZVFS_POLL_MAX];
3546 	const struct fd_op_vtable *vtable;
3547 	void *ctx;
3548 	int ret = 0;
3549 	int result;
3550 	int i;
3551 	bool retry;
3552 	int remaining;
3553 	uint32_t entry = k_uptime_get_32();
3554 
3555 	/* Overwrite TLS file descriptors with underlying ones. */
3556 	for (i = 0; i < nfds; i++) {
3557 		fd_backup[i] = fds[i].fd;
3558 
3559 		ctx = zvfs_get_fd_obj(fds[i].fd,
3560 				   (const struct fd_op_vtable *)
3561 						     &tls_sock_fd_op_vtable,
3562 				   0);
3563 		if (ctx == NULL) {
3564 			continue;
3565 		}
3566 
3567 		if (fds[i].events & ZSOCK_POLLIN) {
3568 			ret = ztls_poll_prepare_pollin(ctx);
3569 			/* In case data is already available in mbedtls,
3570 			 * do not wait in poll.
3571 			 */
3572 			if (ret == -EALREADY) {
3573 				timeout = 0;
3574 			}
3575 		}
3576 
3577 		fds[i].fd = ((struct tls_context *)ctx)->sock;
3578 	}
3579 
3580 	/* Get offloaded sockets vtable. */
3581 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
3582 				      (const struct fd_op_vtable **)&vtable,
3583 				      NULL);
3584 	if (ctx == NULL) {
3585 		errno = EINVAL;
3586 		goto exit;
3587 	}
3588 
3589 	remaining = timeout;
3590 
3591 	do {
3592 		for (i = 0; i < nfds; i++) {
3593 			fds[i].revents = 0;
3594 		}
3595 
3596 		ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
3597 					   fds, nfds, remaining);
3598 		if (ret < 0) {
3599 			goto exit;
3600 		}
3601 
3602 		retry = false;
3603 		ret = 0;
3604 
3605 		for (i = 0; i < nfds; i++) {
3606 			ctx = zvfs_get_fd_obj(fd_backup[i],
3607 					   (const struct fd_op_vtable *)
3608 							&tls_sock_fd_op_vtable,
3609 					   0);
3610 			if (ctx != NULL) {
3611 				if (fds[i].events & ZSOCK_POLLIN) {
3612 					if (poll_offload_dtls_client_retry(
3613 							ctx, &fds[i])) {
3614 						retry = true;
3615 						continue;
3616 					}
3617 
3618 					result = ztls_poll_update_pollin(
3619 						    fd_backup[i], ctx, &fds[i]);
3620 					if (result == -EAGAIN) {
3621 						retry = true;
3622 					}
3623 				}
3624 			}
3625 
3626 			if (fds[i].revents != 0) {
3627 				ret++;
3628 			}
3629 		}
3630 
3631 		if (retry) {
3632 			if (ret > 0 || timeout == 0) {
3633 				goto exit;
3634 			}
3635 
3636 			if (timeout > 0) {
3637 				remaining = time_left(entry, timeout);
3638 				if (remaining <= 0) {
3639 					goto exit;
3640 				}
3641 			}
3642 		}
3643 	} while (retry);
3644 
3645 exit:
3646 	/* Restore original fds. */
3647 	for (i = 0; i < nfds; i++) {
3648 		fds[i].fd = fd_backup[i];
3649 	}
3650 
3651 	return ret;
3652 }
3653 
ztls_getsockopt_ctx(struct tls_context * ctx,int level,int optname,void * optval,net_socklen_t * optlen)3654 int ztls_getsockopt_ctx(struct tls_context *ctx, int level, int optname,
3655 			void *optval, net_socklen_t *optlen)
3656 {
3657 	int err;
3658 
3659 	if (!optval || !optlen) {
3660 		errno = EINVAL;
3661 		return -1;
3662 	}
3663 
3664 	if ((level == ZSOCK_SOL_SOCKET) && (optname == ZSOCK_SO_PROTOCOL)) {
3665 		/* Protocol type is overridden during socket creation. Its
3666 		 * value is restored here to return current value.
3667 		 */
3668 		err = sock_opt_protocol_get(ctx, optval, optlen);
3669 		if (err < 0) {
3670 			errno = -err;
3671 			return -1;
3672 		}
3673 		return err;
3674 	}
3675 
3676 	/* In case error was set on a socket at the TLS layer (for example due
3677 	 * to receiving TLS alert), handle SO_ERROR here, and report that error.
3678 	 * Otherwise, forward the SO_ERROR option request to the underlying
3679 	 * TCP/UDP socket to handle.
3680 	 */
3681 	if ((level == ZSOCK_SOL_SOCKET) && (optname == ZSOCK_SO_ERROR) && ctx->error != 0) {
3682 		if (*optlen != sizeof(int)) {
3683 			errno = EINVAL;
3684 			return -1;
3685 		}
3686 
3687 		*(int *)optval = ctx->error;
3688 
3689 		return 0;
3690 	}
3691 
3692 	if (level != ZSOCK_SOL_TLS) {
3693 		return zsock_getsockopt(ctx->sock, level, optname,
3694 					optval, optlen);
3695 	}
3696 
3697 	switch (optname) {
3698 	case ZSOCK_TLS_SEC_TAG_LIST:
3699 		err =  tls_opt_sec_tag_list_get(ctx, optval, optlen);
3700 		break;
3701 
3702 	case ZSOCK_TLS_CIPHERSUITE_LIST:
3703 		err = tls_opt_ciphersuite_list_get(ctx, optval, optlen);
3704 		break;
3705 
3706 	case ZSOCK_TLS_CIPHERSUITE_USED:
3707 		err = tls_opt_ciphersuite_used_get(ctx, optval, optlen);
3708 		break;
3709 
3710 	case ZSOCK_TLS_ALPN_LIST:
3711 		err = tls_opt_alpn_list_get(ctx, optval, optlen);
3712 		break;
3713 
3714 	case ZSOCK_TLS_SESSION_CACHE:
3715 		err = tls_opt_session_cache_get(ctx, optval, optlen);
3716 		break;
3717 
3718 	case ZSOCK_TLS_CERT_VERIFY_RESULT:
3719 		err = tls_opt_cert_verify_result_get(ctx, optval, optlen);
3720 		break;
3721 
3722 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3723 	case ZSOCK_TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3724 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3725 							 optlen, false);
3726 		break;
3727 
3728 	case ZSOCK_TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3729 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3730 							 optlen, true);
3731 		break;
3732 
3733 	case ZSOCK_TLS_DTLS_CID_STATUS:
3734 		err = tls_opt_dtls_connection_id_status_get(ctx, optval,
3735 							    optlen);
3736 		break;
3737 
3738 	case ZSOCK_TLS_DTLS_CID_VALUE:
3739 		err = tls_opt_dtls_connection_id_value_get(ctx, optval, optlen);
3740 		break;
3741 
3742 	case ZSOCK_TLS_DTLS_PEER_CID_VALUE:
3743 		err = tls_opt_dtls_peer_connection_id_value_get(ctx, optval,
3744 								optlen);
3745 		break;
3746 
3747 	case ZSOCK_TLS_DTLS_HANDSHAKE_ON_CONNECT:
3748 		err = tls_opt_dtls_handshake_on_connect_get(ctx, optval, optlen);
3749 		break;
3750 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3751 
3752 	default:
3753 		/* Unknown or write-only option. */
3754 		err = -ENOPROTOOPT;
3755 		break;
3756 	}
3757 
3758 	if (err < 0) {
3759 		errno = -err;
3760 		return -1;
3761 	}
3762 
3763 	return 0;
3764 }
3765 
set_timeout_opt(k_timeout_t * timeout,const void * optval,net_socklen_t optlen)3766 static int set_timeout_opt(k_timeout_t *timeout, const void *optval,
3767 			   net_socklen_t optlen)
3768 {
3769 	const struct zsock_timeval *tval = optval;
3770 
3771 	if (optlen != sizeof(struct zsock_timeval)) {
3772 		return -EINVAL;
3773 	}
3774 
3775 	if (tval->tv_sec == 0 && tval->tv_usec == 0) {
3776 		*timeout = K_FOREVER;
3777 	} else {
3778 		*timeout = K_USEC(tval->tv_sec * 1000000ULL + tval->tv_usec);
3779 	}
3780 
3781 	return 0;
3782 }
3783 
ztls_setsockopt_ctx(struct tls_context * ctx,int level,int optname,const void * optval,net_socklen_t optlen)3784 int ztls_setsockopt_ctx(struct tls_context *ctx, int level, int optname,
3785 			const void *optval, net_socklen_t optlen)
3786 {
3787 	int err;
3788 
3789 	/* Underlying socket is used in non-blocking mode, hence implement
3790 	 * timeout at the TLS socket level.
3791 	 */
3792 	if ((level == ZSOCK_SOL_SOCKET) && (optname == ZSOCK_SO_SNDTIMEO)) {
3793 		err = set_timeout_opt(&ctx->options.timeout_tx, optval, optlen);
3794 		goto out;
3795 	}
3796 
3797 	if ((level == ZSOCK_SOL_SOCKET) && (optname == ZSOCK_SO_RCVTIMEO)) {
3798 		err = set_timeout_opt(&ctx->options.timeout_rx, optval, optlen);
3799 		goto out;
3800 	}
3801 
3802 	if (level != ZSOCK_SOL_TLS) {
3803 		return zsock_setsockopt(ctx->sock, level, optname,
3804 					optval, optlen);
3805 	}
3806 
3807 	switch (optname) {
3808 	case ZSOCK_TLS_SEC_TAG_LIST:
3809 		err =  tls_opt_sec_tag_list_set(ctx, optval, optlen);
3810 		break;
3811 
3812 	case ZSOCK_TLS_HOSTNAME:
3813 		err = tls_opt_hostname_set(ctx, optval, optlen);
3814 		break;
3815 
3816 	case ZSOCK_TLS_CIPHERSUITE_LIST:
3817 		err = tls_opt_ciphersuite_list_set(ctx, optval, optlen);
3818 		break;
3819 
3820 	case ZSOCK_TLS_PEER_VERIFY:
3821 		err = tls_opt_peer_verify_set(ctx, optval, optlen);
3822 		break;
3823 
3824 	case ZSOCK_TLS_CERT_NOCOPY:
3825 		err = tls_opt_cert_nocopy_set(ctx, optval, optlen);
3826 		break;
3827 
3828 	case ZSOCK_TLS_DTLS_ROLE:
3829 		err = tls_opt_dtls_role_set(ctx, optval, optlen);
3830 		break;
3831 
3832 	case ZSOCK_TLS_ALPN_LIST:
3833 		err = tls_opt_alpn_list_set(ctx, optval, optlen);
3834 		break;
3835 
3836 	case ZSOCK_TLS_SESSION_CACHE:
3837 		err = tls_opt_session_cache_set(ctx, optval, optlen);
3838 		break;
3839 
3840 	case ZSOCK_TLS_SESSION_CACHE_PURGE:
3841 		err = tls_opt_session_cache_purge_set(ctx, optval, optlen);
3842 		break;
3843 
3844 	case ZSOCK_TLS_CERT_VERIFY_CALLBACK:
3845 		err = tls_opt_cert_verify_callback_set(ctx, optval, optlen);
3846 		break;
3847 
3848 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3849 	case ZSOCK_TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3850 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3851 							 optlen, false);
3852 		break;
3853 
3854 	case ZSOCK_TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3855 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3856 							 optlen, true);
3857 		break;
3858 
3859 	case ZSOCK_TLS_DTLS_CID:
3860 		err = tls_opt_dtls_connection_id_set(ctx, optval, optlen);
3861 		break;
3862 
3863 	case ZSOCK_TLS_DTLS_CID_VALUE:
3864 		err = tls_opt_dtls_connection_id_value_set(ctx, optval, optlen);
3865 		break;
3866 
3867 	case ZSOCK_TLS_DTLS_HANDSHAKE_ON_CONNECT:
3868 		err = tls_opt_dtls_handshake_on_connect_set(ctx, optval, optlen);
3869 		break;
3870 
3871 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3872 
3873 	case ZSOCK_TLS_NATIVE:
3874 		/* Option handled at the socket dispatcher level. */
3875 		err = 0;
3876 		break;
3877 
3878 	default:
3879 		/* Unknown or read-only option. */
3880 		err = -ENOPROTOOPT;
3881 		break;
3882 	}
3883 
3884 out:
3885 	if (err < 0) {
3886 		errno = -err;
3887 		return -1;
3888 	}
3889 
3890 	return 0;
3891 }
3892 
3893 #if defined(CONFIG_NET_TEST)
ztls_get_mbedtls_ssl_context(int fd)3894 mbedtls_ssl_context *ztls_get_mbedtls_ssl_context(int fd)
3895 {
3896 	struct tls_context *ctx;
3897 
3898 	ctx = zvfs_get_fd_obj(fd, (const struct fd_op_vtable *)
3899 					&tls_sock_fd_op_vtable, EBADF);
3900 	if (ctx == NULL) {
3901 		return NULL;
3902 	}
3903 
3904 	return &ctx->ssl;
3905 }
3906 #endif /* CONFIG_NET_TEST */
3907 
tls_sock_read_vmeth(void * obj,void * buffer,size_t count)3908 static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count)
3909 {
3910 	return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0);
3911 }
3912 
tls_sock_write_vmeth(void * obj,const void * buffer,size_t count)3913 static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer,
3914 				    size_t count)
3915 {
3916 	return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0);
3917 }
3918 
tls_sock_ioctl_vmeth(void * obj,unsigned int request,va_list args)3919 static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
3920 {
3921 	struct tls_context *ctx = obj;
3922 
3923 	switch (request) {
3924 	/* fcntl() commands */
3925 	case ZVFS_F_GETFL:
3926 	case ZVFS_F_SETFL: {
3927 		const struct fd_op_vtable *vtable;
3928 		struct k_mutex *lock;
3929 		void *fd_obj;
3930 		int ret;
3931 
3932 		fd_obj = zvfs_get_fd_obj_and_vtable(ctx->sock,
3933 				(const struct fd_op_vtable **)&vtable, &lock);
3934 		if (fd_obj == NULL) {
3935 			errno = EBADF;
3936 			return -1;
3937 		}
3938 
3939 		(void)k_mutex_lock(lock, K_FOREVER);
3940 
3941 		/* Pass the call to the core socket implementation. */
3942 		ret = vtable->ioctl(fd_obj, request, args);
3943 
3944 		k_mutex_unlock(lock);
3945 
3946 		return ret;
3947 	}
3948 
3949 	case ZFD_IOCTL_SET_LOCK: {
3950 		struct k_mutex *lock;
3951 
3952 		lock = va_arg(args, struct k_mutex *);
3953 
3954 		ctx_set_lock(obj, lock);
3955 
3956 		return 0;
3957 	}
3958 
3959 	case ZFD_IOCTL_POLL_PREPARE: {
3960 		struct zsock_pollfd *pfd;
3961 		struct k_poll_event **pev;
3962 		struct k_poll_event *pev_end;
3963 
3964 		pfd = va_arg(args, struct zsock_pollfd *);
3965 		pev = va_arg(args, struct k_poll_event **);
3966 		pev_end = va_arg(args, struct k_poll_event *);
3967 
3968 		return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end);
3969 	}
3970 
3971 	case ZFD_IOCTL_POLL_UPDATE: {
3972 		struct zsock_pollfd *pfd;
3973 		struct k_poll_event **pev;
3974 
3975 		pfd = va_arg(args, struct zsock_pollfd *);
3976 		pev = va_arg(args, struct k_poll_event **);
3977 
3978 		return ztls_poll_update_ctx(obj, pfd, pev);
3979 	}
3980 
3981 	case ZFD_IOCTL_POLL_OFFLOAD: {
3982 		struct zsock_pollfd *fds;
3983 		int nfds;
3984 		int timeout;
3985 
3986 		fds = va_arg(args, struct zsock_pollfd *);
3987 		nfds = va_arg(args, int);
3988 		timeout = va_arg(args, int);
3989 
3990 		return ztls_poll_offload(fds, nfds, timeout);
3991 	}
3992 
3993 	default:
3994 		errno = EOPNOTSUPP;
3995 		return -1;
3996 	}
3997 }
3998 
tls_sock_shutdown_vmeth(void * obj,int how)3999 static int tls_sock_shutdown_vmeth(void *obj, int how)
4000 {
4001 	struct tls_context *ctx = obj;
4002 
4003 	return zsock_shutdown(ctx->sock, how);
4004 }
4005 
tls_sock_bind_vmeth(void * obj,const struct net_sockaddr * addr,net_socklen_t addrlen)4006 static int tls_sock_bind_vmeth(void *obj, const struct net_sockaddr *addr,
4007 			       net_socklen_t addrlen)
4008 {
4009 	struct tls_context *ctx = obj;
4010 
4011 	return zsock_bind(ctx->sock, addr, addrlen);
4012 }
4013 
tls_sock_connect_vmeth(void * obj,const struct net_sockaddr * addr,net_socklen_t addrlen)4014 static int tls_sock_connect_vmeth(void *obj, const struct net_sockaddr *addr,
4015 				  net_socklen_t addrlen)
4016 {
4017 	return ztls_connect_ctx(obj, addr, addrlen);
4018 }
4019 
tls_sock_listen_vmeth(void * obj,int backlog)4020 static int tls_sock_listen_vmeth(void *obj, int backlog)
4021 {
4022 	struct tls_context *ctx = obj;
4023 
4024 	ctx->is_listening = true;
4025 
4026 	return zsock_listen(ctx->sock, backlog);
4027 }
4028 
tls_sock_accept_vmeth(void * obj,struct net_sockaddr * addr,net_socklen_t * addrlen)4029 static int tls_sock_accept_vmeth(void *obj, struct net_sockaddr *addr,
4030 				 net_socklen_t *addrlen)
4031 {
4032 	return ztls_accept_ctx(obj, addr, addrlen);
4033 }
4034 
tls_sock_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct net_sockaddr * dest_addr,net_socklen_t addrlen)4035 static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len,
4036 				     int flags,
4037 				     const struct net_sockaddr *dest_addr,
4038 				     net_socklen_t addrlen)
4039 {
4040 	return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen);
4041 }
4042 
tls_sock_sendmsg_vmeth(void * obj,const struct net_msghdr * msg,int flags)4043 static ssize_t tls_sock_sendmsg_vmeth(void *obj, const struct net_msghdr *msg,
4044 				      int flags)
4045 {
4046 	return ztls_sendmsg_ctx(obj, msg, flags);
4047 }
4048 
tls_sock_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct net_sockaddr * src_addr,net_socklen_t * addrlen)4049 static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len,
4050 				       int flags, struct net_sockaddr *src_addr,
4051 				       net_socklen_t *addrlen)
4052 {
4053 	return ztls_recvfrom_ctx(obj, buf, max_len, flags,
4054 				 src_addr, addrlen);
4055 }
4056 
tls_sock_getsockopt_vmeth(void * obj,int level,int optname,void * optval,net_socklen_t * optlen)4057 static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname,
4058 				     void *optval, net_socklen_t *optlen)
4059 {
4060 	return ztls_getsockopt_ctx(obj, level, optname, optval, optlen);
4061 }
4062 
tls_sock_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,net_socklen_t optlen)4063 static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname,
4064 				     const void *optval, net_socklen_t optlen)
4065 {
4066 	return ztls_setsockopt_ctx(obj, level, optname, optval, optlen);
4067 }
4068 
tls_sock_close2_vmeth(void * obj,int sock)4069 static int tls_sock_close2_vmeth(void *obj, int sock)
4070 {
4071 	return ztls_close_ctx(obj, sock);
4072 }
4073 
tls_sock_getpeername_vmeth(void * obj,struct net_sockaddr * addr,net_socklen_t * addrlen)4074 static int tls_sock_getpeername_vmeth(void *obj, struct net_sockaddr *addr,
4075 				      net_socklen_t *addrlen)
4076 {
4077 	struct tls_context *ctx = obj;
4078 
4079 	return zsock_getpeername(ctx->sock, addr, addrlen);
4080 }
4081 
tls_sock_getsockname_vmeth(void * obj,struct net_sockaddr * addr,net_socklen_t * addrlen)4082 static int tls_sock_getsockname_vmeth(void *obj, struct net_sockaddr *addr,
4083 				      net_socklen_t *addrlen)
4084 {
4085 	struct tls_context *ctx = obj;
4086 
4087 	return zsock_getsockname(ctx->sock, addr, addrlen);
4088 }
4089 
4090 static const struct socket_op_vtable tls_sock_fd_op_vtable = {
4091 	.fd_vtable = {
4092 		.read = tls_sock_read_vmeth,
4093 		.write = tls_sock_write_vmeth,
4094 		.close2 = tls_sock_close2_vmeth,
4095 		.ioctl = tls_sock_ioctl_vmeth,
4096 	},
4097 	.shutdown = tls_sock_shutdown_vmeth,
4098 	.bind = tls_sock_bind_vmeth,
4099 	.connect = tls_sock_connect_vmeth,
4100 	.listen = tls_sock_listen_vmeth,
4101 	.accept = tls_sock_accept_vmeth,
4102 	.sendto = tls_sock_sendto_vmeth,
4103 	.sendmsg = tls_sock_sendmsg_vmeth,
4104 	.recvfrom = tls_sock_recvfrom_vmeth,
4105 	.getsockopt = tls_sock_getsockopt_vmeth,
4106 	.setsockopt = tls_sock_setsockopt_vmeth,
4107 	.getpeername = tls_sock_getpeername_vmeth,
4108 	.getsockname = tls_sock_getsockname_vmeth,
4109 };
4110 
tls_is_supported(int family,int type,int proto)4111 static bool tls_is_supported(int family, int type, int proto)
4112 {
4113 	if (protocol_check(family, type, &proto) == 0) {
4114 		return true;
4115 	}
4116 
4117 	return false;
4118 }
4119 
4120 /* Since both, TLS sockets and regular ones fall under the same address family,
4121  * it's required to process TLS first in order to capture socket calls which
4122  * create sockets for secure protocols. Every other call for NET_AF_INET/NET_AF_INET6
4123  * will be forwarded to regular socket implementation.
4124  */
4125 BUILD_ASSERT(CONFIG_NET_SOCKETS_TLS_PRIORITY < CONFIG_NET_SOCKETS_PRIORITY_DEFAULT,
4126 	     "CONFIG_NET_SOCKETS_TLS_PRIORITY have to be smaller than CONFIG_NET_SOCKETS_PRIORITY_DEFAULT");
4127 
4128 NET_SOCKET_REGISTER(tls, CONFIG_NET_SOCKETS_TLS_PRIORITY, NET_AF_UNSPEC,
4129 		    tls_is_supported, ztls_socket);
4130