1 /*
2  * Copyright (c) 2018 Intel Corporation
3  * Copyright (c) 2018 Nordic Semiconductor ASA
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <stdbool.h>
9 #include <zephyr/posix/fcntl.h>
10 
11 #include <zephyr/logging/log.h>
12 LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL);
13 
14 #include <zephyr/init.h>
15 #include <zephyr/sys/util.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/random/random.h>
18 #include <zephyr/internal/syscall_handler.h>
19 #include <zephyr/sys/fdtable.h>
20 
21 /* TODO: Remove all direct access to private fields.
22  * According with Mbed TLS migration guide:
23  *
24  * Direct access to fields of structures
25  * (`struct` types) declared in public headers is no longer
26  * supported. In Mbed TLS 3, the layout of structures is not
27  * considered part of the stable API, and minor versions (3.1, 3.2,
28  * etc.) may add, remove, rename, reorder or change the type of
29  * structure fields.
30  */
31 #if !defined(MBEDTLS_ALLOW_PRIVATE_ACCESS)
32 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
33 #endif
34 
35 #if defined(CONFIG_MBEDTLS)
36 #if !defined(CONFIG_MBEDTLS_CFG_FILE)
37 #include "mbedtls/config.h"
38 #else
39 #include CONFIG_MBEDTLS_CFG_FILE
40 #endif /* CONFIG_MBEDTLS_CFG_FILE */
41 
42 #include <mbedtls/net_sockets.h>
43 #include <mbedtls/x509.h>
44 #include <mbedtls/x509_crt.h>
45 #include <mbedtls/ssl.h>
46 #include <mbedtls/ssl_cookie.h>
47 #include <mbedtls/error.h>
48 #include <mbedtls/platform.h>
49 #include <mbedtls/ssl_cache.h>
50 #endif /* CONFIG_MBEDTLS */
51 
52 #include "sockets_internal.h"
53 #include "tls_internal.h"
54 
55 #if defined(CONFIG_MBEDTLS_DEBUG)
56 #include <zephyr_mbedtls_priv.h>
57 #endif
58 
59 #if defined(CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS)
60 #define ALPN_MAX_PROTOCOLS (CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS + 1)
61 #else
62 #define ALPN_MAX_PROTOCOLS 0
63 #endif /* CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS */
64 
65 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
66 #define DTLS_SENDMSG_BUF_SIZE (CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE)
67 #else
68 #define DTLS_SENDMSG_BUF_SIZE 0
69 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
70 
71 static const struct socket_op_vtable tls_sock_fd_op_vtable;
72 
73 #ifndef MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED
74 #define MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE
75 #endif
76 
77 /** A list of secure tags that TLS context should use. */
78 struct sec_tag_list {
79 	/** An array of secure tags referencing TLS credentials. */
80 	sec_tag_t sec_tags[CONFIG_NET_SOCKETS_TLS_MAX_CREDENTIALS];
81 
82 	/** Number of configured secure tags. */
83 	int sec_tag_count;
84 };
85 
86 /** Timer context for DTLS. */
87 struct dtls_timing_context {
88 	/** Current time, stored during timer set. */
89 	uint32_t snapshot;
90 
91 	/** Intermediate delay value. For details, refer to mbedTLS API
92 	 *  documentation (mbedtls_ssl_set_timer_t).
93 	 */
94 	uint32_t int_ms;
95 
96 	/** Final delay value. For details, refer to mbedTLS API documentation
97 	 *  (mbedtls_ssl_set_timer_t).
98 	 */
99 	uint32_t fin_ms;
100 };
101 
102 /** TLS peer address/session ID mapping. */
103 struct tls_session_cache {
104 	/** Creation time. */
105 	int64_t timestamp;
106 
107 	/** Peer address. */
108 	struct sockaddr peer_addr;
109 
110 	/** Session buffer. */
111 	uint8_t *session;
112 
113 	/** Session length. */
114 	size_t session_len;
115 };
116 
117 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
118 struct tls_dtls_cid {
119 	bool enabled;
120 	unsigned char cid[MAX(MBEDTLS_SSL_CID_OUT_LEN_MAX,
121 			      MBEDTLS_SSL_CID_IN_LEN_MAX)];
122 	size_t cid_len;
123 };
124 #endif
125 
126 /** TLS context information. */
127 __net_socket struct tls_context {
128 	/** Underlying TCP/UDP socket. */
129 	int sock;
130 
131 	/** Information whether TLS context is used. */
132 	bool is_used : 1;
133 
134 	/** Information whether TLS context was initialized. */
135 	bool is_initialized : 1;
136 
137 	/** Information whether underlying socket is listening. */
138 	bool is_listening : 1;
139 
140 	/** Information whether TLS handshake is currently in progress. */
141 	bool handshake_in_progress : 1;
142 
143 	/** Session ended at the TLS/DTLS level. */
144 	bool session_closed : 1;
145 
146 	/** Socket type. */
147 	enum net_sock_type type;
148 
149 	/** Secure protocol version running on TLS context. */
150 	enum net_ip_protocol_secure tls_version;
151 
152 	/** Socket flags passed to a socket call. */
153 	int flags;
154 
155 	/* Indicates whether socket is in error state at TLS/DTLS level. */
156 	int error;
157 
158 	/** Information whether TLS handshake is complete or not. */
159 	struct k_sem tls_established;
160 
161 	/* TLS socket mutex lock. */
162 	struct k_mutex *lock;
163 
164 	/** TLS specific option values. */
165 	struct {
166 		/** Select which credentials to use with TLS. */
167 		struct sec_tag_list sec_tag_list;
168 
169 		/** 0-terminated list of allowed ciphersuites (mbedTLS format).
170 		 */
171 		int ciphersuites[CONFIG_NET_SOCKETS_TLS_MAX_CIPHERSUITES + 1];
172 
173 		/** Information if hostname was explicitly set on a socket. */
174 		bool is_hostname_set;
175 
176 		/** Peer verification level. */
177 		int8_t verify_level;
178 
179 		/** Indicating on whether DER certificates should not be copied
180 		 * to the heap.
181 		 */
182 		int8_t cert_nocopy;
183 
184 		/** DTLS role, client by default. */
185 		int8_t role;
186 
187 		/** NULL-terminated list of allowed application layer
188 		 * protocols.
189 		 */
190 		const char *alpn_list[ALPN_MAX_PROTOCOLS];
191 
192 		/** Session cache enabled on a socket. */
193 		bool cache_enabled;
194 
195 		/** Socket TX timeout */
196 		k_timeout_t timeout_tx;
197 
198 		/** Socket RX timeout */
199 		k_timeout_t timeout_rx;
200 
201 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
202 		/* DTLS handshake timeout */
203 		uint32_t dtls_handshake_timeout_min;
204 		uint32_t dtls_handshake_timeout_max;
205 
206 		struct tls_dtls_cid dtls_cid;
207 
208 		bool dtls_handshake_on_connect;
209 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
210 	} options;
211 
212 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
213 	/** Context information for DTLS timing. */
214 	struct dtls_timing_context dtls_timing;
215 
216 	/** mbedTLS cookie context for DTLS */
217 	mbedtls_ssl_cookie_ctx cookie;
218 
219 	/** DTLS peer address. */
220 	struct sockaddr dtls_peer_addr;
221 
222 	/** DTLS peer address length. */
223 	socklen_t dtls_peer_addrlen;
224 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
225 
226 #if defined(CONFIG_MBEDTLS)
227 	/** mbedTLS context. */
228 	mbedtls_ssl_context ssl;
229 
230 	/** mbedTLS configuration. */
231 	mbedtls_ssl_config config;
232 
233 #if defined(MBEDTLS_X509_CRT_PARSE_C)
234 	/** mbedTLS structure for CA chain. */
235 	mbedtls_x509_crt ca_chain;
236 
237 	/** mbedTLS structure for own certificate. */
238 	mbedtls_x509_crt own_cert;
239 
240 	/** mbedTLS structure for own private key. */
241 	mbedtls_pk_context priv_key;
242 #endif /* MBEDTLS_X509_CRT_PARSE_C */
243 
244 #endif /* CONFIG_MBEDTLS */
245 };
246 
247 
248 /* A global pool of TLS contexts. */
249 static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS];
250 
251 static struct tls_session_cache client_cache[CONFIG_NET_SOCKETS_TLS_MAX_CLIENT_SESSION_COUNT];
252 
253 #if defined(MBEDTLS_SSL_CACHE_C)
254 static mbedtls_ssl_cache_context server_cache;
255 #endif
256 
257 /* A mutex for protecting TLS context allocation. */
258 static struct k_mutex context_lock;
259 
260 /* Arbitrary delay value to wait if mbedTLS reports it cannot proceed for
261  * reasons other than TX/RX block.
262  */
263 #define TLS_WAIT_MS 100
264 
tls_session_cache_reset(void)265 static void tls_session_cache_reset(void)
266 {
267 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
268 		if (client_cache[i].session != NULL) {
269 			mbedtls_free(client_cache[i].session);
270 		}
271 	}
272 
273 	(void)memset(client_cache, 0, sizeof(client_cache));
274 }
275 
net_socket_is_tls(void * obj)276 bool net_socket_is_tls(void *obj)
277 {
278 	return PART_OF_ARRAY(tls_contexts, (struct tls_context *)obj);
279 }
280 
tls_ctr_drbg_random(void * ctx,unsigned char * buf,size_t len)281 static int tls_ctr_drbg_random(void *ctx, unsigned char *buf, size_t len)
282 {
283 	ARG_UNUSED(ctx);
284 
285 #if defined(CONFIG_CSPRNG_ENABLED)
286 	return sys_csrand_get(buf, len);
287 #else
288 	sys_rand_get(buf, len);
289 
290 	return 0;
291 #endif
292 }
293 
294 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
295 /* mbedTLS-defined function for setting timer. */
dtls_timing_set_delay(void * data,uint32_t int_ms,uint32_t fin_ms)296 static void dtls_timing_set_delay(void *data, uint32_t int_ms, uint32_t fin_ms)
297 {
298 	struct dtls_timing_context *ctx = data;
299 
300 	ctx->int_ms = int_ms;
301 	ctx->fin_ms = fin_ms;
302 
303 	if (fin_ms != 0U) {
304 		ctx->snapshot = k_uptime_get_32();
305 	}
306 }
307 
308 /* mbedTLS-defined function for getting timer status.
309  * The return values are specified by mbedTLS. The callback must return:
310  *   -1 if cancelled (fin_ms == 0),
311  *    0 if none of the delays have passed,
312  *    1 if only the intermediate delay has passed,
313  *    2 if the final delay has passed.
314  */
dtls_timing_get_delay(void * data)315 static int dtls_timing_get_delay(void *data)
316 {
317 	struct dtls_timing_context *timing = data;
318 	unsigned long elapsed_ms;
319 
320 	NET_ASSERT(timing);
321 
322 	if (timing->fin_ms == 0U) {
323 		return -1;
324 	}
325 
326 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
327 
328 	if (elapsed_ms >= timing->fin_ms) {
329 		return 2;
330 	}
331 
332 	if (elapsed_ms >= timing->int_ms) {
333 		return 1;
334 	}
335 
336 	return 0;
337 }
338 
dtls_get_remaining_timeout(struct tls_context * ctx)339 static int dtls_get_remaining_timeout(struct tls_context *ctx)
340 {
341 	struct dtls_timing_context *timing = &ctx->dtls_timing;
342 	uint32_t elapsed_ms;
343 
344 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
345 
346 	if (timing->fin_ms == 0U) {
347 		return SYS_FOREVER_MS;
348 	}
349 
350 	if (elapsed_ms >= timing->fin_ms) {
351 		return 0;
352 	}
353 
354 	return timing->fin_ms - elapsed_ms;
355 }
356 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
357 
358 /* Initialize TLS internals. */
tls_init(void)359 static int tls_init(void)
360 {
361 
362 #if !defined(CONFIG_ENTROPY_HAS_DRIVER)
363 	NET_WARN("No entropy device on the system, "
364 		 "TLS communication is insecure!");
365 #endif
366 
367 	(void)memset(tls_contexts, 0, sizeof(tls_contexts));
368 	(void)memset(client_cache, 0, sizeof(client_cache));
369 
370 	k_mutex_init(&context_lock);
371 
372 #if defined(MBEDTLS_SSL_CACHE_C)
373 	mbedtls_ssl_cache_init(&server_cache);
374 #endif
375 
376 	return 0;
377 }
378 
379 SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT);
380 
is_handshake_complete(struct tls_context * ctx)381 static inline bool is_handshake_complete(struct tls_context *ctx)
382 {
383 	return k_sem_count_get(&ctx->tls_established) != 0;
384 }
385 
386 /*
387  * Copied from include/mbedtls/ssl_internal.h
388  *
389  * Maximum length we can advertise as our max content length for
390  * RFC 6066 max_fragment_length extension negotiation purposes
391  * (the lesser of both sizes, if they are unequal.)
392  */
393 #define MBEDTLS_TLS_EXT_ADV_CONTENT_LEN (                            \
394 	(MBEDTLS_SSL_IN_CONTENT_LEN > MBEDTLS_SSL_OUT_CONTENT_LEN)   \
395 	? (MBEDTLS_SSL_OUT_CONTENT_LEN)				     \
396 	: (MBEDTLS_SSL_IN_CONTENT_LEN)				     \
397 	)
398 
399 #if defined(CONFIG_NET_SOCKETS_TLS_SET_MAX_FRAGMENT_LENGTH) &&	\
400 	defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) &&		\
401 	(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN < 16384)
402 
403 BUILD_ASSERT(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN >= 512,
404 	     "Too small content length!");
405 
tls_mfl_code_from_content_len(size_t len)406 static inline unsigned char tls_mfl_code_from_content_len(size_t len)
407 {
408 	if (len >= 4096) {
409 		return MBEDTLS_SSL_MAX_FRAG_LEN_4096;
410 	} else if (len >= 2048) {
411 		return MBEDTLS_SSL_MAX_FRAG_LEN_2048;
412 	} else if (len >= 1024) {
413 		return MBEDTLS_SSL_MAX_FRAG_LEN_1024;
414 	} else if (len >= 512) {
415 		return MBEDTLS_SSL_MAX_FRAG_LEN_512;
416 	} else {
417 		return MBEDTLS_SSL_MAX_FRAG_LEN_INVALID;
418 	}
419 }
420 
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)421 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type)
422 {
423 	unsigned char mfl_code;
424 	size_t len = MBEDTLS_TLS_EXT_ADV_CONTENT_LEN;
425 
426 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
427 	if (type == SOCK_DGRAM && len > CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH) {
428 		len = CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH;
429 	}
430 #endif
431 	mfl_code = tls_mfl_code_from_content_len(len);
432 
433 	mbedtls_ssl_conf_max_frag_len(config, mfl_code);
434 }
435 #else
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)436 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type) {}
437 #endif
438 
439 /* Allocate TLS context. */
tls_alloc(void)440 static struct tls_context *tls_alloc(void)
441 {
442 	int i;
443 	struct tls_context *tls = NULL;
444 
445 	k_mutex_lock(&context_lock, K_FOREVER);
446 
447 	for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) {
448 		if (!tls_contexts[i].is_used) {
449 			tls = &tls_contexts[i];
450 			(void)memset(tls, 0, sizeof(*tls));
451 			tls->is_used = true;
452 			tls->options.verify_level = -1;
453 			tls->options.timeout_tx = K_FOREVER;
454 			tls->options.timeout_rx = K_FOREVER;
455 			tls->sock = -1;
456 
457 			NET_DBG("Allocated TLS context, %p", tls);
458 			break;
459 		}
460 	}
461 
462 	k_mutex_unlock(&context_lock);
463 
464 	if (tls) {
465 		k_sem_init(&tls->tls_established, 0, 1);
466 
467 		mbedtls_ssl_init(&tls->ssl);
468 		mbedtls_ssl_config_init(&tls->config);
469 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
470 		mbedtls_ssl_cookie_init(&tls->cookie);
471 		tls->options.dtls_handshake_timeout_min =
472 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN;
473 		tls->options.dtls_handshake_timeout_max =
474 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX;
475 		tls->options.dtls_cid.cid_len = 0;
476 		tls->options.dtls_cid.enabled = false;
477 		tls->options.dtls_handshake_on_connect = true;
478 #endif
479 #if defined(MBEDTLS_X509_CRT_PARSE_C)
480 		mbedtls_x509_crt_init(&tls->ca_chain);
481 		mbedtls_x509_crt_init(&tls->own_cert);
482 		mbedtls_pk_init(&tls->priv_key);
483 #endif
484 
485 #if defined(CONFIG_MBEDTLS_DEBUG)
486 		mbedtls_ssl_conf_dbg(&tls->config, zephyr_mbedtls_debug, NULL);
487 #endif
488 	} else {
489 		NET_WARN("Failed to allocate TLS context");
490 	}
491 
492 	return tls;
493 }
494 
495 /* Allocate new TLS context and copy the content from the source context. */
tls_clone(struct tls_context * source_tls)496 static struct tls_context *tls_clone(struct tls_context *source_tls)
497 {
498 	struct tls_context *target_tls;
499 
500 	target_tls = tls_alloc();
501 	if (!target_tls) {
502 		return NULL;
503 	}
504 
505 	target_tls->tls_version = source_tls->tls_version;
506 	target_tls->type = source_tls->type;
507 
508 	memcpy(&target_tls->options, &source_tls->options,
509 	       sizeof(target_tls->options));
510 
511 #if defined(MBEDTLS_X509_CRT_PARSE_C)
512 	if (target_tls->options.is_hostname_set) {
513 		mbedtls_ssl_set_hostname(&target_tls->ssl,
514 					 source_tls->ssl.hostname);
515 	}
516 #endif
517 
518 	return target_tls;
519 }
520 
521 /* Release TLS context. */
tls_release(struct tls_context * tls)522 static int tls_release(struct tls_context *tls)
523 {
524 	if (!PART_OF_ARRAY(tls_contexts, tls)) {
525 		NET_ERR("Invalid TLS context");
526 		return -EBADF;
527 	}
528 
529 	if (!tls->is_used) {
530 		NET_ERR("Deallocating unused TLS context");
531 		return -EBADF;
532 	}
533 
534 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
535 	mbedtls_ssl_cookie_free(&tls->cookie);
536 #endif
537 	mbedtls_ssl_config_free(&tls->config);
538 	mbedtls_ssl_free(&tls->ssl);
539 #if defined(MBEDTLS_X509_CRT_PARSE_C)
540 	mbedtls_x509_crt_free(&tls->ca_chain);
541 	mbedtls_x509_crt_free(&tls->own_cert);
542 	mbedtls_pk_free(&tls->priv_key);
543 #endif
544 
545 	tls->is_used = false;
546 
547 	return 0;
548 }
549 
peer_addr_cmp(const struct sockaddr * addr,const struct sockaddr * peer_addr)550 static bool peer_addr_cmp(const struct sockaddr *addr,
551 			  const struct sockaddr *peer_addr)
552 {
553 	if (addr->sa_family != peer_addr->sa_family) {
554 		return false;
555 	}
556 
557 	if (IS_ENABLED(CONFIG_NET_IPV6) && peer_addr->sa_family == AF_INET6) {
558 		struct sockaddr_in6 *addr1 = net_sin6(peer_addr);
559 		struct sockaddr_in6 *addr2 = net_sin6(addr);
560 
561 		return (addr1->sin6_port == addr2->sin6_port) &&
562 			net_ipv6_addr_cmp(&addr1->sin6_addr, &addr2->sin6_addr);
563 	} else if (IS_ENABLED(CONFIG_NET_IPV4) && peer_addr->sa_family == AF_INET) {
564 		struct sockaddr_in *addr1 = net_sin(peer_addr);
565 		struct sockaddr_in *addr2 = net_sin(addr);
566 
567 		return (addr1->sin_port == addr2->sin_port) &&
568 			net_ipv4_addr_cmp(&addr1->sin_addr, &addr2->sin_addr);
569 	}
570 
571 	return false;
572 }
573 
tls_session_save(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)574 static int tls_session_save(const struct sockaddr *peer_addr,
575 			    mbedtls_ssl_session *session)
576 {
577 	struct tls_session_cache *entry = NULL;
578 	size_t session_len;
579 	int ret;
580 
581 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
582 		if (client_cache[i].session == NULL) {
583 			/* New entry. */
584 			if (entry == NULL || entry->session != NULL) {
585 				entry = &client_cache[i];
586 			}
587 		} else {
588 			if (peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
589 				/* Reuse old entry for given address. */
590 				entry = &client_cache[i];
591 				break;
592 			}
593 
594 			/* Remember the oldest entry and reuse if needed. */
595 			if (entry == NULL ||
596 			    (entry->session != NULL &&
597 			     entry->timestamp < client_cache[i].timestamp)) {
598 				entry = &client_cache[i];
599 			}
600 		}
601 	}
602 
603 	/* Allocate session and save */
604 
605 	if (entry->session != NULL) {
606 		mbedtls_free(entry->session);
607 		entry->session = NULL;
608 	}
609 
610 	(void)mbedtls_ssl_session_save(session, NULL, 0, &session_len);
611 
612 	entry->session = mbedtls_calloc(1, session_len);
613 	if (entry->session == NULL) {
614 		NET_ERR("Failed to allocate session buffer.");
615 		return -ENOMEM;
616 	}
617 
618 	ret = mbedtls_ssl_session_save(session, entry->session, session_len,
619 				       &session_len);
620 	if (ret < 0) {
621 		NET_ERR("Failed to serialize session, err: -0x%x.", -ret);
622 		mbedtls_free(entry->session);
623 		entry->session = NULL;
624 		return -ENOMEM;
625 	}
626 
627 	entry->session_len = session_len;
628 	entry->timestamp = k_uptime_get();
629 	memcpy(&entry->peer_addr, peer_addr, sizeof(*peer_addr));
630 
631 	return 0;
632 }
633 
tls_session_get(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)634 static int tls_session_get(const struct sockaddr *peer_addr,
635 			   mbedtls_ssl_session *session)
636 {
637 	struct tls_session_cache *entry = NULL;
638 	int ret;
639 
640 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
641 		if (client_cache[i].session != NULL &&
642 		    peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
643 			entry = &client_cache[i];
644 			break;
645 		}
646 	}
647 
648 	if (entry == NULL) {
649 		return -ENOENT;
650 	}
651 
652 	ret = mbedtls_ssl_session_load(session, entry->session,
653 				       entry->session_len);
654 	if (ret < 0) {
655 		/* Discard corrupted session data. */
656 		mbedtls_free(entry->session);
657 		entry->session = NULL;
658 		NET_ERR("Failed to load TLS session %d", ret);
659 		return -EIO;
660 	}
661 
662 	return 0;
663 }
664 
tls_session_store(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)665 static void tls_session_store(struct tls_context *context,
666 			      const struct sockaddr *addr,
667 			      socklen_t addrlen)
668 {
669 	mbedtls_ssl_session session;
670 	struct sockaddr peer_addr = { 0 };
671 	int ret;
672 
673 	if (!context->options.cache_enabled) {
674 		return;
675 	}
676 
677 	memcpy(&peer_addr, addr, addrlen);
678 	mbedtls_ssl_session_init(&session);
679 
680 	ret = mbedtls_ssl_get_session(&context->ssl, &session);
681 	if (ret < 0) {
682 		NET_ERR("Failed to obtain session for %p", context);
683 		goto exit;
684 	}
685 
686 	ret = tls_session_save(&peer_addr, &session);
687 	if (ret < 0) {
688 		NET_ERR("Failed to save session for %p", context);
689 	}
690 
691 exit:
692 	mbedtls_ssl_session_free(&session);
693 }
694 
tls_session_restore(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)695 static void tls_session_restore(struct tls_context *context,
696 				const struct sockaddr *addr,
697 				socklen_t addrlen)
698 {
699 	mbedtls_ssl_session session;
700 	struct sockaddr peer_addr = { 0 };
701 	int ret;
702 
703 	if (!context->options.cache_enabled) {
704 		return;
705 	}
706 
707 	memcpy(&peer_addr, addr, addrlen);
708 	mbedtls_ssl_session_init(&session);
709 
710 	ret = tls_session_get(&peer_addr, &session);
711 	if (ret < 0) {
712 		NET_DBG("Session not found for %p", context);
713 		goto exit;
714 	}
715 
716 	ret = mbedtls_ssl_set_session(&context->ssl, &session);
717 	if (ret < 0) {
718 		NET_ERR("Failed to set session for %p", context);
719 	}
720 
721 exit:
722 	mbedtls_ssl_session_free(&session);
723 }
724 
tls_session_purge(void)725 static void tls_session_purge(void)
726 {
727 	tls_session_cache_reset();
728 
729 #if defined(MBEDTLS_SSL_CACHE_C)
730 	mbedtls_ssl_cache_free(&server_cache);
731 	mbedtls_ssl_cache_init(&server_cache);
732 #endif
733 }
734 
time_left(uint32_t start,uint32_t timeout)735 static inline int time_left(uint32_t start, uint32_t timeout)
736 {
737 	uint32_t elapsed = k_uptime_get_32() - start;
738 
739 	return timeout - elapsed;
740 }
741 
wait(int sock,int timeout,int event)742 static int wait(int sock, int timeout, int event)
743 {
744 	struct zsock_pollfd fds = {
745 		.fd = sock,
746 		.events = event,
747 	};
748 	int ret;
749 
750 	ret = zsock_poll(&fds, 1, timeout);
751 	if (ret < 0) {
752 		return ret;
753 	}
754 
755 	if (ret == 1) {
756 		if (fds.revents & ZSOCK_POLLNVAL) {
757 			return -EBADF;
758 		}
759 
760 		if (fds.revents & ZSOCK_POLLERR) {
761 			int optval;
762 			socklen_t optlen = sizeof(optval);
763 
764 			if (zsock_getsockopt(fds.fd, SOL_SOCKET, SO_ERROR,
765 					     &optval, &optlen) == 0) {
766 				NET_ERR("TLS underlying socket poll error %d",
767 					-optval);
768 				return -optval;
769 			}
770 
771 			return -EIO;
772 		}
773 	}
774 
775 	return 0;
776 }
777 
wait_for_reason(int sock,int timeout,int reason)778 static int wait_for_reason(int sock, int timeout, int reason)
779 {
780 	if (reason == MBEDTLS_ERR_SSL_WANT_READ) {
781 		return wait(sock, timeout, ZSOCK_POLLIN);
782 	}
783 
784 	if (reason == MBEDTLS_ERR_SSL_WANT_WRITE) {
785 		return wait(sock, timeout, ZSOCK_POLLOUT);
786 	}
787 
788 	/* Any other reason - no way to monitor, just wait for some time. */
789 	k_msleep(TLS_WAIT_MS);
790 
791 	return 0;
792 }
793 
is_blocking(int sock,int flags)794 static bool is_blocking(int sock, int flags)
795 {
796 	int sock_flags = zsock_fcntl(sock, F_GETFL, 0);
797 
798 	if (sock_flags == -1) {
799 		return false;
800 	}
801 
802 	return !((flags & ZSOCK_MSG_DONTWAIT) || (sock_flags & O_NONBLOCK));
803 }
804 
timeout_to_ms(k_timeout_t * timeout)805 static int timeout_to_ms(k_timeout_t *timeout)
806 {
807 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
808 		return 0;
809 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
810 		return SYS_FOREVER_MS;
811 	} else {
812 		return k_ticks_to_ms_floor32(timeout->ticks);
813 	}
814 }
815 
ctx_set_lock(struct tls_context * ctx,struct k_mutex * lock)816 static void ctx_set_lock(struct tls_context *ctx, struct k_mutex *lock)
817 {
818 	ctx->lock = lock;
819 }
820 
821 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
dtls_is_peer_addr_valid(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)822 static bool dtls_is_peer_addr_valid(struct tls_context *context,
823 				    const struct sockaddr *peer_addr,
824 				    socklen_t addrlen)
825 {
826 	if (context->dtls_peer_addrlen != addrlen) {
827 		return false;
828 	}
829 
830 	return peer_addr_cmp(&context->dtls_peer_addr, peer_addr);
831 }
832 
dtls_peer_address_set(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)833 static void dtls_peer_address_set(struct tls_context *context,
834 				  const struct sockaddr *peer_addr,
835 				  socklen_t addrlen)
836 {
837 	if (addrlen <= sizeof(context->dtls_peer_addr)) {
838 		memcpy(&context->dtls_peer_addr, peer_addr, addrlen);
839 		context->dtls_peer_addrlen = addrlen;
840 	}
841 }
842 
dtls_peer_address_get(struct tls_context * context,struct sockaddr * peer_addr,socklen_t * addrlen)843 static void dtls_peer_address_get(struct tls_context *context,
844 				  struct sockaddr *peer_addr,
845 				  socklen_t *addrlen)
846 {
847 	socklen_t len = MIN(context->dtls_peer_addrlen, *addrlen);
848 
849 	memcpy(peer_addr, &context->dtls_peer_addr, len);
850 	*addrlen = len;
851 }
852 
dtls_tx(void * ctx,const unsigned char * buf,size_t len)853 static int dtls_tx(void *ctx, const unsigned char *buf, size_t len)
854 {
855 	struct tls_context *tls_ctx = ctx;
856 	ssize_t sent;
857 
858 	sent = zsock_sendto(tls_ctx->sock, buf, len, ZSOCK_MSG_DONTWAIT,
859 			    &tls_ctx->dtls_peer_addr,
860 			    tls_ctx->dtls_peer_addrlen);
861 	if (sent < 0) {
862 		if (errno == EAGAIN) {
863 			return MBEDTLS_ERR_SSL_WANT_WRITE;
864 		}
865 
866 		return MBEDTLS_ERR_NET_SEND_FAILED;
867 	}
868 
869 	return sent;
870 }
871 
dtls_rx(void * ctx,unsigned char * buf,size_t len)872 static int dtls_rx(void *ctx, unsigned char *buf, size_t len)
873 {
874 	struct tls_context *tls_ctx = ctx;
875 	socklen_t addrlen = sizeof(struct sockaddr);
876 	struct sockaddr addr;
877 	int err;
878 	ssize_t received;
879 
880 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
881 				  ZSOCK_MSG_DONTWAIT, &addr, &addrlen);
882 	if (received < 0) {
883 		if (errno == EAGAIN) {
884 			return MBEDTLS_ERR_SSL_WANT_READ;
885 		}
886 
887 		return MBEDTLS_ERR_NET_RECV_FAILED;
888 	}
889 
890 	if (tls_ctx->dtls_peer_addrlen == 0) {
891 		/* Only allow to store peer address for DTLS servers. */
892 		if (tls_ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
893 			dtls_peer_address_set(tls_ctx, &addr, addrlen);
894 
895 			err = mbedtls_ssl_set_client_transport_id(
896 				&tls_ctx->ssl,
897 				(const unsigned char *)&addr, addrlen);
898 			if (err < 0) {
899 				return err;
900 			}
901 		} else {
902 			/* For clients it's incorrect to receive when
903 			 * no peer has been set up.
904 			 */
905 			return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED;
906 		}
907 	} else if (!dtls_is_peer_addr_valid(tls_ctx, &addr, addrlen)) {
908 		return MBEDTLS_ERR_SSL_WANT_READ;
909 	}
910 
911 	return received;
912 }
913 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
914 
tls_tx(void * ctx,const unsigned char * buf,size_t len)915 static int tls_tx(void *ctx, const unsigned char *buf, size_t len)
916 {
917 	struct tls_context *tls_ctx = ctx;
918 	ssize_t sent;
919 
920 	sent = zsock_sendto(tls_ctx->sock, buf, len,
921 			    ZSOCK_MSG_DONTWAIT, NULL, 0);
922 	if (sent < 0) {
923 		if (errno == EAGAIN) {
924 			return MBEDTLS_ERR_SSL_WANT_WRITE;
925 		}
926 
927 		return MBEDTLS_ERR_NET_SEND_FAILED;
928 	}
929 
930 	return sent;
931 }
932 
tls_rx(void * ctx,unsigned char * buf,size_t len)933 static int tls_rx(void *ctx, unsigned char *buf, size_t len)
934 {
935 	struct tls_context *tls_ctx = ctx;
936 	ssize_t received;
937 
938 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
939 				  ZSOCK_MSG_DONTWAIT, NULL, 0);
940 	if (received < 0) {
941 		if (errno == EAGAIN) {
942 			return MBEDTLS_ERR_SSL_WANT_READ;
943 		}
944 
945 		return MBEDTLS_ERR_NET_RECV_FAILED;
946 	}
947 
948 	return received;
949 }
950 
951 #if defined(MBEDTLS_X509_CRT_PARSE_C)
crt_is_pem(const unsigned char * buf,size_t buflen)952 static bool crt_is_pem(const unsigned char *buf, size_t buflen)
953 {
954 	return (buflen != 0 && buf[buflen - 1] == '\0' &&
955 		strstr((const char *)buf, "-----BEGIN CERTIFICATE-----") != NULL);
956 }
957 #endif
958 
tls_add_ca_certificate(struct tls_context * tls,struct tls_credential * ca_cert)959 static int tls_add_ca_certificate(struct tls_context *tls,
960 				  struct tls_credential *ca_cert)
961 {
962 #if defined(MBEDTLS_X509_CRT_PARSE_C)
963 	int err;
964 
965 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
966 	    crt_is_pem(ca_cert->buf, ca_cert->len)) {
967 		err = mbedtls_x509_crt_parse(&tls->ca_chain, ca_cert->buf,
968 					     ca_cert->len);
969 	} else {
970 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->ca_chain,
971 							ca_cert->buf,
972 							ca_cert->len);
973 	}
974 
975 	if (err != 0) {
976 		NET_ERR("Failed to parse CA certificate, err: -0x%x", -err);
977 		return -EINVAL;
978 	}
979 
980 	return 0;
981 #endif /* MBEDTLS_X509_CRT_PARSE_C */
982 
983 	return -ENOTSUP;
984 }
985 
tls_set_ca_chain(struct tls_context * tls)986 static void tls_set_ca_chain(struct tls_context *tls)
987 {
988 #if defined(MBEDTLS_X509_CRT_PARSE_C)
989 	mbedtls_ssl_conf_ca_chain(&tls->config, &tls->ca_chain, NULL);
990 	mbedtls_ssl_conf_cert_profile(&tls->config,
991 				      &mbedtls_x509_crt_profile_default);
992 #endif /* MBEDTLS_X509_CRT_PARSE_C */
993 }
994 
tls_add_own_cert(struct tls_context * tls,struct tls_credential * own_cert)995 static int tls_add_own_cert(struct tls_context *tls,
996 			    struct tls_credential *own_cert)
997 {
998 #if defined(MBEDTLS_X509_CRT_PARSE_C)
999 	int err;
1000 
1001 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
1002 	    crt_is_pem(own_cert->buf, own_cert->len)) {
1003 		err = mbedtls_x509_crt_parse(&tls->own_cert,
1004 					     own_cert->buf, own_cert->len);
1005 	} else {
1006 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->own_cert,
1007 							own_cert->buf,
1008 							own_cert->len);
1009 	}
1010 
1011 	if (err != 0) {
1012 		return -EINVAL;
1013 	}
1014 
1015 	return 0;
1016 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1017 
1018 	return -ENOTSUP;
1019 }
1020 
tls_set_own_cert(struct tls_context * tls)1021 static int tls_set_own_cert(struct tls_context *tls)
1022 {
1023 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1024 	int err = mbedtls_ssl_conf_own_cert(&tls->config, &tls->own_cert,
1025 					    &tls->priv_key);
1026 	if (err != 0) {
1027 		err = -ENOMEM;
1028 	}
1029 
1030 	return err;
1031 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1032 
1033 	return -ENOTSUP;
1034 }
1035 
tls_set_private_key(struct tls_context * tls,struct tls_credential * priv_key)1036 static int tls_set_private_key(struct tls_context *tls,
1037 			       struct tls_credential *priv_key)
1038 {
1039 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1040 	int err;
1041 
1042 	err = mbedtls_pk_parse_key(&tls->priv_key, priv_key->buf,
1043 				   priv_key->len, NULL, 0,
1044 				   tls_ctr_drbg_random, NULL);
1045 	if (err != 0) {
1046 		return -EINVAL;
1047 	}
1048 
1049 	return 0;
1050 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1051 
1052 	return -ENOTSUP;
1053 }
1054 
tls_set_psk(struct tls_context * tls,struct tls_credential * psk,struct tls_credential * psk_id)1055 static int tls_set_psk(struct tls_context *tls,
1056 		       struct tls_credential *psk,
1057 		       struct tls_credential *psk_id)
1058 {
1059 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
1060 	int err = mbedtls_ssl_conf_psk(&tls->config,
1061 				       psk->buf, psk->len,
1062 				       (const unsigned char *)psk_id->buf,
1063 				       psk_id->len);
1064 	if (err != 0) {
1065 		return -EINVAL;
1066 	}
1067 
1068 	return 0;
1069 #endif
1070 
1071 	return -ENOTSUP;
1072 }
1073 
tls_set_credential(struct tls_context * tls,struct tls_credential * cred)1074 static int tls_set_credential(struct tls_context *tls,
1075 			      struct tls_credential *cred)
1076 {
1077 	switch (cred->type) {
1078 	case TLS_CREDENTIAL_CA_CERTIFICATE:
1079 		return tls_add_ca_certificate(tls, cred);
1080 
1081 	case TLS_CREDENTIAL_SERVER_CERTIFICATE:
1082 		return tls_add_own_cert(tls, cred);
1083 
1084 	case TLS_CREDENTIAL_PRIVATE_KEY:
1085 		return tls_set_private_key(tls, cred);
1086 	break;
1087 
1088 	case TLS_CREDENTIAL_PSK:
1089 	{
1090 		struct tls_credential *psk_id =
1091 			credential_get(cred->tag, TLS_CREDENTIAL_PSK_ID);
1092 		if (!psk_id) {
1093 			return -ENOENT;
1094 		}
1095 
1096 		return tls_set_psk(tls, cred, psk_id);
1097 	}
1098 
1099 	case TLS_CREDENTIAL_PSK_ID:
1100 		/* Ignore PSK ID - it will be used together
1101 		 * with PSK
1102 		 */
1103 		break;
1104 
1105 	default:
1106 		return -EINVAL;
1107 	}
1108 
1109 	return 0;
1110 }
1111 
tls_mbedtls_set_credentials(struct tls_context * tls)1112 static int tls_mbedtls_set_credentials(struct tls_context *tls)
1113 {
1114 	struct tls_credential *cred;
1115 	sec_tag_t tag;
1116 	int i, err = 0;
1117 	bool tag_found, ca_cert_present = false, own_cert_present = false;
1118 
1119 	credentials_lock();
1120 
1121 	for (i = 0; i < tls->options.sec_tag_list.sec_tag_count; i++) {
1122 		tag = tls->options.sec_tag_list.sec_tags[i];
1123 		cred = NULL;
1124 		tag_found = false;
1125 
1126 		while ((cred = credential_next_get(tag, cred)) != NULL) {
1127 			tag_found = true;
1128 
1129 			err = tls_set_credential(tls, cred);
1130 			if (err != 0) {
1131 				goto exit;
1132 			}
1133 
1134 			if (cred->type == TLS_CREDENTIAL_CA_CERTIFICATE) {
1135 				ca_cert_present = true;
1136 			} else if (cred->type == TLS_CREDENTIAL_SERVER_CERTIFICATE) {
1137 				own_cert_present = true;
1138 			}
1139 		}
1140 
1141 		if (!tag_found) {
1142 			err = -ENOENT;
1143 			goto exit;
1144 		}
1145 	}
1146 
1147 exit:
1148 	credentials_unlock();
1149 
1150 	if (err == 0) {
1151 		if (ca_cert_present) {
1152 			tls_set_ca_chain(tls);
1153 		}
1154 		if (own_cert_present) {
1155 			err = tls_set_own_cert(tls);
1156 		}
1157 	}
1158 
1159 	return err;
1160 }
1161 
tls_mbedtls_reset(struct tls_context * context)1162 static int tls_mbedtls_reset(struct tls_context *context)
1163 {
1164 	int ret;
1165 
1166 	ret = mbedtls_ssl_session_reset(&context->ssl);
1167 	if (ret != 0) {
1168 		return ret;
1169 	}
1170 
1171 	k_sem_reset(&context->tls_established);
1172 
1173 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1174 	/* Server role: reset the address so that a new
1175 	 *              client can connect w/o a need to reopen a socket
1176 	 * Client role: keep peer addr so socket can continue to be used
1177 	 *              even on handshake timeout
1178 	 */
1179 	if (context->options.role == MBEDTLS_SSL_IS_SERVER) {
1180 		(void)memset(&context->dtls_peer_addr, 0,
1181 			     sizeof(context->dtls_peer_addr));
1182 		context->dtls_peer_addrlen = 0;
1183 	}
1184 #endif
1185 
1186 	return 0;
1187 }
1188 
tls_mbedtls_handshake(struct tls_context * context,k_timeout_t timeout)1189 static int tls_mbedtls_handshake(struct tls_context *context,
1190 				 k_timeout_t timeout)
1191 {
1192 	k_timepoint_t end;
1193 	int ret;
1194 
1195 	context->handshake_in_progress = true;
1196 
1197 	end = sys_timepoint_calc(timeout);
1198 
1199 	while ((ret = mbedtls_ssl_handshake(&context->ssl)) != 0) {
1200 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
1201 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
1202 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
1203 		    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
1204 			int timeout_ms;
1205 
1206 			/* Blocking timeout. */
1207 			timeout = sys_timepoint_timeout(end);
1208 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1209 				ret = -EAGAIN;
1210 				break;
1211 			}
1212 
1213 			/* Block. */
1214 			timeout_ms = timeout_to_ms(&timeout);
1215 
1216 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1217 			if (context->type == SOCK_DGRAM) {
1218 				int timeout_dtls =
1219 					dtls_get_remaining_timeout(context);
1220 
1221 				if (timeout_dtls != SYS_FOREVER_MS) {
1222 					if (timeout_ms == SYS_FOREVER_MS) {
1223 						timeout_ms = timeout_dtls;
1224 					} else {
1225 						timeout_ms = MIN(timeout_dtls,
1226 								 timeout_ms);
1227 					}
1228 				}
1229 			}
1230 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1231 
1232 			ret = wait_for_reason(context->sock, timeout_ms, ret);
1233 			if (ret != 0) {
1234 				break;
1235 			}
1236 
1237 			continue;
1238 		} else if (ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED) {
1239 			ret = tls_mbedtls_reset(context);
1240 			if (ret == 0) {
1241 				if (!K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1242 					continue;
1243 				}
1244 
1245 				ret = -EAGAIN;
1246 				break;
1247 			}
1248 		} else if (ret == MBEDTLS_ERR_SSL_TIMEOUT) {
1249 			/* MbedTLS API documentation requires session to
1250 			 * be reset in this case
1251 			 */
1252 			ret = tls_mbedtls_reset(context);
1253 			if (ret == 0) {
1254 				NET_ERR("TLS handshake timeout");
1255 				context->error = ETIMEDOUT;
1256 				ret = -ETIMEDOUT;
1257 				break;
1258 			}
1259 		} else {
1260 			/* MbedTLS API documentation requires session to
1261 			 * be reset in other error cases
1262 			 */
1263 			NET_ERR("TLS handshake error: -0x%x", -ret);
1264 			ret = tls_mbedtls_reset(context);
1265 			if (ret == 0) {
1266 				context->error = ECONNABORTED;
1267 				ret = -ECONNABORTED;
1268 				break;
1269 			}
1270 		}
1271 
1272 		/* Avoid constant loop if tls_mbedtls_reset fails */
1273 		NET_ERR("TLS reset error: -0x%x", -ret);
1274 		context->error = ECONNABORTED;
1275 		ret = -ECONNABORTED;
1276 		break;
1277 	}
1278 
1279 	if (ret == 0) {
1280 		k_sem_give(&context->tls_established);
1281 	}
1282 
1283 	context->handshake_in_progress = false;
1284 
1285 	return ret;
1286 }
1287 
tls_mbedtls_init(struct tls_context * context,bool is_server)1288 static int tls_mbedtls_init(struct tls_context *context, bool is_server)
1289 {
1290 	int role, type, ret;
1291 
1292 	role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
1293 
1294 	type = (context->type == SOCK_STREAM) ?
1295 		MBEDTLS_SSL_TRANSPORT_STREAM :
1296 		MBEDTLS_SSL_TRANSPORT_DATAGRAM;
1297 
1298 	if (type == MBEDTLS_SSL_TRANSPORT_STREAM) {
1299 		mbedtls_ssl_set_bio(&context->ssl, context,
1300 				    tls_tx, tls_rx, NULL);
1301 	} else {
1302 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1303 		mbedtls_ssl_set_bio(&context->ssl, context,
1304 				    dtls_tx, dtls_rx, NULL);
1305 #else
1306 		return -ENOTSUP;
1307 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1308 	}
1309 
1310 	ret = mbedtls_ssl_config_defaults(&context->config, role, type,
1311 					  MBEDTLS_SSL_PRESET_DEFAULT);
1312 	if (ret != 0) {
1313 		/* According to mbedTLS API documentation,
1314 		 * mbedtls_ssl_config_defaults can fail due to memory
1315 		 * allocation failure
1316 		 */
1317 		return -ENOMEM;
1318 	}
1319 	tls_set_max_frag_len(&context->config, context->type);
1320 
1321 #if defined(MBEDTLS_SSL_RENEGOTIATION)
1322 	mbedtls_ssl_conf_legacy_renegotiation(&context->config,
1323 					   MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE);
1324 	mbedtls_ssl_conf_renegotiation(&context->config,
1325 				       MBEDTLS_SSL_RENEGOTIATION_ENABLED);
1326 #endif
1327 
1328 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1329 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1330 		/* DTLS requires timer callbacks to operate */
1331 		mbedtls_ssl_set_timer_cb(&context->ssl,
1332 					 &context->dtls_timing,
1333 					 dtls_timing_set_delay,
1334 					 dtls_timing_get_delay);
1335 		mbedtls_ssl_conf_handshake_timeout(&context->config,
1336 				context->options.dtls_handshake_timeout_min,
1337 				context->options.dtls_handshake_timeout_max);
1338 
1339 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1340 		if (context->options.dtls_cid.enabled) {
1341 			ret = mbedtls_ssl_conf_cid(
1342 					&context->config,
1343 					context->options.dtls_cid.cid_len,
1344 					MBEDTLS_SSL_UNEXPECTED_CID_IGNORE);
1345 			if (ret != 0) {
1346 				return -EINVAL;
1347 			}
1348 		}
1349 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1350 
1351 		/* Configure cookie for DTLS server */
1352 		if (role == MBEDTLS_SSL_IS_SERVER) {
1353 			ret = mbedtls_ssl_cookie_setup(&context->cookie,
1354 						       tls_ctr_drbg_random,
1355 						       NULL);
1356 			if (ret != 0) {
1357 				return -ENOMEM;
1358 			}
1359 
1360 			mbedtls_ssl_conf_dtls_cookies(&context->config,
1361 						      mbedtls_ssl_cookie_write,
1362 						      mbedtls_ssl_cookie_check,
1363 						      &context->cookie);
1364 
1365 			mbedtls_ssl_conf_read_timeout(
1366 					&context->config,
1367 					CONFIG_NET_SOCKETS_DTLS_TIMEOUT);
1368 		}
1369 	}
1370 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1371 
1372 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1373 	/* For TLS clients, set hostname to empty string to enforce it's
1374 	 * verification - only if hostname option was not set. Otherwise
1375 	 * depend on user configuration.
1376 	 */
1377 	if (!is_server && !context->options.is_hostname_set) {
1378 		mbedtls_ssl_set_hostname(&context->ssl, "");
1379 	}
1380 #endif
1381 
1382 	/* If verification level was specified explicitly, set it. Otherwise,
1383 	 * use mbedTLS default values (required for client, none for server)
1384 	 */
1385 	if (context->options.verify_level != -1) {
1386 		mbedtls_ssl_conf_authmode(&context->config,
1387 					  context->options.verify_level);
1388 	}
1389 
1390 	mbedtls_ssl_conf_rng(&context->config,
1391 			     tls_ctr_drbg_random,
1392 			     NULL);
1393 
1394 	ret = tls_mbedtls_set_credentials(context);
1395 	if (ret != 0) {
1396 		return ret;
1397 	}
1398 
1399 	if (context->options.ciphersuites[0] != 0) {
1400 		/* Specific ciphersuites configured, so use them */
1401 		NET_DBG("Using user-specified ciphersuites");
1402 		mbedtls_ssl_conf_ciphersuites(&context->config,
1403 					      context->options.ciphersuites);
1404 	}
1405 
1406 #if defined(CONFIG_MBEDTLS_SSL_ALPN)
1407 	if (ALPN_MAX_PROTOCOLS && context->options.alpn_list[0] != NULL) {
1408 		ret = mbedtls_ssl_conf_alpn_protocols(&context->config,
1409 				context->options.alpn_list);
1410 		if (ret != 0) {
1411 			return -EINVAL;
1412 		}
1413 	}
1414 #endif /* CONFIG_MBEDTLS_SSL_ALPN */
1415 
1416 #if defined(MBEDTLS_SSL_CACHE_C)
1417 	if (is_server && context->options.cache_enabled) {
1418 		mbedtls_ssl_conf_session_cache(&context->config, &server_cache,
1419 					       mbedtls_ssl_cache_get,
1420 					       mbedtls_ssl_cache_set);
1421 	}
1422 #endif
1423 
1424 	ret = mbedtls_ssl_setup(&context->ssl,
1425 				&context->config);
1426 	if (ret != 0) {
1427 		/* According to mbedTLS API documentation,
1428 		 * mbedtls_ssl_setup can fail due to memory allocation failure
1429 		 */
1430 		return -ENOMEM;
1431 	}
1432 
1433 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) && defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1434 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1435 		if (context->options.dtls_cid.enabled) {
1436 			ret = mbedtls_ssl_set_cid(&context->ssl, MBEDTLS_SSL_CID_ENABLED,
1437 						  context->options.dtls_cid.cid,
1438 						  context->options.dtls_cid.cid_len);
1439 			if (ret != 0) {
1440 				return -EINVAL;
1441 			}
1442 		}
1443 	}
1444 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS && CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1445 
1446 	context->is_initialized = true;
1447 
1448 	return 0;
1449 }
1450 
tls_opt_sec_tag_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1451 static int tls_opt_sec_tag_list_set(struct tls_context *context,
1452 				    const void *optval, socklen_t optlen)
1453 {
1454 	int sec_tag_cnt;
1455 
1456 	if (!optval) {
1457 		return -EINVAL;
1458 	}
1459 
1460 	if (optlen % sizeof(sec_tag_t) != 0) {
1461 		return -EINVAL;
1462 	}
1463 
1464 	sec_tag_cnt = optlen / sizeof(sec_tag_t);
1465 	if (sec_tag_cnt >
1466 		ARRAY_SIZE(context->options.sec_tag_list.sec_tags)) {
1467 		return -EINVAL;
1468 	}
1469 
1470 	memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
1471 	context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;
1472 
1473 	return 0;
1474 }
1475 
sock_opt_protocol_get(struct tls_context * context,void * optval,socklen_t * optlen)1476 static int sock_opt_protocol_get(struct tls_context *context,
1477 				 void *optval, socklen_t *optlen)
1478 {
1479 	int protocol = (int)context->tls_version;
1480 
1481 	if (*optlen != sizeof(protocol)) {
1482 		return -EINVAL;
1483 	}
1484 
1485 	*(int *)optval = protocol;
1486 
1487 	return 0;
1488 }
1489 
tls_opt_sec_tag_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1490 static int tls_opt_sec_tag_list_get(struct tls_context *context,
1491 				    void *optval, socklen_t *optlen)
1492 {
1493 	int len;
1494 
1495 	if (*optlen % sizeof(sec_tag_t) != 0 || *optlen == 0) {
1496 		return -EINVAL;
1497 	}
1498 
1499 	len = MIN(context->options.sec_tag_list.sec_tag_count *
1500 		  sizeof(sec_tag_t), *optlen);
1501 
1502 	memcpy(optval, context->options.sec_tag_list.sec_tags, len);
1503 	*optlen = len;
1504 
1505 	return 0;
1506 }
1507 
tls_opt_hostname_set(struct tls_context * context,const void * optval,socklen_t optlen)1508 static int tls_opt_hostname_set(struct tls_context *context,
1509 				const void *optval, socklen_t optlen)
1510 {
1511 	ARG_UNUSED(optlen);
1512 
1513 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1514 	if (mbedtls_ssl_set_hostname(&context->ssl, optval) != 0) {
1515 		return -EINVAL;
1516 	}
1517 #else
1518 	return -ENOPROTOOPT;
1519 #endif
1520 
1521 	context->options.is_hostname_set = true;
1522 
1523 	return 0;
1524 }
1525 
tls_opt_ciphersuite_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1526 static int tls_opt_ciphersuite_list_set(struct tls_context *context,
1527 					const void *optval, socklen_t optlen)
1528 {
1529 	int cipher_cnt;
1530 
1531 	if (!optval) {
1532 		return -EINVAL;
1533 	}
1534 
1535 	if (optlen % sizeof(int) != 0) {
1536 		return -EINVAL;
1537 	}
1538 
1539 	cipher_cnt = optlen / sizeof(int);
1540 
1541 	/* + 1 for 0-termination. */
1542 	if (cipher_cnt + 1 > ARRAY_SIZE(context->options.ciphersuites)) {
1543 		return -EINVAL;
1544 	}
1545 
1546 	memcpy(context->options.ciphersuites, optval, optlen);
1547 	context->options.ciphersuites[cipher_cnt] = 0;
1548 
1549 	mbedtls_ssl_conf_ciphersuites(&context->config,
1550 				      context->options.ciphersuites);
1551 	return 0;
1552 }
1553 
tls_opt_ciphersuite_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1554 static int tls_opt_ciphersuite_list_get(struct tls_context *context,
1555 					void *optval, socklen_t *optlen)
1556 {
1557 	const int *selected_ciphers;
1558 	int cipher_cnt, i = 0;
1559 	int *ciphers = optval;
1560 
1561 	if (*optlen % sizeof(int) != 0 || *optlen == 0) {
1562 		return -EINVAL;
1563 	}
1564 
1565 	if (context->options.ciphersuites[0] == 0) {
1566 		/* No specific ciphersuites configured, return all available. */
1567 		selected_ciphers = mbedtls_ssl_list_ciphersuites();
1568 	} else {
1569 		selected_ciphers = context->options.ciphersuites;
1570 	}
1571 
1572 	cipher_cnt = *optlen / sizeof(int);
1573 	while (selected_ciphers[i] != 0) {
1574 		ciphers[i] = selected_ciphers[i];
1575 
1576 		if (++i == cipher_cnt) {
1577 			break;
1578 		}
1579 	}
1580 
1581 	*optlen = i * sizeof(int);
1582 
1583 	return 0;
1584 }
1585 
tls_opt_ciphersuite_used_get(struct tls_context * context,void * optval,socklen_t * optlen)1586 static int tls_opt_ciphersuite_used_get(struct tls_context *context,
1587 					void *optval, socklen_t *optlen)
1588 {
1589 	const char *ciph;
1590 
1591 	if (*optlen != sizeof(int)) {
1592 		return -EINVAL;
1593 	}
1594 
1595 	ciph = mbedtls_ssl_get_ciphersuite(&context->ssl);
1596 	if (ciph == NULL) {
1597 		return -ENOTCONN;
1598 	}
1599 
1600 	*(int *)optval = mbedtls_ssl_get_ciphersuite_id(ciph);
1601 
1602 	return 0;
1603 }
1604 
tls_opt_alpn_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1605 static int tls_opt_alpn_list_set(struct tls_context *context,
1606 				 const void *optval, socklen_t optlen)
1607 {
1608 	int alpn_cnt;
1609 
1610 	if (!ALPN_MAX_PROTOCOLS) {
1611 		return -EINVAL;
1612 	}
1613 
1614 	if (!optval) {
1615 		return -EINVAL;
1616 	}
1617 
1618 	if (optlen % sizeof(const char *) != 0) {
1619 		return -EINVAL;
1620 	}
1621 
1622 	alpn_cnt = optlen / sizeof(const char *);
1623 	/* + 1 for NULL-termination. */
1624 	if (alpn_cnt + 1 > ARRAY_SIZE(context->options.alpn_list)) {
1625 		return -EINVAL;
1626 	}
1627 
1628 	memcpy(context->options.alpn_list, optval, optlen);
1629 	context->options.alpn_list[alpn_cnt] = NULL;
1630 
1631 	return 0;
1632 }
1633 
1634 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
tls_opt_dtls_handshake_timeout_get(struct tls_context * context,void * optval,socklen_t * optlen,bool is_max)1635 static int tls_opt_dtls_handshake_timeout_get(struct tls_context *context,
1636 					      void *optval, socklen_t *optlen,
1637 					      bool is_max)
1638 {
1639 	uint32_t *val = (uint32_t *)optval;
1640 
1641 	if (sizeof(uint32_t) != *optlen) {
1642 		return -EINVAL;
1643 	}
1644 
1645 	if (is_max) {
1646 		*val = context->options.dtls_handshake_timeout_max;
1647 	} else {
1648 		*val = context->options.dtls_handshake_timeout_min;
1649 	}
1650 
1651 	return 0;
1652 }
1653 
tls_opt_dtls_handshake_timeout_set(struct tls_context * context,const void * optval,socklen_t optlen,bool is_max)1654 static int tls_opt_dtls_handshake_timeout_set(struct tls_context *context,
1655 					      const void *optval,
1656 					      socklen_t optlen, bool is_max)
1657 {
1658 	uint32_t *val = (uint32_t *)optval;
1659 
1660 	if (!optval) {
1661 		return -EINVAL;
1662 	}
1663 
1664 	if (sizeof(uint32_t) != optlen) {
1665 		return -EINVAL;
1666 	}
1667 
1668 	/* If mbedTLS context not inited, it will
1669 	 * use these values upon init.
1670 	 */
1671 	if (is_max) {
1672 		context->options.dtls_handshake_timeout_max = *val;
1673 	} else {
1674 		context->options.dtls_handshake_timeout_min = *val;
1675 	}
1676 
1677 	/* If mbedTLS context already inited, we need to
1678 	 * update mbedTLS config for it to take effect
1679 	 */
1680 	mbedtls_ssl_conf_handshake_timeout(&context->config,
1681 			context->options.dtls_handshake_timeout_min,
1682 			context->options.dtls_handshake_timeout_max);
1683 
1684 	return 0;
1685 }
1686 
tls_opt_dtls_connection_id_set(struct tls_context * context,const void * optval,socklen_t optlen)1687 static int tls_opt_dtls_connection_id_set(struct tls_context *context,
1688 					  const void *optval, socklen_t optlen)
1689 {
1690 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1691 	int value;
1692 
1693 	if (optlen > 0 && optval == NULL) {
1694 		return -EINVAL;
1695 	}
1696 
1697 	if (optlen != sizeof(int)) {
1698 		return -EINVAL;
1699 	}
1700 
1701 	value = *((int *)optval);
1702 
1703 	switch (value) {
1704 	case TLS_DTLS_CID_DISABLED:
1705 		context->options.dtls_cid.enabled = false;
1706 		context->options.dtls_cid.cid_len = 0;
1707 		break;
1708 	case TLS_DTLS_CID_SUPPORTED:
1709 		context->options.dtls_cid.enabled = true;
1710 		context->options.dtls_cid.cid_len = 0;
1711 		break;
1712 	case TLS_DTLS_CID_ENABLED:
1713 		context->options.dtls_cid.enabled = true;
1714 		if (context->options.dtls_cid.cid_len == 0) {
1715 			/* generate random self cid */
1716 #if defined(CONFIG_CSPRNG_ENABLED)
1717 			sys_csrand_get(context->options.dtls_cid.cid,
1718 				       MBEDTLS_SSL_CID_OUT_LEN_MAX);
1719 #else
1720 			sys_rand_get(context->options.dtls_cid.cid,
1721 				     MBEDTLS_SSL_CID_OUT_LEN_MAX);
1722 #endif
1723 			context->options.dtls_cid.cid_len = MBEDTLS_SSL_CID_OUT_LEN_MAX;
1724 		}
1725 		break;
1726 	default:
1727 		return -EINVAL;
1728 	}
1729 
1730 	return 0;
1731 #else
1732 	return -ENOPROTOOPT;
1733 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1734 }
1735 
tls_opt_dtls_connection_id_value_set(struct tls_context * context,const void * optval,socklen_t optlen)1736 static int tls_opt_dtls_connection_id_value_set(struct tls_context *context,
1737 						const void *optval,
1738 						socklen_t optlen)
1739 {
1740 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1741 	if (optlen > 0 && optval == NULL) {
1742 		return -EINVAL;
1743 	}
1744 
1745 	if (optlen > MBEDTLS_SSL_CID_IN_LEN_MAX) {
1746 		return -EINVAL;
1747 	}
1748 
1749 	context->options.dtls_cid.cid_len = optlen;
1750 	memcpy(context->options.dtls_cid.cid, optval, optlen);
1751 
1752 	return 0;
1753 #else
1754 	return -ENOPROTOOPT;
1755 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1756 }
1757 
tls_opt_dtls_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1758 static int tls_opt_dtls_connection_id_value_get(struct tls_context *context,
1759 						void *optval, socklen_t *optlen)
1760 {
1761 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1762 
1763 	if (*optlen < context->options.dtls_cid.cid_len) {
1764 		return -EINVAL;
1765 	}
1766 
1767 	*optlen = context->options.dtls_cid.cid_len;
1768 	memcpy(optval, context->options.dtls_cid.cid, *optlen);
1769 
1770 	return 0;
1771 #else
1772 	return -ENOPROTOOPT;
1773 #endif
1774 }
1775 
tls_opt_dtls_peer_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1776 static int tls_opt_dtls_peer_connection_id_value_get(struct tls_context *context,
1777 						     void *optval,
1778 						     socklen_t *optlen)
1779 {
1780 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1781 	int enabled = false;
1782 	int ret;
1783 
1784 	if (!context->is_initialized) {
1785 		return -ENOTCONN;
1786 	}
1787 
1788 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled, optval, optlen);
1789 	if (!enabled) {
1790 		*optlen = 0;
1791 	}
1792 	return ret;
1793 #else
1794 	return -ENOPROTOOPT;
1795 #endif
1796 }
1797 
tls_opt_dtls_connection_id_status_get(struct tls_context * context,void * optval,socklen_t * optlen)1798 static int tls_opt_dtls_connection_id_status_get(struct tls_context *context,
1799 					  void *optval, socklen_t *optlen)
1800 {
1801 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1802 	struct tls_dtls_cid cid;
1803 	int ret;
1804 	int val;
1805 	int enabled;
1806 	bool have_self_cid;
1807 	bool have_peer_cid;
1808 
1809 	if (sizeof(int) != *optlen) {
1810 		return -EINVAL;
1811 	}
1812 
1813 	if (!context->is_initialized) {
1814 		return -ENOTCONN;
1815 	}
1816 
1817 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled,
1818 				       cid.cid,
1819 				       &cid.cid_len);
1820 	if (ret) {
1821 		/* Handshake is not complete */
1822 		return -EAGAIN;
1823 	}
1824 
1825 	cid.enabled = (enabled == MBEDTLS_SSL_CID_ENABLED);
1826 	have_self_cid = (context->options.dtls_cid.cid_len != 0);
1827 	have_peer_cid = (cid.cid_len != 0);
1828 
1829 	if (!context->options.dtls_cid.enabled) {
1830 		val = TLS_DTLS_CID_STATUS_DISABLED;
1831 	} else if (have_self_cid && have_peer_cid) {
1832 		val = TLS_DTLS_CID_STATUS_BIDIRECTIONAL;
1833 	} else if (have_self_cid) {
1834 		val = TLS_DTLS_CID_STATUS_DOWNLINK;
1835 	} else if (have_peer_cid) {
1836 		val = TLS_DTLS_CID_STATUS_UPLINK;
1837 	} else {
1838 		val = TLS_DTLS_CID_STATUS_DISABLED;
1839 	}
1840 
1841 	*((int *)optval) = val;
1842 	return 0;
1843 #else
1844 	return -ENOPROTOOPT;
1845 #endif
1846 }
1847 
tls_opt_dtls_handshake_on_connect_set(struct tls_context * context,const void * optval,socklen_t optlen)1848 static int tls_opt_dtls_handshake_on_connect_set(struct tls_context *context,
1849 						 const void *optval,
1850 						 socklen_t optlen)
1851 {
1852 	int *val = (int *)optval;
1853 
1854 	if (!optval) {
1855 		return -EINVAL;
1856 	}
1857 
1858 	if (sizeof(int) != optlen) {
1859 		return -EINVAL;
1860 	}
1861 
1862 	context->options.dtls_handshake_on_connect = (bool)*val;
1863 
1864 	return 0;
1865 }
1866 
tls_opt_dtls_handshake_on_connect_get(struct tls_context * context,void * optval,socklen_t * optlen)1867 static int tls_opt_dtls_handshake_on_connect_get(struct tls_context *context,
1868 						 void *optval,
1869 						 socklen_t *optlen)
1870 {
1871 	if (*optlen != sizeof(int)) {
1872 		return -EINVAL;
1873 	}
1874 
1875 	*(int *)optval = context->options.dtls_handshake_on_connect;
1876 
1877 	return 0;
1878 }
1879 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1880 
tls_opt_alpn_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1881 static int tls_opt_alpn_list_get(struct tls_context *context,
1882 				 void *optval, socklen_t *optlen)
1883 {
1884 	const char **alpn_list = context->options.alpn_list;
1885 	int alpn_cnt, i = 0;
1886 	const char **ret_list = optval;
1887 
1888 	if (!ALPN_MAX_PROTOCOLS) {
1889 		return -EINVAL;
1890 	}
1891 
1892 	if (*optlen % sizeof(const char *) != 0 || *optlen == 0) {
1893 		return -EINVAL;
1894 	}
1895 
1896 	alpn_cnt = *optlen / sizeof(const char *);
1897 	while (alpn_list[i] != NULL) {
1898 		ret_list[i] = alpn_list[i];
1899 
1900 		if (++i == alpn_cnt) {
1901 			break;
1902 		}
1903 	}
1904 
1905 	*optlen = i * sizeof(const char *);
1906 
1907 	return 0;
1908 }
1909 
tls_opt_session_cache_set(struct tls_context * context,const void * optval,socklen_t optlen)1910 static int tls_opt_session_cache_set(struct tls_context *context,
1911 				     const void *optval, socklen_t optlen)
1912 {
1913 	int *val = (int *)optval;
1914 
1915 	if (!optval) {
1916 		return -EINVAL;
1917 	}
1918 
1919 	if (sizeof(int) != optlen) {
1920 		return -EINVAL;
1921 	}
1922 
1923 	context->options.cache_enabled = (*val == TLS_SESSION_CACHE_ENABLED);
1924 
1925 	return 0;
1926 }
1927 
tls_opt_session_cache_get(struct tls_context * context,void * optval,socklen_t * optlen)1928 static int tls_opt_session_cache_get(struct tls_context *context,
1929 				     void *optval, socklen_t *optlen)
1930 {
1931 	int cache_enabled = context->options.cache_enabled ?
1932 			    TLS_SESSION_CACHE_ENABLED :
1933 			    TLS_SESSION_CACHE_DISABLED;
1934 
1935 	if (*optlen != sizeof(cache_enabled)) {
1936 		return -EINVAL;
1937 	}
1938 
1939 	*(int *)optval = cache_enabled;
1940 
1941 	return 0;
1942 }
1943 
tls_opt_session_cache_purge_set(struct tls_context * context,const void * optval,socklen_t optlen)1944 static int tls_opt_session_cache_purge_set(struct tls_context *context,
1945 					   const void *optval, socklen_t optlen)
1946 {
1947 	ARG_UNUSED(context);
1948 	ARG_UNUSED(optval);
1949 	ARG_UNUSED(optlen);
1950 
1951 	tls_session_purge();
1952 
1953 	return 0;
1954 }
1955 
tls_opt_peer_verify_set(struct tls_context * context,const void * optval,socklen_t optlen)1956 static int tls_opt_peer_verify_set(struct tls_context *context,
1957 				   const void *optval, socklen_t optlen)
1958 {
1959 	int *peer_verify;
1960 
1961 	if (!optval) {
1962 		return -EINVAL;
1963 	}
1964 
1965 	if (optlen != sizeof(int)) {
1966 		return -EINVAL;
1967 	}
1968 
1969 	peer_verify = (int *)optval;
1970 
1971 	if (*peer_verify != MBEDTLS_SSL_VERIFY_NONE &&
1972 	    *peer_verify != MBEDTLS_SSL_VERIFY_OPTIONAL &&
1973 	    *peer_verify != MBEDTLS_SSL_VERIFY_REQUIRED) {
1974 		return -EINVAL;
1975 	}
1976 
1977 	context->options.verify_level = *peer_verify;
1978 
1979 	return 0;
1980 }
1981 
tls_opt_cert_nocopy_set(struct tls_context * context,const void * optval,socklen_t optlen)1982 static int tls_opt_cert_nocopy_set(struct tls_context *context,
1983 				   const void *optval, socklen_t optlen)
1984 {
1985 	int *cert_nocopy;
1986 
1987 	if (!optval) {
1988 		return -EINVAL;
1989 	}
1990 
1991 	if (optlen != sizeof(int)) {
1992 		return -EINVAL;
1993 	}
1994 
1995 	cert_nocopy = (int *)optval;
1996 
1997 	if (*cert_nocopy != TLS_CERT_NOCOPY_NONE &&
1998 	    *cert_nocopy != TLS_CERT_NOCOPY_OPTIONAL) {
1999 		return -EINVAL;
2000 	}
2001 
2002 	context->options.cert_nocopy = *cert_nocopy;
2003 
2004 	return 0;
2005 }
2006 
tls_opt_dtls_role_set(struct tls_context * context,const void * optval,socklen_t optlen)2007 static int tls_opt_dtls_role_set(struct tls_context *context,
2008 				 const void *optval, socklen_t optlen)
2009 {
2010 	int *role;
2011 
2012 	if (!optval) {
2013 		return -EINVAL;
2014 	}
2015 
2016 	if (optlen != sizeof(int)) {
2017 		return -EINVAL;
2018 	}
2019 
2020 	role = (int *)optval;
2021 	if (*role != MBEDTLS_SSL_IS_CLIENT &&
2022 	    *role != MBEDTLS_SSL_IS_SERVER) {
2023 		return -EINVAL;
2024 	}
2025 
2026 	context->options.role = *role;
2027 
2028 	return 0;
2029 }
2030 
protocol_check(int family,int type,int * proto)2031 static int protocol_check(int family, int type, int *proto)
2032 {
2033 	if (family != AF_INET && family != AF_INET6) {
2034 		return -EAFNOSUPPORT;
2035 	}
2036 
2037 	if (*proto >= IPPROTO_TLS_1_0 && *proto <= IPPROTO_TLS_1_2) {
2038 		if (type != SOCK_STREAM) {
2039 			return -EPROTOTYPE;
2040 		}
2041 
2042 		*proto = IPPROTO_TCP;
2043 	} else if (*proto >= IPPROTO_DTLS_1_0 && *proto <= IPPROTO_DTLS_1_2) {
2044 		if (!IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS)) {
2045 			return -EPROTONOSUPPORT;
2046 		}
2047 
2048 		if (type != SOCK_DGRAM) {
2049 			return -EPROTOTYPE;
2050 		}
2051 
2052 		*proto = IPPROTO_UDP;
2053 	} else {
2054 		return -EPROTONOSUPPORT;
2055 	}
2056 
2057 	return 0;
2058 }
2059 
ztls_socket(int family,int type,int proto)2060 static int ztls_socket(int family, int type, int proto)
2061 {
2062 	enum net_ip_protocol_secure tls_proto = proto;
2063 	int fd = zvfs_reserve_fd();
2064 	int sock = -1;
2065 	int ret;
2066 	struct tls_context *ctx;
2067 
2068 	if (fd < 0) {
2069 		return -1;
2070 	}
2071 
2072 	ret = protocol_check(family, type, &proto);
2073 	if (ret < 0) {
2074 		errno = -ret;
2075 		goto free_fd;
2076 	}
2077 
2078 	ctx = tls_alloc();
2079 	if (ctx == NULL) {
2080 		errno = ENOMEM;
2081 		goto free_fd;
2082 	}
2083 
2084 	sock = zsock_socket(family, type, proto);
2085 	if (sock < 0) {
2086 		goto release_tls;
2087 	}
2088 
2089 	ctx->tls_version = tls_proto;
2090 	ctx->type = (proto == IPPROTO_TCP) ? SOCK_STREAM : SOCK_DGRAM;
2091 	ctx->sock = sock;
2092 
2093 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2094 			    ZVFS_MODE_IFSOCK);
2095 
2096 	return fd;
2097 
2098 release_tls:
2099 	(void)tls_release(ctx);
2100 
2101 free_fd:
2102 	zvfs_free_fd(fd);
2103 
2104 	return -1;
2105 }
2106 
ztls_close_ctx(struct tls_context * ctx)2107 int ztls_close_ctx(struct tls_context *ctx)
2108 {
2109 	int ret, err = 0;
2110 
2111 	/* Try to send close notification. */
2112 	ctx->flags = 0;
2113 
2114 	(void)mbedtls_ssl_close_notify(&ctx->ssl);
2115 
2116 	err = tls_release(ctx);
2117 	ret = zsock_close(ctx->sock);
2118 
2119 	/* In case close fails, we propagate errno value set by close.
2120 	 * In case close succeeds, but tls_release fails, set errno
2121 	 * according to tls_release return value.
2122 	 */
2123 	if (ret == 0 && err < 0) {
2124 		errno = -err;
2125 		ret = -1;
2126 	}
2127 
2128 	return ret;
2129 }
2130 
ztls_connect_ctx(struct tls_context * ctx,const struct sockaddr * addr,socklen_t addrlen)2131 int ztls_connect_ctx(struct tls_context *ctx, const struct sockaddr *addr,
2132 		     socklen_t addrlen)
2133 {
2134 	int ret;
2135 	int sock_flags;
2136 
2137 	sock_flags = zsock_fcntl(ctx->sock, F_GETFL, 0);
2138 	if (sock_flags < 0) {
2139 		return -EIO;
2140 	}
2141 
2142 	if (sock_flags & O_NONBLOCK) {
2143 		(void)zsock_fcntl(ctx->sock, F_SETFL,
2144 				  sock_flags & ~O_NONBLOCK);
2145 	}
2146 
2147 	ret = zsock_connect(ctx->sock, addr, addrlen);
2148 	if (ret < 0) {
2149 		return ret;
2150 	}
2151 
2152 	if (sock_flags & O_NONBLOCK) {
2153 		(void)zsock_fcntl(ctx->sock, F_SETFL, sock_flags);
2154 	}
2155 
2156 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2157 	if (ctx->type == SOCK_DGRAM) {
2158 		dtls_peer_address_set(ctx, addr, addrlen);
2159 	}
2160 #endif
2161 
2162 	if (ctx->type == SOCK_STREAM
2163 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2164 	    || (ctx->type == SOCK_DGRAM && ctx->options.dtls_handshake_on_connect)
2165 #endif
2166 	    ) {
2167 		ret = tls_mbedtls_init(ctx, false);
2168 		if (ret < 0) {
2169 			goto error;
2170 		}
2171 
2172 		/* Do not use any socket flags during the handshake. */
2173 		ctx->flags = 0;
2174 
2175 		tls_session_restore(ctx, addr, addrlen);
2176 
2177 		/* TODO For simplicity, TLS handshake blocks the socket
2178 		 * even for non-blocking socket.
2179 		 */
2180 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2181 		if (ret < 0) {
2182 			goto error;
2183 		}
2184 
2185 		tls_session_store(ctx, addr, addrlen);
2186 	}
2187 
2188 	return 0;
2189 
2190 error:
2191 	errno = -ret;
2192 	return -1;
2193 }
2194 
ztls_accept_ctx(struct tls_context * parent,struct sockaddr * addr,socklen_t * addrlen)2195 int ztls_accept_ctx(struct tls_context *parent, struct sockaddr *addr,
2196 		    socklen_t *addrlen)
2197 {
2198 	struct tls_context *child = NULL;
2199 	int ret, err, fd, sock;
2200 
2201 	fd = zvfs_reserve_fd();
2202 	if (fd < 0) {
2203 		return -1;
2204 	}
2205 
2206 
2207 	k_mutex_unlock(parent->lock);
2208 	sock = zsock_accept(parent->sock, addr, addrlen);
2209 	k_mutex_lock(parent->lock, K_FOREVER);
2210 	if (sock < 0) {
2211 		ret = -errno;
2212 		goto error;
2213 	}
2214 
2215 	child = tls_clone(parent);
2216 	if (child == NULL) {
2217 		ret = -ENOMEM;
2218 		goto error;
2219 	}
2220 
2221 	zvfs_finalize_typed_fd(fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2222 			    ZVFS_MODE_IFSOCK);
2223 
2224 	child->sock = sock;
2225 
2226 	ret = tls_mbedtls_init(child, true);
2227 	if (ret < 0) {
2228 		goto error;
2229 	}
2230 
2231 	/* Do not use any socket flags during the handshake. */
2232 	child->flags = 0;
2233 
2234 	/* TODO For simplicity, TLS handshake blocks the socket even for
2235 	 * non-blocking socket.
2236 	 */
2237 	ret = tls_mbedtls_handshake(child, K_FOREVER);
2238 	if (ret < 0) {
2239 		goto error;
2240 	}
2241 
2242 	return fd;
2243 
2244 error:
2245 	if (child != NULL) {
2246 		err = tls_release(child);
2247 		__ASSERT(err == 0, "TLS context release failed");
2248 	}
2249 
2250 	if (sock >= 0) {
2251 		err = zsock_close(sock);
2252 		__ASSERT(err == 0, "Child socket close failed");
2253 	}
2254 
2255 	zvfs_free_fd(fd);
2256 
2257 	errno = -ret;
2258 	return -1;
2259 }
2260 
send_tls(struct tls_context * ctx,const void * buf,size_t len,int flags)2261 static ssize_t send_tls(struct tls_context *ctx, const void *buf,
2262 			size_t len, int flags)
2263 {
2264 	const bool is_block = is_blocking(ctx->sock, flags);
2265 	k_timeout_t timeout;
2266 	k_timepoint_t end;
2267 	int ret;
2268 
2269 	if (ctx->error != 0) {
2270 		errno = ctx->error;
2271 		return -1;
2272 	}
2273 
2274 	if (ctx->session_closed) {
2275 		errno = ECONNABORTED;
2276 		return -1;
2277 	}
2278 
2279 	if (!is_block) {
2280 		timeout = K_NO_WAIT;
2281 	} else {
2282 		timeout = ctx->options.timeout_tx;
2283 	}
2284 
2285 	end = sys_timepoint_calc(timeout);
2286 
2287 	do {
2288 		ret = mbedtls_ssl_write(&ctx->ssl, buf, len);
2289 		if (ret >= 0) {
2290 			return ret;
2291 		}
2292 
2293 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2294 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2295 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2296 		    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2297 			int timeout_ms;
2298 
2299 			if (!is_block) {
2300 				errno = EAGAIN;
2301 				break;
2302 			}
2303 
2304 			/* Blocking timeout. */
2305 			timeout = sys_timepoint_timeout(end);
2306 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2307 				errno = EAGAIN;
2308 				break;
2309 			}
2310 
2311 			/* Block. */
2312 			timeout_ms = timeout_to_ms(&timeout);
2313 			ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2314 			if (ret != 0) {
2315 				errno = -ret;
2316 				break;
2317 			}
2318 		} else {
2319 			NET_ERR("TLS send error: -%x", -ret);
2320 
2321 			/* MbedTLS API documentation requires session to
2322 			 * be reset in other error cases
2323 			 */
2324 			ret = tls_mbedtls_reset(ctx);
2325 			if (ret != 0) {
2326 				ctx->error = ENOMEM;
2327 				errno = ENOMEM;
2328 			} else {
2329 				ctx->error = ECONNABORTED;
2330 				errno = ECONNABORTED;
2331 			}
2332 
2333 			break;
2334 		}
2335 	} while (true);
2336 
2337 	return -1;
2338 }
2339 
2340 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
sendto_dtls_client(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2341 static ssize_t sendto_dtls_client(struct tls_context *ctx, const void *buf,
2342 				  size_t len, int flags,
2343 				  const struct sockaddr *dest_addr,
2344 				  socklen_t addrlen)
2345 {
2346 	int ret;
2347 
2348 	if (!dest_addr) {
2349 		/* No address provided, check if we have stored one,
2350 		 * otherwise return error.
2351 		 */
2352 		if (ctx->dtls_peer_addrlen == 0) {
2353 			ret = -EDESTADDRREQ;
2354 			goto error;
2355 		}
2356 	} else if (ctx->dtls_peer_addrlen == 0) {
2357 		/* Address provided and no peer address stored. */
2358 		dtls_peer_address_set(ctx, dest_addr, addrlen);
2359 	} else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2360 		/* Address provided but it does not match stored one */
2361 		ret = -EISCONN;
2362 		goto error;
2363 	}
2364 
2365 	if (!ctx->is_initialized) {
2366 		ret = tls_mbedtls_init(ctx, false);
2367 		if (ret < 0) {
2368 			goto error;
2369 		}
2370 	}
2371 
2372 	if (!is_handshake_complete(ctx)) {
2373 		tls_session_restore(ctx, &ctx->dtls_peer_addr,
2374 				    ctx->dtls_peer_addrlen);
2375 
2376 		/* TODO For simplicity, TLS handshake blocks the socket even for
2377 		 * non-blocking socket.
2378 		 */
2379 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2380 		if (ret < 0) {
2381 			goto error;
2382 		}
2383 
2384 		/* Client socket ready to use again. */
2385 		ctx->error = 0;
2386 
2387 		tls_session_store(ctx, &ctx->dtls_peer_addr,
2388 				  ctx->dtls_peer_addrlen);
2389 	}
2390 
2391 	return send_tls(ctx, buf, len, flags);
2392 
2393 error:
2394 	errno = -ret;
2395 	return -1;
2396 }
2397 
sendto_dtls_server(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2398 static ssize_t sendto_dtls_server(struct tls_context *ctx, const void *buf,
2399 				  size_t len, int flags,
2400 				  const struct sockaddr *dest_addr,
2401 				  socklen_t addrlen)
2402 {
2403 	/* For DTLS server, require to have established DTLS connection
2404 	 * in order to send data.
2405 	 */
2406 	if (!is_handshake_complete(ctx)) {
2407 		errno = ENOTCONN;
2408 		return -1;
2409 	}
2410 
2411 	/* Verify we are sending to a peer that we have connection with. */
2412 	if (dest_addr &&
2413 	    !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2414 		errno = EISCONN;
2415 		return -1;
2416 	}
2417 
2418 	return send_tls(ctx, buf, len, flags);
2419 }
2420 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2421 
ztls_sendto_ctx(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2422 ssize_t ztls_sendto_ctx(struct tls_context *ctx, const void *buf, size_t len,
2423 			int flags, const struct sockaddr *dest_addr,
2424 			socklen_t addrlen)
2425 {
2426 	ctx->flags = flags;
2427 
2428 	/* TLS */
2429 	if (ctx->type == SOCK_STREAM) {
2430 		return send_tls(ctx, buf, len, flags);
2431 	}
2432 
2433 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2434 	/* DTLS */
2435 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2436 		return sendto_dtls_server(ctx, buf, len, flags,
2437 					  dest_addr, addrlen);
2438 	}
2439 
2440 	return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen);
2441 #else
2442 	errno = ENOTSUP;
2443 	return -1;
2444 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2445 }
2446 
dtls_sendmsg_merge_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2447 static ssize_t dtls_sendmsg_merge_and_send(struct tls_context *ctx,
2448 					   const struct msghdr *msg,
2449 					   int flags)
2450 {
2451 	static K_MUTEX_DEFINE(sendmsg_lock);
2452 	static uint8_t sendmsg_buf[DTLS_SENDMSG_BUF_SIZE];
2453 	ssize_t len = 0;
2454 
2455 	k_mutex_lock(&sendmsg_lock, K_FOREVER);
2456 
2457 	for (int i = 0; i < msg->msg_iovlen; i++) {
2458 		struct iovec *vec = msg->msg_iov + i;
2459 
2460 		if (vec->iov_len > 0) {
2461 			if (len + vec->iov_len > sizeof(sendmsg_buf)) {
2462 				k_mutex_unlock(&sendmsg_lock);
2463 				errno = EMSGSIZE;
2464 				return -1;
2465 			}
2466 
2467 			memcpy(sendmsg_buf + len, vec->iov_base, vec->iov_len);
2468 			len += vec->iov_len;
2469 		}
2470 	}
2471 
2472 	if (len > 0) {
2473 		len = ztls_sendto_ctx(ctx, sendmsg_buf, len, flags,
2474 				      msg->msg_name, msg->msg_namelen);
2475 	}
2476 
2477 	k_mutex_unlock(&sendmsg_lock);
2478 
2479 	return len;
2480 }
2481 
tls_sendmsg_loop_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2482 static ssize_t tls_sendmsg_loop_and_send(struct tls_context *ctx,
2483 					 const struct msghdr *msg,
2484 					 int flags)
2485 {
2486 	ssize_t len = 0;
2487 	ssize_t ret;
2488 
2489 	for (int i = 0; i < msg->msg_iovlen; i++) {
2490 		struct iovec *vec = msg->msg_iov + i;
2491 		size_t sent = 0;
2492 
2493 		if (vec->iov_len == 0) {
2494 			continue;
2495 		}
2496 
2497 		while (sent < vec->iov_len) {
2498 			uint8_t *ptr = (uint8_t *)vec->iov_base + sent;
2499 
2500 			ret = ztls_sendto_ctx(ctx, ptr, vec->iov_len - sent,
2501 					      flags, msg->msg_name,
2502 					      msg->msg_namelen);
2503 			if (ret < 0) {
2504 				return ret;
2505 			}
2506 			sent += ret;
2507 		}
2508 		len += sent;
2509 	}
2510 
2511 	return len;
2512 }
2513 
ztls_sendmsg_ctx(struct tls_context * ctx,const struct msghdr * msg,int flags)2514 ssize_t ztls_sendmsg_ctx(struct tls_context *ctx, const struct msghdr *msg,
2515 			 int flags)
2516 {
2517 	if (msg == NULL) {
2518 		errno = EINVAL;
2519 		return -1;
2520 	}
2521 
2522 	if (IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS) &&
2523 	    ctx->type == SOCK_DGRAM) {
2524 		if (DTLS_SENDMSG_BUF_SIZE > 0) {
2525 			/* With one buffer only, there's no need to use
2526 			 * intermediate buffer.
2527 			 */
2528 			if (msghdr_non_empty_iov_count(msg) == 1) {
2529 				goto send_loop;
2530 			}
2531 
2532 			return dtls_sendmsg_merge_and_send(ctx, msg, flags);
2533 		}
2534 
2535 		/*
2536 		 * Current mbedTLS API (i.e. mbedtls_ssl_write()) allows only to send a single
2537 		 * contiguous buffer. This means that gather write using sendmsg() can only be
2538 		 * handled correctly if there is a single non-empty buffer in msg->msg_iov.
2539 		 */
2540 		if (msghdr_non_empty_iov_count(msg) > 1) {
2541 			errno = EMSGSIZE;
2542 			return -1;
2543 		}
2544 	}
2545 
2546 send_loop:
2547 	return tls_sendmsg_loop_and_send(ctx, msg, flags);
2548 }
2549 
recv_tls(struct tls_context * ctx,void * buf,size_t max_len,int flags)2550 static ssize_t recv_tls(struct tls_context *ctx, void *buf,
2551 			size_t max_len, int flags)
2552 {
2553 	size_t recv_len = 0;
2554 	const bool waitall = flags & ZSOCK_MSG_WAITALL;
2555 	const bool is_block = is_blocking(ctx->sock, flags);
2556 	k_timeout_t timeout;
2557 	k_timepoint_t end;
2558 	int ret;
2559 
2560 	if (ctx->error != 0) {
2561 		errno = ctx->error;
2562 		return -1;
2563 	}
2564 
2565 	if (ctx->session_closed) {
2566 		return 0;
2567 	}
2568 
2569 	if (!is_block) {
2570 		timeout = K_NO_WAIT;
2571 	} else {
2572 		timeout = ctx->options.timeout_rx;
2573 	}
2574 
2575 	end = sys_timepoint_calc(timeout);
2576 
2577 	do {
2578 		size_t read_len = max_len - recv_len;
2579 
2580 		ret = mbedtls_ssl_read(&ctx->ssl, (uint8_t *)buf + recv_len,
2581 				       read_len);
2582 		if (ret < 0) {
2583 			if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2584 				/* Peer notified that it's closing the
2585 				 * connection.
2586 				 */
2587 				ctx->session_closed = true;
2588 				break;
2589 			}
2590 
2591 			if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
2592 				/* Client reconnect on the same socket is not
2593 				 * supported. See mbedtls_ssl_read API
2594 				 * documentation.
2595 				 */
2596 				ctx->session_closed = true;
2597 				break;
2598 			}
2599 
2600 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2601 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2602 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2603 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2604 				int timeout_ms;
2605 
2606 				if (!is_block) {
2607 					ret = -EAGAIN;
2608 					goto err;
2609 				}
2610 
2611 				/* Blocking timeout. */
2612 				timeout = sys_timepoint_timeout(end);
2613 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2614 					ret = -EAGAIN;
2615 					goto err;
2616 				}
2617 
2618 				timeout_ms = timeout_to_ms(&timeout);
2619 
2620 				/* Block. */
2621 				k_mutex_unlock(ctx->lock);
2622 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2623 				k_mutex_lock(ctx->lock, K_FOREVER);
2624 
2625 				if (ret == 0) {
2626 					/* Retry. */
2627 					continue;
2628 				}
2629 			} else {
2630 				NET_ERR("TLS recv error: -%x", -ret);
2631 				ret = -EIO;
2632 			}
2633 
2634 err:
2635 			errno = -ret;
2636 			return -1;
2637 		}
2638 
2639 		if (ret == 0) {
2640 			break;
2641 		}
2642 
2643 		recv_len += ret;
2644 	} while ((recv_len == 0) || (waitall && (recv_len < max_len)));
2645 
2646 	return recv_len;
2647 }
2648 
2649 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
recvfrom_dtls_common(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2650 static ssize_t recvfrom_dtls_common(struct tls_context *ctx, void *buf,
2651 				    size_t max_len, int flags,
2652 				    struct sockaddr *src_addr,
2653 				    socklen_t *addrlen)
2654 {
2655 	int ret;
2656 	bool is_block = is_blocking(ctx->sock, flags);
2657 	k_timeout_t timeout;
2658 	k_timepoint_t end;
2659 
2660 	if (ctx->error != 0) {
2661 		errno = ctx->error;
2662 		return -1;
2663 	}
2664 
2665 	if (!is_block) {
2666 		timeout = K_NO_WAIT;
2667 	} else {
2668 		timeout = ctx->options.timeout_rx;
2669 	}
2670 
2671 	end = sys_timepoint_calc(timeout);
2672 
2673 	do {
2674 		size_t remaining;
2675 
2676 		ret = mbedtls_ssl_read(&ctx->ssl, buf, max_len);
2677 		if (ret < 0) {
2678 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2679 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2680 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2681 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2682 				int timeout_dtls, timeout_sock, timeout_ms;
2683 
2684 				if (!is_block) {
2685 					return ret;
2686 				}
2687 
2688 				/* Blocking timeout. */
2689 				timeout = sys_timepoint_timeout(end);
2690 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2691 					return ret;
2692 				}
2693 
2694 				timeout_dtls = dtls_get_remaining_timeout(ctx);
2695 				timeout_sock = timeout_to_ms(&timeout);
2696 				if (timeout_dtls == SYS_FOREVER_MS ||
2697 				    timeout_sock == SYS_FOREVER_MS) {
2698 					timeout_ms = MAX(timeout_dtls, timeout_sock);
2699 				} else {
2700 					timeout_ms = MIN(timeout_dtls, timeout_sock);
2701 				}
2702 
2703 				/* Block. */
2704 				k_mutex_unlock(ctx->lock);
2705 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2706 				k_mutex_lock(ctx->lock, K_FOREVER);
2707 
2708 				if (ret == 0) {
2709 					/* Retry. */
2710 					continue;
2711 				} else {
2712 					return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2713 				}
2714 			} else {
2715 				return ret;
2716 			}
2717 		}
2718 
2719 		if (src_addr && addrlen) {
2720 			dtls_peer_address_get(ctx, src_addr, addrlen);
2721 		}
2722 
2723 		/* mbedtls_ssl_get_bytes_avail() indicate the data length
2724 		 * remaining in the current datagram.
2725 		 */
2726 		remaining = mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2727 
2728 		/* No more data in the datagram, or dummy read. */
2729 		if ((remaining == 0) || (max_len == 0)) {
2730 			return ret;
2731 		}
2732 
2733 		if (flags & ZSOCK_MSG_TRUNC) {
2734 			ret += remaining;
2735 		}
2736 
2737 		for (int i = 0; i < remaining; i++) {
2738 			uint8_t byte;
2739 			int err;
2740 
2741 			err = mbedtls_ssl_read(&ctx->ssl, &byte, sizeof(byte));
2742 			if (err <= 0) {
2743 				NET_ERR("Error while flushing the rest of the"
2744 					" datagram, err %d", err);
2745 				ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2746 				break;
2747 			}
2748 		}
2749 
2750 		break;
2751 	} while (true);
2752 
2753 
2754 	return ret;
2755 }
2756 
recvfrom_dtls_client(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2757 static ssize_t recvfrom_dtls_client(struct tls_context *ctx, void *buf,
2758 				    size_t max_len, int flags,
2759 				    struct sockaddr *src_addr,
2760 				    socklen_t *addrlen)
2761 {
2762 	int ret;
2763 
2764 	if (!is_handshake_complete(ctx)) {
2765 		ret = -ENOTCONN;
2766 		goto error;
2767 	}
2768 
2769 	ret = recvfrom_dtls_common(ctx, buf, max_len, flags, src_addr, addrlen);
2770 	if (ret >= 0) {
2771 		return ret;
2772 	}
2773 
2774 	switch (ret) {
2775 	case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2776 		/* Peer notified that it's closing the connection. */
2777 		ret = tls_mbedtls_reset(ctx);
2778 		if (ret == 0) {
2779 			ctx->error = ENOTCONN;
2780 			ret = -ENOTCONN;
2781 		} else {
2782 			ctx->error = ENOMEM;
2783 			ret = -ENOMEM;
2784 		}
2785 		break;
2786 
2787 	case MBEDTLS_ERR_SSL_TIMEOUT:
2788 		(void)mbedtls_ssl_close_notify(&ctx->ssl);
2789 		ctx->error = ETIMEDOUT;
2790 		ret = -ETIMEDOUT;
2791 		break;
2792 
2793 	case MBEDTLS_ERR_SSL_WANT_READ:
2794 	case MBEDTLS_ERR_SSL_WANT_WRITE:
2795 	case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2796 	case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2797 		ret = -EAGAIN;
2798 		break;
2799 
2800 	default:
2801 		NET_ERR("DTLS client recv error: -%x", -ret);
2802 
2803 		/* MbedTLS API documentation requires session to
2804 		 * be reset in other error cases
2805 		 */
2806 		ret = tls_mbedtls_reset(ctx);
2807 		if (ret != 0) {
2808 			ctx->error = ENOMEM;
2809 			errno = ENOMEM;
2810 		} else {
2811 			ctx->error = ECONNABORTED;
2812 			ret = -ECONNABORTED;
2813 		}
2814 
2815 		break;
2816 	}
2817 
2818 error:
2819 	errno = -ret;
2820 	return -1;
2821 }
2822 
recvfrom_dtls_server(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2823 static ssize_t recvfrom_dtls_server(struct tls_context *ctx, void *buf,
2824 				    size_t max_len, int flags,
2825 				    struct sockaddr *src_addr,
2826 				    socklen_t *addrlen)
2827 {
2828 	int ret;
2829 	bool repeat;
2830 	k_timeout_t timeout;
2831 
2832 	if (!ctx->is_initialized) {
2833 		ret = tls_mbedtls_init(ctx, true);
2834 		if (ret < 0) {
2835 			goto error;
2836 		}
2837 	}
2838 
2839 	if (is_blocking(ctx->sock, flags)) {
2840 		timeout = ctx->options.timeout_rx;
2841 	} else {
2842 		timeout = K_NO_WAIT;
2843 	}
2844 
2845 	/* Loop to enable DTLS reconnection for servers without closing
2846 	 * a socket.
2847 	 */
2848 	do {
2849 		repeat = false;
2850 
2851 		if (!is_handshake_complete(ctx)) {
2852 			ret = tls_mbedtls_handshake(ctx, timeout);
2853 			if (ret < 0) {
2854 				/* In case of EAGAIN, just exit. */
2855 				if (ret == -EAGAIN) {
2856 					break;
2857 				}
2858 
2859 				ret = tls_mbedtls_reset(ctx);
2860 				if (ret == 0) {
2861 					repeat = true;
2862 				} else {
2863 					ret = -ENOMEM;
2864 				}
2865 
2866 				continue;
2867 			}
2868 
2869 			/* Server socket ready to use again. */
2870 			ctx->error = 0;
2871 		}
2872 
2873 		ret = recvfrom_dtls_common(ctx, buf, max_len, flags,
2874 					   src_addr, addrlen);
2875 		if (ret >= 0) {
2876 			return ret;
2877 		}
2878 
2879 		switch (ret) {
2880 		case MBEDTLS_ERR_SSL_TIMEOUT:
2881 			(void)mbedtls_ssl_close_notify(&ctx->ssl);
2882 			__fallthrough;
2883 			/* fallthrough */
2884 
2885 		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2886 		case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
2887 			ret = tls_mbedtls_reset(ctx);
2888 			if (ret == 0) {
2889 				repeat = true;
2890 			} else {
2891 				ctx->error = ENOMEM;
2892 				ret = -ENOMEM;
2893 			}
2894 			break;
2895 
2896 		case MBEDTLS_ERR_SSL_WANT_READ:
2897 		case MBEDTLS_ERR_SSL_WANT_WRITE:
2898 		case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2899 		case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2900 			ret = -EAGAIN;
2901 			break;
2902 
2903 		default:
2904 			NET_ERR("DTLS server recv error: -%x", -ret);
2905 
2906 			ret = tls_mbedtls_reset(ctx);
2907 			if (ret != 0) {
2908 				ctx->error = ENOMEM;
2909 				errno = ENOMEM;
2910 			} else {
2911 				ctx->error = ECONNABORTED;
2912 				ret = -ECONNABORTED;
2913 			}
2914 
2915 			break;
2916 		}
2917 	} while (repeat);
2918 
2919 error:
2920 	errno = -ret;
2921 	return -1;
2922 }
2923 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2924 
ztls_recvfrom_ctx(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2925 ssize_t ztls_recvfrom_ctx(struct tls_context *ctx, void *buf, size_t max_len,
2926 			  int flags, struct sockaddr *src_addr,
2927 			  socklen_t *addrlen)
2928 {
2929 	if (flags & ZSOCK_MSG_PEEK) {
2930 		/* TODO mbedTLS does not support 'peeking' This could be
2931 		 * bypassed by having intermediate buffer for peeking
2932 		 */
2933 		errno = ENOTSUP;
2934 		return -1;
2935 	}
2936 
2937 	ctx->flags = flags;
2938 
2939 	/* TLS */
2940 	if (ctx->type == SOCK_STREAM) {
2941 		return recv_tls(ctx, buf, max_len, flags);
2942 	}
2943 
2944 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2945 	/* DTLS */
2946 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2947 		return recvfrom_dtls_server(ctx, buf, max_len, flags,
2948 					    src_addr, addrlen);
2949 	}
2950 
2951 	return recvfrom_dtls_client(ctx, buf, max_len, flags,
2952 				    src_addr, addrlen);
2953 #else
2954 	errno = ENOTSUP;
2955 	return -1;
2956 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2957 }
2958 
ztls_poll_prepare_pollin(struct tls_context * ctx)2959 static int ztls_poll_prepare_pollin(struct tls_context *ctx)
2960 {
2961 	/* If there already is mbedTLS data to read, there is no
2962 	 * need to set the k_poll_event object. Return EALREADY
2963 	 * so we won't block in the k_poll.
2964 	 */
2965 	if (!ctx->is_listening) {
2966 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
2967 			return -EALREADY;
2968 		}
2969 	}
2970 
2971 	return 0;
2972 }
2973 
ztls_poll_prepare_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)2974 static int ztls_poll_prepare_ctx(struct tls_context *ctx,
2975 				 struct zsock_pollfd *pfd,
2976 				 struct k_poll_event **pev,
2977 				 struct k_poll_event *pev_end)
2978 {
2979 	const struct fd_op_vtable *vtable;
2980 	struct k_mutex *lock;
2981 	void *obj;
2982 	int ret;
2983 	short events = pfd->events;
2984 
2985 	/* DTLS client should wait for the handshake to complete before
2986 	 * it actually starts to poll for data.
2987 	 */
2988 	if ((pfd->events & ZSOCK_POLLIN) && (ctx->type == SOCK_DGRAM) &&
2989 	    (ctx->options.role == MBEDTLS_SSL_IS_CLIENT) &&
2990 	    !is_handshake_complete(ctx)) {
2991 		(*pev)->obj = &ctx->tls_established;
2992 		(*pev)->type = K_POLL_TYPE_SEM_AVAILABLE;
2993 		(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
2994 		(*pev)->state = K_POLL_STATE_NOT_READY;
2995 		(*pev)++;
2996 
2997 		/* Since k_poll_event is configured by the TLS layer in this
2998 		 * case, do not forward ZSOCK_POLLIN to the underlying socket.
2999 		 */
3000 		pfd->events &= ~ZSOCK_POLLIN;
3001 	}
3002 
3003 	obj = zvfs_get_fd_obj_and_vtable(
3004 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3005 	if (obj == NULL) {
3006 		ret = -EBADF;
3007 		goto exit;
3008 	}
3009 
3010 	(void)k_mutex_lock(lock, K_FOREVER);
3011 
3012 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
3013 				   pfd, pev, pev_end);
3014 	if (ret != 0) {
3015 		goto exit;
3016 	}
3017 
3018 	if (pfd->events & ZSOCK_POLLIN) {
3019 		ret = ztls_poll_prepare_pollin(ctx);
3020 	}
3021 
3022 exit:
3023 	/* Restore original events. */
3024 	pfd->events = events;
3025 
3026 	k_mutex_unlock(lock);
3027 
3028 	return ret;
3029 }
3030 
3031 #include <zephyr/net/net_core.h>
3032 
ztls_socket_data_check(struct tls_context * ctx)3033 static int ztls_socket_data_check(struct tls_context *ctx)
3034 {
3035 	int ret;
3036 
3037 	if (ctx->type == SOCK_STREAM) {
3038 		if (!ctx->is_initialized) {
3039 			return -ENOTCONN;
3040 		}
3041 	}
3042 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3043 	else {
3044 		if (!ctx->is_initialized) {
3045 			bool is_server = ctx->options.role == MBEDTLS_SSL_IS_SERVER;
3046 
3047 			ret = tls_mbedtls_init(ctx, is_server);
3048 			if (ret < 0) {
3049 				return -ENOMEM;
3050 			}
3051 		}
3052 
3053 		if (!is_handshake_complete(ctx)) {
3054 			ret = tls_mbedtls_handshake(ctx, K_NO_WAIT);
3055 			if (ret < 0) {
3056 				if (ret == -EAGAIN) {
3057 					return 0;
3058 				}
3059 
3060 				ret = tls_mbedtls_reset(ctx);
3061 				if (ret != 0) {
3062 					return -ENOMEM;
3063 				}
3064 
3065 				return 0;
3066 			}
3067 
3068 			/* Socket ready to use again. */
3069 			ctx->error = 0;
3070 
3071 			return 0;
3072 		}
3073 	}
3074 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3075 
3076 	ctx->flags = ZSOCK_MSG_DONTWAIT;
3077 
3078 	ret = mbedtls_ssl_read(&ctx->ssl, NULL, 0);
3079 	if (ret < 0) {
3080 		if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
3081 			/* Don't reset the context for STREAM socket - the
3082 			 * application needs to reopen the socket anyway, and
3083 			 * resetting the context would result in an error instead
3084 			 * of 0 in a consecutive recv() call.
3085 			 */
3086 			if (ctx->type == SOCK_DGRAM) {
3087 				ret = tls_mbedtls_reset(ctx);
3088 				if (ret != 0) {
3089 					return -ENOMEM;
3090 				}
3091 			} else {
3092 				ctx->session_closed = true;
3093 			}
3094 
3095 			return -ENOTCONN;
3096 		}
3097 
3098 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
3099 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
3100 			return 0;
3101 		}
3102 
3103 		NET_ERR("TLS data check error: -%x", -ret);
3104 
3105 		/* MbedTLS API documentation requires session to
3106 		 * be reset in other error cases
3107 		 */
3108 		if (tls_mbedtls_reset(ctx) != 0) {
3109 			return -ENOMEM;
3110 		}
3111 
3112 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3113 		if (ret == MBEDTLS_ERR_SSL_TIMEOUT && ctx->type == SOCK_DGRAM) {
3114 			/* DTLS timeout interpreted as closing of connection. */
3115 			return -ENOTCONN;
3116 		}
3117 #endif
3118 		return -ECONNABORTED;
3119 	}
3120 
3121 	return mbedtls_ssl_get_bytes_avail(&ctx->ssl);
3122 }
3123 
ztls_poll_update_pollin(int fd,struct tls_context * ctx,struct zsock_pollfd * pfd)3124 static int ztls_poll_update_pollin(int fd, struct tls_context *ctx,
3125 				   struct zsock_pollfd *pfd)
3126 {
3127 	int ret;
3128 
3129 	if (!ctx->is_listening) {
3130 		/* Already had TLS data to read on socket. */
3131 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
3132 			pfd->revents |= ZSOCK_POLLIN;
3133 			goto next;
3134 		}
3135 	}
3136 
3137 	if (ctx->type == SOCK_STREAM) {
3138 		if (!(pfd->revents & ZSOCK_POLLIN)) {
3139 			/* No new data on a socket. */
3140 			goto next;
3141 		}
3142 
3143 		if (ctx->is_listening) {
3144 			goto next;
3145 		}
3146 	}
3147 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3148 	else {
3149 		/* Perform data check without incoming data for completed DTLS connections.
3150 		 * This allows the connections to timeout with CONFIG_NET_SOCKETS_DTLS_TIMEOUT.
3151 		 */
3152 		if (!is_handshake_complete(ctx) && !(pfd->revents & ZSOCK_POLLIN)) {
3153 			goto next;
3154 		}
3155 	}
3156 #endif
3157 	ret = ztls_socket_data_check(ctx);
3158 	if (ret == -ENOTCONN || (pfd->revents & ZSOCK_POLLHUP)) {
3159 		/* Datagram does not return 0 on consecutive recv, but an error
3160 		 * code, hence clear POLLIN.
3161 		 */
3162 		if (ctx->type == SOCK_DGRAM) {
3163 			pfd->revents &= ~ZSOCK_POLLIN;
3164 		}
3165 		pfd->revents |= ZSOCK_POLLHUP;
3166 		goto next;
3167 	} else if (ret < 0) {
3168 		ctx->error = -ret;
3169 		pfd->revents |= ZSOCK_POLLERR;
3170 		goto next;
3171 	} else if (ret == 0) {
3172 		goto again;
3173 	}
3174 
3175 next:
3176 	return 0;
3177 
3178 again:
3179 	/* Received encrypted data, but still not enough
3180 	 * to decrypt it and return data through socket,
3181 	 * ask for retry if no other events are set.
3182 	 */
3183 	pfd->revents &= ~ZSOCK_POLLIN;
3184 
3185 	return -EAGAIN;
3186 }
3187 
ztls_poll_update_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev)3188 static int ztls_poll_update_ctx(struct tls_context *ctx,
3189 				struct zsock_pollfd *pfd,
3190 				struct k_poll_event **pev)
3191 {
3192 	const struct fd_op_vtable *vtable;
3193 	struct k_mutex *lock;
3194 	void *obj;
3195 	int ret;
3196 	short events = pfd->events;
3197 
3198 	obj = zvfs_get_fd_obj_and_vtable(
3199 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3200 	if (obj == NULL) {
3201 		return -EBADF;
3202 	}
3203 
3204 	(void)k_mutex_lock(lock, K_FOREVER);
3205 
3206 	/* Check if the socket was waiting for the handshake to complete. */
3207 	if ((pfd->events & ZSOCK_POLLIN) &&
3208 	    ((*pev)->obj == &ctx->tls_established)) {
3209 		/* In case handshake is complete, reconfigure the k_poll_event
3210 		 * to monitor the underlying socket now.
3211 		 */
3212 		if ((*pev)->state != K_POLL_STATE_NOT_READY) {
3213 			ret = zvfs_fdtable_call_ioctl(vtable, obj,
3214 						   ZFD_IOCTL_POLL_PREPARE,
3215 						   pfd, pev, *pev + 1);
3216 			if (ret != 0 && ret != -EALREADY) {
3217 				goto out;
3218 			}
3219 
3220 			/* Return -EAGAIN to signal to poll() that it should
3221 			 * make another iteration with the event reconfigured
3222 			 * above (if needed).
3223 			 */
3224 			ret = -EAGAIN;
3225 			goto out;
3226 		}
3227 
3228 		/* Handshake still not ready - skip ZSOCK_POLLIN verification
3229 		 * for the underlying socket.
3230 		 */
3231 		(*pev)++;
3232 		pfd->events &= ~ZSOCK_POLLIN;
3233 	}
3234 
3235 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_UPDATE,
3236 				   pfd, pev);
3237 	if (ret != 0) {
3238 		goto exit;
3239 	}
3240 
3241 	if (pfd->events & ZSOCK_POLLIN) {
3242 		ret = ztls_poll_update_pollin(pfd->fd, ctx, pfd);
3243 		if (ret == -EAGAIN && pfd->revents == 0) {
3244 			(*pev - 1)->state = K_POLL_STATE_NOT_READY;
3245 			goto exit;
3246 		} else {
3247 			ret = 0;
3248 		}
3249 	}
3250 exit:
3251 	/* Restore original events. */
3252 	pfd->events = events;
3253 
3254 out:
3255 	k_mutex_unlock(lock);
3256 
3257 	return ret;
3258 }
3259 
3260 /* Return true if needed to retry rightoff or false otherwise. */
poll_offload_dtls_client_retry(struct tls_context * ctx,struct zsock_pollfd * pfd)3261 static bool poll_offload_dtls_client_retry(struct tls_context *ctx,
3262 					   struct zsock_pollfd *pfd)
3263 {
3264 	/* DTLS client should wait for the handshake to complete before it
3265 	 * reports that data is ready.
3266 	 */
3267 	if ((ctx->type != SOCK_DGRAM) ||
3268 	    (ctx->options.role != MBEDTLS_SSL_IS_CLIENT)) {
3269 		return false;
3270 	}
3271 
3272 	if (ctx->handshake_in_progress) {
3273 		/* Add some sleep to allow lower priority threads to proceed
3274 		 * with handshake.
3275 		 */
3276 		k_msleep(10);
3277 
3278 		pfd->revents &= ~ZSOCK_POLLIN;
3279 		return true;
3280 	} else if (!is_handshake_complete(ctx)) {
3281 		uint8_t byte;
3282 		int ret;
3283 
3284 		/* Handshake didn't start yet - just drop the incoming data -
3285 		 * it's the client who should initiate the handshake.
3286 		 */
3287 		ret = zsock_recv(ctx->sock, &byte, sizeof(byte),
3288 				 ZSOCK_MSG_DONTWAIT);
3289 		if (ret < 0) {
3290 			pfd->revents |= ZSOCK_POLLERR;
3291 		}
3292 
3293 		pfd->revents &= ~ZSOCK_POLLIN;
3294 		return true;
3295 	}
3296 
3297 	/* Handshake complete, just proceed. */
3298 	return false;
3299 }
3300 
ztls_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)3301 static int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int timeout)
3302 {
3303 	int fd_backup[CONFIG_NET_SOCKETS_POLL_MAX];
3304 	const struct fd_op_vtable *vtable;
3305 	void *ctx;
3306 	int ret = 0;
3307 	int result;
3308 	int i;
3309 	bool retry;
3310 	int remaining;
3311 	uint32_t entry = k_uptime_get_32();
3312 
3313 	/* Overwrite TLS file descriptors with underlying ones. */
3314 	for (i = 0; i < nfds; i++) {
3315 		fd_backup[i] = fds[i].fd;
3316 
3317 		ctx = zvfs_get_fd_obj(fds[i].fd,
3318 				   (const struct fd_op_vtable *)
3319 						     &tls_sock_fd_op_vtable,
3320 				   0);
3321 		if (ctx == NULL) {
3322 			continue;
3323 		}
3324 
3325 		if (fds[i].events & ZSOCK_POLLIN) {
3326 			ret = ztls_poll_prepare_pollin(ctx);
3327 			/* In case data is already available in mbedtls,
3328 			 * do not wait in poll.
3329 			 */
3330 			if (ret == -EALREADY) {
3331 				timeout = 0;
3332 			}
3333 		}
3334 
3335 		fds[i].fd = ((struct tls_context *)ctx)->sock;
3336 	}
3337 
3338 	/* Get offloaded sockets vtable. */
3339 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
3340 				      (const struct fd_op_vtable **)&vtable,
3341 				      NULL);
3342 	if (ctx == NULL) {
3343 		errno = EINVAL;
3344 		goto exit;
3345 	}
3346 
3347 	remaining = timeout;
3348 
3349 	do {
3350 		for (i = 0; i < nfds; i++) {
3351 			fds[i].revents = 0;
3352 		}
3353 
3354 		ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
3355 					   fds, nfds, remaining);
3356 		if (ret < 0) {
3357 			goto exit;
3358 		}
3359 
3360 		retry = false;
3361 		ret = 0;
3362 
3363 		for (i = 0; i < nfds; i++) {
3364 			ctx = zvfs_get_fd_obj(fd_backup[i],
3365 					   (const struct fd_op_vtable *)
3366 							&tls_sock_fd_op_vtable,
3367 					   0);
3368 			if (ctx != NULL) {
3369 				if (fds[i].events & ZSOCK_POLLIN) {
3370 					if (poll_offload_dtls_client_retry(
3371 							ctx, &fds[i])) {
3372 						retry = true;
3373 						continue;
3374 					}
3375 
3376 					result = ztls_poll_update_pollin(
3377 						    fd_backup[i], ctx, &fds[i]);
3378 					if (result == -EAGAIN) {
3379 						retry = true;
3380 					}
3381 				}
3382 			}
3383 
3384 			if (fds[i].revents != 0) {
3385 				ret++;
3386 			}
3387 		}
3388 
3389 		if (retry) {
3390 			if (ret > 0 || timeout == 0) {
3391 				goto exit;
3392 			}
3393 
3394 			if (timeout > 0) {
3395 				remaining = time_left(entry, timeout);
3396 				if (remaining <= 0) {
3397 					goto exit;
3398 				}
3399 			}
3400 		}
3401 	} while (retry);
3402 
3403 exit:
3404 	/* Restore original fds. */
3405 	for (i = 0; i < nfds; i++) {
3406 		fds[i].fd = fd_backup[i];
3407 	}
3408 
3409 	return ret;
3410 }
3411 
ztls_getsockopt_ctx(struct tls_context * ctx,int level,int optname,void * optval,socklen_t * optlen)3412 int ztls_getsockopt_ctx(struct tls_context *ctx, int level, int optname,
3413 			void *optval, socklen_t *optlen)
3414 {
3415 	int err;
3416 
3417 	if (!optval || !optlen) {
3418 		errno = EINVAL;
3419 		return -1;
3420 	}
3421 
3422 	if ((level == SOL_SOCKET) && (optname == SO_PROTOCOL)) {
3423 		/* Protocol type is overridden during socket creation. Its
3424 		 * value is restored here to return current value.
3425 		 */
3426 		err = sock_opt_protocol_get(ctx, optval, optlen);
3427 		if (err < 0) {
3428 			errno = -err;
3429 			return -1;
3430 		}
3431 		return err;
3432 	}
3433 
3434 	/* In case error was set on a socket at the TLS layer (for example due
3435 	 * to receiving TLS alert), handle SO_ERROR here, and report that error.
3436 	 * Otherwise, forward the SO_ERROR option request to the underlying
3437 	 * TCP/UDP socket to handle.
3438 	 */
3439 	if ((level == SOL_SOCKET) && (optname == SO_ERROR) && ctx->error != 0) {
3440 		if (*optlen != sizeof(int)) {
3441 			errno = EINVAL;
3442 			return -1;
3443 		}
3444 
3445 		*(int *)optval = ctx->error;
3446 
3447 		return 0;
3448 	}
3449 
3450 	if (level != SOL_TLS) {
3451 		return zsock_getsockopt(ctx->sock, level, optname,
3452 					optval, optlen);
3453 	}
3454 
3455 	switch (optname) {
3456 	case TLS_SEC_TAG_LIST:
3457 		err =  tls_opt_sec_tag_list_get(ctx, optval, optlen);
3458 		break;
3459 
3460 	case TLS_CIPHERSUITE_LIST:
3461 		err = tls_opt_ciphersuite_list_get(ctx, optval, optlen);
3462 		break;
3463 
3464 	case TLS_CIPHERSUITE_USED:
3465 		err = tls_opt_ciphersuite_used_get(ctx, optval, optlen);
3466 		break;
3467 
3468 	case TLS_ALPN_LIST:
3469 		err = tls_opt_alpn_list_get(ctx, optval, optlen);
3470 		break;
3471 
3472 	case TLS_SESSION_CACHE:
3473 		err = tls_opt_session_cache_get(ctx, optval, optlen);
3474 		break;
3475 
3476 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3477 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3478 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3479 							 optlen, false);
3480 		break;
3481 
3482 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3483 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3484 							 optlen, true);
3485 		break;
3486 
3487 	case TLS_DTLS_CID_STATUS:
3488 		err = tls_opt_dtls_connection_id_status_get(ctx, optval,
3489 							    optlen);
3490 		break;
3491 
3492 	case TLS_DTLS_CID_VALUE:
3493 		err = tls_opt_dtls_connection_id_value_get(ctx, optval, optlen);
3494 		break;
3495 
3496 	case TLS_DTLS_PEER_CID_VALUE:
3497 		err = tls_opt_dtls_peer_connection_id_value_get(ctx, optval,
3498 								optlen);
3499 		break;
3500 
3501 	case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3502 		err = tls_opt_dtls_handshake_on_connect_get(ctx, optval, optlen);
3503 		break;
3504 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3505 
3506 	default:
3507 		/* Unknown or write-only option. */
3508 		err = -ENOPROTOOPT;
3509 		break;
3510 	}
3511 
3512 	if (err < 0) {
3513 		errno = -err;
3514 		return -1;
3515 	}
3516 
3517 	return 0;
3518 }
3519 
set_timeout_opt(k_timeout_t * timeout,const void * optval,socklen_t optlen)3520 static int set_timeout_opt(k_timeout_t *timeout, const void *optval,
3521 			   socklen_t optlen)
3522 {
3523 	const struct zsock_timeval *tval = optval;
3524 
3525 	if (optlen != sizeof(struct zsock_timeval)) {
3526 		return -EINVAL;
3527 	}
3528 
3529 	if (tval->tv_sec == 0 && tval->tv_usec == 0) {
3530 		*timeout = K_FOREVER;
3531 	} else {
3532 		*timeout = K_USEC(tval->tv_sec * 1000000ULL + tval->tv_usec);
3533 	}
3534 
3535 	return 0;
3536 }
3537 
ztls_setsockopt_ctx(struct tls_context * ctx,int level,int optname,const void * optval,socklen_t optlen)3538 int ztls_setsockopt_ctx(struct tls_context *ctx, int level, int optname,
3539 			const void *optval, socklen_t optlen)
3540 {
3541 	int err;
3542 
3543 	/* Underlying socket is used in non-blocking mode, hence implement
3544 	 * timeout at the TLS socket level.
3545 	 */
3546 	if ((level == SOL_SOCKET) && (optname == SO_SNDTIMEO)) {
3547 		err = set_timeout_opt(&ctx->options.timeout_tx, optval, optlen);
3548 		goto out;
3549 	}
3550 
3551 	if ((level == SOL_SOCKET) && (optname == SO_RCVTIMEO)) {
3552 		err = set_timeout_opt(&ctx->options.timeout_rx, optval, optlen);
3553 		goto out;
3554 	}
3555 
3556 	if (level != SOL_TLS) {
3557 		return zsock_setsockopt(ctx->sock, level, optname,
3558 					optval, optlen);
3559 	}
3560 
3561 	switch (optname) {
3562 	case TLS_SEC_TAG_LIST:
3563 		err =  tls_opt_sec_tag_list_set(ctx, optval, optlen);
3564 		break;
3565 
3566 	case TLS_HOSTNAME:
3567 		err = tls_opt_hostname_set(ctx, optval, optlen);
3568 		break;
3569 
3570 	case TLS_CIPHERSUITE_LIST:
3571 		err = tls_opt_ciphersuite_list_set(ctx, optval, optlen);
3572 		break;
3573 
3574 	case TLS_PEER_VERIFY:
3575 		err = tls_opt_peer_verify_set(ctx, optval, optlen);
3576 		break;
3577 
3578 	case TLS_CERT_NOCOPY:
3579 		err = tls_opt_cert_nocopy_set(ctx, optval, optlen);
3580 		break;
3581 
3582 	case TLS_DTLS_ROLE:
3583 		err = tls_opt_dtls_role_set(ctx, optval, optlen);
3584 		break;
3585 
3586 	case TLS_ALPN_LIST:
3587 		err = tls_opt_alpn_list_set(ctx, optval, optlen);
3588 		break;
3589 
3590 	case TLS_SESSION_CACHE:
3591 		err = tls_opt_session_cache_set(ctx, optval, optlen);
3592 		break;
3593 
3594 	case TLS_SESSION_CACHE_PURGE:
3595 		err = tls_opt_session_cache_purge_set(ctx, optval, optlen);
3596 		break;
3597 
3598 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3599 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3600 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3601 							 optlen, false);
3602 		break;
3603 
3604 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3605 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3606 							 optlen, true);
3607 		break;
3608 
3609 	case TLS_DTLS_CID:
3610 		err = tls_opt_dtls_connection_id_set(ctx, optval, optlen);
3611 		break;
3612 
3613 	case TLS_DTLS_CID_VALUE:
3614 		err = tls_opt_dtls_connection_id_value_set(ctx, optval, optlen);
3615 		break;
3616 
3617 	case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3618 		err = tls_opt_dtls_handshake_on_connect_set(ctx, optval, optlen);
3619 		break;
3620 
3621 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3622 
3623 	case TLS_NATIVE:
3624 		/* Option handled at the socket dispatcher level. */
3625 		err = 0;
3626 		break;
3627 
3628 	default:
3629 		/* Unknown or read-only option. */
3630 		err = -ENOPROTOOPT;
3631 		break;
3632 	}
3633 
3634 out:
3635 	if (err < 0) {
3636 		errno = -err;
3637 		return -1;
3638 	}
3639 
3640 	return 0;
3641 }
3642 
3643 #if defined(CONFIG_NET_TEST)
ztls_get_mbedtls_ssl_context(int fd)3644 mbedtls_ssl_context *ztls_get_mbedtls_ssl_context(int fd)
3645 {
3646 	struct tls_context *ctx;
3647 
3648 	ctx = zvfs_get_fd_obj(fd, (const struct fd_op_vtable *)
3649 					&tls_sock_fd_op_vtable, EBADF);
3650 	if (ctx == NULL) {
3651 		return NULL;
3652 	}
3653 
3654 	return &ctx->ssl;
3655 }
3656 #endif /* CONFIG_NET_TEST */
3657 
tls_sock_read_vmeth(void * obj,void * buffer,size_t count)3658 static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count)
3659 {
3660 	return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0);
3661 }
3662 
tls_sock_write_vmeth(void * obj,const void * buffer,size_t count)3663 static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer,
3664 				    size_t count)
3665 {
3666 	return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0);
3667 }
3668 
tls_sock_ioctl_vmeth(void * obj,unsigned int request,va_list args)3669 static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
3670 {
3671 	struct tls_context *ctx = obj;
3672 
3673 	switch (request) {
3674 	/* fcntl() commands */
3675 	case F_GETFL:
3676 	case F_SETFL: {
3677 		const struct fd_op_vtable *vtable;
3678 		struct k_mutex *lock;
3679 		void *fd_obj;
3680 		int ret;
3681 
3682 		fd_obj = zvfs_get_fd_obj_and_vtable(ctx->sock,
3683 				(const struct fd_op_vtable **)&vtable, &lock);
3684 		if (fd_obj == NULL) {
3685 			errno = EBADF;
3686 			return -1;
3687 		}
3688 
3689 		(void)k_mutex_lock(lock, K_FOREVER);
3690 
3691 		/* Pass the call to the core socket implementation. */
3692 		ret = vtable->ioctl(fd_obj, request, args);
3693 
3694 		k_mutex_unlock(lock);
3695 
3696 		return ret;
3697 	}
3698 
3699 	case ZFD_IOCTL_SET_LOCK: {
3700 		struct k_mutex *lock;
3701 
3702 		lock = va_arg(args, struct k_mutex *);
3703 
3704 		ctx_set_lock(obj, lock);
3705 
3706 		return 0;
3707 	}
3708 
3709 	case ZFD_IOCTL_POLL_PREPARE: {
3710 		struct zsock_pollfd *pfd;
3711 		struct k_poll_event **pev;
3712 		struct k_poll_event *pev_end;
3713 
3714 		pfd = va_arg(args, struct zsock_pollfd *);
3715 		pev = va_arg(args, struct k_poll_event **);
3716 		pev_end = va_arg(args, struct k_poll_event *);
3717 
3718 		return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end);
3719 	}
3720 
3721 	case ZFD_IOCTL_POLL_UPDATE: {
3722 		struct zsock_pollfd *pfd;
3723 		struct k_poll_event **pev;
3724 
3725 		pfd = va_arg(args, struct zsock_pollfd *);
3726 		pev = va_arg(args, struct k_poll_event **);
3727 
3728 		return ztls_poll_update_ctx(obj, pfd, pev);
3729 	}
3730 
3731 	case ZFD_IOCTL_POLL_OFFLOAD: {
3732 		struct zsock_pollfd *fds;
3733 		int nfds;
3734 		int timeout;
3735 
3736 		fds = va_arg(args, struct zsock_pollfd *);
3737 		nfds = va_arg(args, int);
3738 		timeout = va_arg(args, int);
3739 
3740 		return ztls_poll_offload(fds, nfds, timeout);
3741 	}
3742 
3743 	default:
3744 		errno = EOPNOTSUPP;
3745 		return -1;
3746 	}
3747 }
3748 
tls_sock_shutdown_vmeth(void * obj,int how)3749 static int tls_sock_shutdown_vmeth(void *obj, int how)
3750 {
3751 	struct tls_context *ctx = obj;
3752 
3753 	return zsock_shutdown(ctx->sock, how);
3754 }
3755 
tls_sock_bind_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3756 static int tls_sock_bind_vmeth(void *obj, const struct sockaddr *addr,
3757 			       socklen_t addrlen)
3758 {
3759 	struct tls_context *ctx = obj;
3760 
3761 	return zsock_bind(ctx->sock, addr, addrlen);
3762 }
3763 
tls_sock_connect_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3764 static int tls_sock_connect_vmeth(void *obj, const struct sockaddr *addr,
3765 				  socklen_t addrlen)
3766 {
3767 	return ztls_connect_ctx(obj, addr, addrlen);
3768 }
3769 
tls_sock_listen_vmeth(void * obj,int backlog)3770 static int tls_sock_listen_vmeth(void *obj, int backlog)
3771 {
3772 	struct tls_context *ctx = obj;
3773 
3774 	ctx->is_listening = true;
3775 
3776 	return zsock_listen(ctx->sock, backlog);
3777 }
3778 
tls_sock_accept_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3779 static int tls_sock_accept_vmeth(void *obj, struct sockaddr *addr,
3780 				 socklen_t *addrlen)
3781 {
3782 	return ztls_accept_ctx(obj, addr, addrlen);
3783 }
3784 
tls_sock_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)3785 static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len,
3786 				     int flags,
3787 				     const struct sockaddr *dest_addr,
3788 				     socklen_t addrlen)
3789 {
3790 	return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen);
3791 }
3792 
tls_sock_sendmsg_vmeth(void * obj,const struct msghdr * msg,int flags)3793 static ssize_t tls_sock_sendmsg_vmeth(void *obj, const struct msghdr *msg,
3794 				      int flags)
3795 {
3796 	return ztls_sendmsg_ctx(obj, msg, flags);
3797 }
3798 
tls_sock_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)3799 static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len,
3800 				       int flags, struct sockaddr *src_addr,
3801 				       socklen_t *addrlen)
3802 {
3803 	return ztls_recvfrom_ctx(obj, buf, max_len, flags,
3804 				 src_addr, addrlen);
3805 }
3806 
tls_sock_getsockopt_vmeth(void * obj,int level,int optname,void * optval,socklen_t * optlen)3807 static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname,
3808 				     void *optval, socklen_t *optlen)
3809 {
3810 	return ztls_getsockopt_ctx(obj, level, optname, optval, optlen);
3811 }
3812 
tls_sock_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,socklen_t optlen)3813 static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname,
3814 				     const void *optval, socklen_t optlen)
3815 {
3816 	return ztls_setsockopt_ctx(obj, level, optname, optval, optlen);
3817 }
3818 
tls_sock_close_vmeth(void * obj)3819 static int tls_sock_close_vmeth(void *obj)
3820 {
3821 	return ztls_close_ctx(obj);
3822 }
3823 
tls_sock_getpeername_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3824 static int tls_sock_getpeername_vmeth(void *obj, struct sockaddr *addr,
3825 				      socklen_t *addrlen)
3826 {
3827 	struct tls_context *ctx = obj;
3828 
3829 	return zsock_getpeername(ctx->sock, addr, addrlen);
3830 }
3831 
tls_sock_getsockname_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3832 static int tls_sock_getsockname_vmeth(void *obj, struct sockaddr *addr,
3833 				      socklen_t *addrlen)
3834 {
3835 	struct tls_context *ctx = obj;
3836 
3837 	return zsock_getsockname(ctx->sock, addr, addrlen);
3838 }
3839 
3840 static const struct socket_op_vtable tls_sock_fd_op_vtable = {
3841 	.fd_vtable = {
3842 		.read = tls_sock_read_vmeth,
3843 		.write = tls_sock_write_vmeth,
3844 		.close = tls_sock_close_vmeth,
3845 		.ioctl = tls_sock_ioctl_vmeth,
3846 	},
3847 	.shutdown = tls_sock_shutdown_vmeth,
3848 	.bind = tls_sock_bind_vmeth,
3849 	.connect = tls_sock_connect_vmeth,
3850 	.listen = tls_sock_listen_vmeth,
3851 	.accept = tls_sock_accept_vmeth,
3852 	.sendto = tls_sock_sendto_vmeth,
3853 	.sendmsg = tls_sock_sendmsg_vmeth,
3854 	.recvfrom = tls_sock_recvfrom_vmeth,
3855 	.getsockopt = tls_sock_getsockopt_vmeth,
3856 	.setsockopt = tls_sock_setsockopt_vmeth,
3857 	.getpeername = tls_sock_getpeername_vmeth,
3858 	.getsockname = tls_sock_getsockname_vmeth,
3859 };
3860 
tls_is_supported(int family,int type,int proto)3861 static bool tls_is_supported(int family, int type, int proto)
3862 {
3863 	if (protocol_check(family, type, &proto) == 0) {
3864 		return true;
3865 	}
3866 
3867 	return false;
3868 }
3869 
3870 /* Since both, TLS sockets and regular ones fall under the same address family,
3871  * it's required to process TLS first in order to capture socket calls which
3872  * create sockets for secure protocols. Every other call for AF_INET/AF_INET6
3873  * will be forwarded to regular socket implementation.
3874  */
3875 BUILD_ASSERT(CONFIG_NET_SOCKETS_TLS_PRIORITY < CONFIG_NET_SOCKETS_PRIORITY_DEFAULT,
3876 	     "CONFIG_NET_SOCKETS_TLS_PRIORITY have to be smaller than CONFIG_NET_SOCKETS_PRIORITY_DEFAULT");
3877 
3878 NET_SOCKET_REGISTER(tls, CONFIG_NET_SOCKETS_TLS_PRIORITY, AF_UNSPEC,
3879 		    tls_is_supported, ztls_socket);
3880