1 /*
2  * Copyright (c) 2018 Intel Corporation
3  * Copyright (c) 2018 Nordic Semiconductor ASA
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <stdbool.h>
9 #include <zephyr/posix/fcntl.h>
10 
11 #include <zephyr/logging/log.h>
12 LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL);
13 
14 #include <zephyr/init.h>
15 #include <zephyr/sys/util.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/random/random.h>
18 #include <zephyr/internal/syscall_handler.h>
19 #include <zephyr/sys/fdtable.h>
20 
21 /* TODO: Remove all direct access to private fields.
22  * According with Mbed TLS migration guide:
23  *
24  * Direct access to fields of structures
25  * (`struct` types) declared in public headers is no longer
26  * supported. In Mbed TLS 3, the layout of structures is not
27  * considered part of the stable API, and minor versions (3.1, 3.2,
28  * etc.) may add, remove, rename, reorder or change the type of
29  * structure fields.
30  */
31 #if !defined(MBEDTLS_ALLOW_PRIVATE_ACCESS)
32 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
33 #endif
34 
35 #if defined(CONFIG_MBEDTLS)
36 #if !defined(CONFIG_MBEDTLS_CFG_FILE)
37 #include "mbedtls/config.h"
38 #else
39 #include CONFIG_MBEDTLS_CFG_FILE
40 #endif /* CONFIG_MBEDTLS_CFG_FILE */
41 
42 #include <mbedtls/net_sockets.h>
43 #include <mbedtls/x509.h>
44 #include <mbedtls/x509_crt.h>
45 #include <mbedtls/ssl.h>
46 #include <mbedtls/ssl_cookie.h>
47 #include <mbedtls/error.h>
48 #include <mbedtls/platform.h>
49 #include <mbedtls/ssl_cache.h>
50 #endif /* CONFIG_MBEDTLS */
51 
52 #include "sockets_internal.h"
53 #include "tls_internal.h"
54 
55 #if defined(CONFIG_MBEDTLS_DEBUG)
56 #include <zephyr_mbedtls_priv.h>
57 #endif
58 
59 #if defined(CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS)
60 #define ALPN_MAX_PROTOCOLS (CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS + 1)
61 #else
62 #define ALPN_MAX_PROTOCOLS 0
63 #endif /* CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS */
64 
65 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
66 #define DTLS_SENDMSG_BUF_SIZE (CONFIG_NET_SOCKETS_DTLS_SENDMSG_BUF_SIZE)
67 #else
68 #define DTLS_SENDMSG_BUF_SIZE 0
69 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
70 
71 static const struct socket_op_vtable tls_sock_fd_op_vtable;
72 
73 #ifndef MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED
74 #define MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE
75 #endif
76 
77 /** A list of secure tags that TLS context should use. */
78 struct sec_tag_list {
79 	/** An array of secure tags referencing TLS credentials. */
80 	sec_tag_t sec_tags[CONFIG_NET_SOCKETS_TLS_MAX_CREDENTIALS];
81 
82 	/** Number of configured secure tags. */
83 	int sec_tag_count;
84 };
85 
86 /** Timer context for DTLS. */
87 struct dtls_timing_context {
88 	/** Current time, stored during timer set. */
89 	uint32_t snapshot;
90 
91 	/** Intermediate delay value. For details, refer to mbedTLS API
92 	 *  documentation (mbedtls_ssl_set_timer_t).
93 	 */
94 	uint32_t int_ms;
95 
96 	/** Final delay value. For details, refer to mbedTLS API documentation
97 	 *  (mbedtls_ssl_set_timer_t).
98 	 */
99 	uint32_t fin_ms;
100 };
101 
102 /** TLS peer address/session ID mapping. */
103 struct tls_session_cache {
104 	/** Creation time. */
105 	int64_t timestamp;
106 
107 	/** Peer address. */
108 	struct sockaddr peer_addr;
109 
110 	/** Session buffer. */
111 	uint8_t *session;
112 
113 	/** Session length. */
114 	size_t session_len;
115 };
116 
117 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
118 struct tls_dtls_cid {
119 	bool enabled;
120 	unsigned char cid[MAX(MBEDTLS_SSL_CID_OUT_LEN_MAX,
121 			      MBEDTLS_SSL_CID_IN_LEN_MAX)];
122 	size_t cid_len;
123 };
124 #endif
125 
126 /** TLS context information. */
127 __net_socket struct tls_context {
128 	/** Underlying TCP/UDP socket. */
129 	int sock;
130 
131 	/** Information whether TLS context is used. */
132 	bool is_used : 1;
133 
134 	/** Information whether TLS context was initialized. */
135 	bool is_initialized : 1;
136 
137 	/** Information whether underlying socket is listening. */
138 	bool is_listening : 1;
139 
140 	/** Information whether TLS handshake is currently in progress. */
141 	bool handshake_in_progress : 1;
142 
143 	/** Session ended at the TLS/DTLS level. */
144 	bool session_closed : 1;
145 
146 	/** Socket type. */
147 	enum net_sock_type type;
148 
149 	/** Secure protocol version running on TLS context. */
150 	enum net_ip_protocol_secure tls_version;
151 
152 	/** Socket flags passed to a socket call. */
153 	int flags;
154 
155 	/* Indicates whether socket is in error state at TLS/DTLS level. */
156 	int error;
157 
158 	/** Information whether TLS handshake is complete or not. */
159 	struct k_sem tls_established;
160 
161 	/* TLS socket mutex lock. */
162 	struct k_mutex *lock;
163 
164 	/** TLS specific option values. */
165 	struct {
166 		/** Select which credentials to use with TLS. */
167 		struct sec_tag_list sec_tag_list;
168 
169 		/** 0-terminated list of allowed ciphersuites (mbedTLS format).
170 		 */
171 		int ciphersuites[CONFIG_NET_SOCKETS_TLS_MAX_CIPHERSUITES + 1];
172 
173 		/** Information if hostname was explicitly set on a socket. */
174 		bool is_hostname_set;
175 
176 		/** Peer verification level. */
177 		int8_t verify_level;
178 
179 		/** Indicating on whether DER certificates should not be copied
180 		 * to the heap.
181 		 */
182 		int8_t cert_nocopy;
183 
184 		/** DTLS role, client by default. */
185 		int8_t role;
186 
187 		/** NULL-terminated list of allowed application layer
188 		 * protocols.
189 		 */
190 		const char *alpn_list[ALPN_MAX_PROTOCOLS];
191 
192 		/** Session cache enabled on a socket. */
193 		bool cache_enabled;
194 
195 		/** Socket TX timeout */
196 		k_timeout_t timeout_tx;
197 
198 		/** Socket RX timeout */
199 		k_timeout_t timeout_rx;
200 
201 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
202 		/* DTLS handshake timeout */
203 		uint32_t dtls_handshake_timeout_min;
204 		uint32_t dtls_handshake_timeout_max;
205 
206 		struct tls_dtls_cid dtls_cid;
207 
208 		bool dtls_handshake_on_connect;
209 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
210 	} options;
211 
212 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
213 	/** Context information for DTLS timing. */
214 	struct dtls_timing_context dtls_timing;
215 
216 	/** mbedTLS cookie context for DTLS */
217 	mbedtls_ssl_cookie_ctx cookie;
218 
219 	/** DTLS peer address. */
220 	struct sockaddr dtls_peer_addr;
221 
222 	/** DTLS peer address length. */
223 	socklen_t dtls_peer_addrlen;
224 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
225 
226 #if defined(CONFIG_MBEDTLS)
227 	/** mbedTLS context. */
228 	mbedtls_ssl_context ssl;
229 
230 	/** mbedTLS configuration. */
231 	mbedtls_ssl_config config;
232 
233 #if defined(MBEDTLS_X509_CRT_PARSE_C)
234 	/** mbedTLS structure for CA chain. */
235 	mbedtls_x509_crt ca_chain;
236 
237 	/** mbedTLS structure for own certificate. */
238 	mbedtls_x509_crt own_cert;
239 
240 	/** mbedTLS structure for own private key. */
241 	mbedtls_pk_context priv_key;
242 #endif /* MBEDTLS_X509_CRT_PARSE_C */
243 
244 #endif /* CONFIG_MBEDTLS */
245 };
246 
247 
248 /* A global pool of TLS contexts. */
249 static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS];
250 
251 static struct tls_session_cache client_cache[CONFIG_NET_SOCKETS_TLS_MAX_CLIENT_SESSION_COUNT];
252 
253 #if defined(MBEDTLS_SSL_CACHE_C)
254 static mbedtls_ssl_cache_context server_cache;
255 #endif
256 
257 /* A mutex for protecting TLS context allocation. */
258 static struct k_mutex context_lock;
259 
260 /* Arbitrary delay value to wait if mbedTLS reports it cannot proceed for
261  * reasons other than TX/RX block.
262  */
263 #define TLS_WAIT_MS 100
264 
tls_session_cache_reset(void)265 static void tls_session_cache_reset(void)
266 {
267 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
268 		if (client_cache[i].session != NULL) {
269 			mbedtls_free(client_cache[i].session);
270 		}
271 	}
272 
273 	(void)memset(client_cache, 0, sizeof(client_cache));
274 }
275 
net_socket_is_tls(void * obj)276 bool net_socket_is_tls(void *obj)
277 {
278 	return PART_OF_ARRAY(tls_contexts, (struct tls_context *)obj);
279 }
280 
tls_ctr_drbg_random(void * ctx,unsigned char * buf,size_t len)281 static int tls_ctr_drbg_random(void *ctx, unsigned char *buf, size_t len)
282 {
283 	ARG_UNUSED(ctx);
284 
285 #if defined(CONFIG_CSPRNG_ENABLED)
286 	return sys_csrand_get(buf, len);
287 #else
288 	sys_rand_get(buf, len);
289 
290 	return 0;
291 #endif
292 }
293 
294 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
295 /* mbedTLS-defined function for setting timer. */
dtls_timing_set_delay(void * data,uint32_t int_ms,uint32_t fin_ms)296 static void dtls_timing_set_delay(void *data, uint32_t int_ms, uint32_t fin_ms)
297 {
298 	struct dtls_timing_context *ctx = data;
299 
300 	ctx->int_ms = int_ms;
301 	ctx->fin_ms = fin_ms;
302 
303 	if (fin_ms != 0U) {
304 		ctx->snapshot = k_uptime_get_32();
305 	}
306 }
307 
308 /* mbedTLS-defined function for getting timer status.
309  * The return values are specified by mbedTLS. The callback must return:
310  *   -1 if cancelled (fin_ms == 0),
311  *    0 if none of the delays have passed,
312  *    1 if only the intermediate delay has passed,
313  *    2 if the final delay has passed.
314  */
dtls_timing_get_delay(void * data)315 static int dtls_timing_get_delay(void *data)
316 {
317 	struct dtls_timing_context *timing = data;
318 	unsigned long elapsed_ms;
319 
320 	NET_ASSERT(timing);
321 
322 	if (timing->fin_ms == 0U) {
323 		return -1;
324 	}
325 
326 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
327 
328 	if (elapsed_ms >= timing->fin_ms) {
329 		return 2;
330 	}
331 
332 	if (elapsed_ms >= timing->int_ms) {
333 		return 1;
334 	}
335 
336 	return 0;
337 }
338 
dtls_get_remaining_timeout(struct tls_context * ctx)339 static int dtls_get_remaining_timeout(struct tls_context *ctx)
340 {
341 	struct dtls_timing_context *timing = &ctx->dtls_timing;
342 	uint32_t elapsed_ms;
343 
344 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
345 
346 	if (timing->fin_ms == 0U) {
347 		return SYS_FOREVER_MS;
348 	}
349 
350 	if (elapsed_ms >= timing->fin_ms) {
351 		return 0;
352 	}
353 
354 	return timing->fin_ms - elapsed_ms;
355 }
356 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
357 
358 /* Initialize TLS internals. */
tls_init(void)359 static int tls_init(void)
360 {
361 
362 #if !defined(CONFIG_ENTROPY_HAS_DRIVER)
363 	NET_WARN("No entropy device on the system, "
364 		 "TLS communication is insecure!");
365 #endif
366 
367 	(void)memset(tls_contexts, 0, sizeof(tls_contexts));
368 	(void)memset(client_cache, 0, sizeof(client_cache));
369 
370 	k_mutex_init(&context_lock);
371 
372 #if defined(MBEDTLS_SSL_CACHE_C)
373 	mbedtls_ssl_cache_init(&server_cache);
374 #endif
375 
376 	return 0;
377 }
378 
379 SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT);
380 
is_handshake_complete(struct tls_context * ctx)381 static inline bool is_handshake_complete(struct tls_context *ctx)
382 {
383 	return k_sem_count_get(&ctx->tls_established) != 0;
384 }
385 
386 /*
387  * Copied from include/mbedtls/ssl_internal.h
388  *
389  * Maximum length we can advertise as our max content length for
390  * RFC 6066 max_fragment_length extension negotiation purposes
391  * (the lesser of both sizes, if they are unequal.)
392  */
393 #define MBEDTLS_TLS_EXT_ADV_CONTENT_LEN (                            \
394 	(MBEDTLS_SSL_IN_CONTENT_LEN > MBEDTLS_SSL_OUT_CONTENT_LEN)   \
395 	? (MBEDTLS_SSL_OUT_CONTENT_LEN)				     \
396 	: (MBEDTLS_SSL_IN_CONTENT_LEN)				     \
397 	)
398 
399 #if defined(CONFIG_NET_SOCKETS_TLS_SET_MAX_FRAGMENT_LENGTH) &&	\
400 	defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) &&		\
401 	(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN < 16384)
402 
403 BUILD_ASSERT(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN >= 512,
404 	     "Too small content length!");
405 
tls_mfl_code_from_content_len(size_t len)406 static inline unsigned char tls_mfl_code_from_content_len(size_t len)
407 {
408 	if (len >= 4096) {
409 		return MBEDTLS_SSL_MAX_FRAG_LEN_4096;
410 	} else if (len >= 2048) {
411 		return MBEDTLS_SSL_MAX_FRAG_LEN_2048;
412 	} else if (len >= 1024) {
413 		return MBEDTLS_SSL_MAX_FRAG_LEN_1024;
414 	} else if (len >= 512) {
415 		return MBEDTLS_SSL_MAX_FRAG_LEN_512;
416 	} else {
417 		return MBEDTLS_SSL_MAX_FRAG_LEN_INVALID;
418 	}
419 }
420 
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)421 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type)
422 {
423 	unsigned char mfl_code;
424 	size_t len = MBEDTLS_TLS_EXT_ADV_CONTENT_LEN;
425 
426 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
427 	if (type == SOCK_DGRAM && len > CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH) {
428 		len = CONFIG_NET_SOCKETS_DTLS_MAX_FRAGMENT_LENGTH;
429 	}
430 #endif
431 	mfl_code = tls_mfl_code_from_content_len(len);
432 
433 	mbedtls_ssl_conf_max_frag_len(config, mfl_code);
434 }
435 #else
tls_set_max_frag_len(mbedtls_ssl_config * config,enum net_sock_type type)436 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config, enum net_sock_type type) {}
437 #endif
438 
439 /* Allocate TLS context. */
tls_alloc(void)440 static struct tls_context *tls_alloc(void)
441 {
442 	int i;
443 	struct tls_context *tls = NULL;
444 
445 	k_mutex_lock(&context_lock, K_FOREVER);
446 
447 	for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) {
448 		if (!tls_contexts[i].is_used) {
449 			tls = &tls_contexts[i];
450 			(void)memset(tls, 0, sizeof(*tls));
451 			tls->is_used = true;
452 			tls->options.verify_level = -1;
453 			tls->options.timeout_tx = K_FOREVER;
454 			tls->options.timeout_rx = K_FOREVER;
455 			tls->sock = -1;
456 
457 			NET_DBG("Allocated TLS context, %p", tls);
458 			break;
459 		}
460 	}
461 
462 	k_mutex_unlock(&context_lock);
463 
464 	if (tls) {
465 		k_sem_init(&tls->tls_established, 0, 1);
466 
467 		mbedtls_ssl_init(&tls->ssl);
468 		mbedtls_ssl_config_init(&tls->config);
469 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
470 		mbedtls_ssl_cookie_init(&tls->cookie);
471 		tls->options.dtls_handshake_timeout_min =
472 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN;
473 		tls->options.dtls_handshake_timeout_max =
474 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX;
475 		tls->options.dtls_cid.cid_len = 0;
476 		tls->options.dtls_cid.enabled = false;
477 		tls->options.dtls_handshake_on_connect = true;
478 #endif
479 #if defined(MBEDTLS_X509_CRT_PARSE_C)
480 		mbedtls_x509_crt_init(&tls->ca_chain);
481 		mbedtls_x509_crt_init(&tls->own_cert);
482 		mbedtls_pk_init(&tls->priv_key);
483 #endif
484 
485 #if defined(CONFIG_MBEDTLS_DEBUG)
486 		mbedtls_ssl_conf_dbg(&tls->config, zephyr_mbedtls_debug, NULL);
487 #endif
488 	} else {
489 		NET_WARN("Failed to allocate TLS context");
490 	}
491 
492 	return tls;
493 }
494 
495 /* Allocate new TLS context and copy the content from the source context. */
tls_clone(struct tls_context * source_tls)496 static struct tls_context *tls_clone(struct tls_context *source_tls)
497 {
498 	struct tls_context *target_tls;
499 
500 	target_tls = tls_alloc();
501 	if (!target_tls) {
502 		return NULL;
503 	}
504 
505 	target_tls->tls_version = source_tls->tls_version;
506 	target_tls->type = source_tls->type;
507 
508 	memcpy(&target_tls->options, &source_tls->options,
509 	       sizeof(target_tls->options));
510 
511 #if defined(MBEDTLS_X509_CRT_PARSE_C)
512 	if (target_tls->options.is_hostname_set) {
513 		mbedtls_ssl_set_hostname(&target_tls->ssl,
514 					 source_tls->ssl.hostname);
515 	}
516 #endif
517 
518 	return target_tls;
519 }
520 
521 /* Release TLS context. */
tls_release(struct tls_context * tls)522 static int tls_release(struct tls_context *tls)
523 {
524 	if (!PART_OF_ARRAY(tls_contexts, tls)) {
525 		NET_ERR("Invalid TLS context");
526 		return -EBADF;
527 	}
528 
529 	if (!tls->is_used) {
530 		NET_ERR("Deallocating unused TLS context");
531 		return -EBADF;
532 	}
533 
534 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
535 	mbedtls_ssl_cookie_free(&tls->cookie);
536 #endif
537 	mbedtls_ssl_config_free(&tls->config);
538 	mbedtls_ssl_free(&tls->ssl);
539 #if defined(MBEDTLS_X509_CRT_PARSE_C)
540 	mbedtls_x509_crt_free(&tls->ca_chain);
541 	mbedtls_x509_crt_free(&tls->own_cert);
542 	mbedtls_pk_free(&tls->priv_key);
543 #endif
544 
545 	tls->is_used = false;
546 
547 	return 0;
548 }
549 
peer_addr_cmp(const struct sockaddr * addr,const struct sockaddr * peer_addr)550 static bool peer_addr_cmp(const struct sockaddr *addr,
551 			  const struct sockaddr *peer_addr)
552 {
553 	if (addr->sa_family != peer_addr->sa_family) {
554 		return false;
555 	}
556 
557 	if (IS_ENABLED(CONFIG_NET_IPV6) && peer_addr->sa_family == AF_INET6) {
558 		struct sockaddr_in6 *addr1 = net_sin6(peer_addr);
559 		struct sockaddr_in6 *addr2 = net_sin6(addr);
560 
561 		return (addr1->sin6_port == addr2->sin6_port) &&
562 			net_ipv6_addr_cmp(&addr1->sin6_addr, &addr2->sin6_addr);
563 	} else if (IS_ENABLED(CONFIG_NET_IPV4) && peer_addr->sa_family == AF_INET) {
564 		struct sockaddr_in *addr1 = net_sin(peer_addr);
565 		struct sockaddr_in *addr2 = net_sin(addr);
566 
567 		return (addr1->sin_port == addr2->sin_port) &&
568 			net_ipv4_addr_cmp(&addr1->sin_addr, &addr2->sin_addr);
569 	}
570 
571 	return false;
572 }
573 
tls_session_save(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)574 static int tls_session_save(const struct sockaddr *peer_addr,
575 			    mbedtls_ssl_session *session)
576 {
577 	struct tls_session_cache *entry = NULL;
578 	size_t session_len;
579 	int ret;
580 
581 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
582 		if (client_cache[i].session == NULL) {
583 			/* New entry. */
584 			if (entry == NULL || entry->session != NULL) {
585 				entry = &client_cache[i];
586 			}
587 		} else {
588 			if (peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
589 				/* Reuse old entry for given address. */
590 				entry = &client_cache[i];
591 				break;
592 			}
593 
594 			/* Remember the oldest entry and reuse if needed. */
595 			if (entry == NULL ||
596 			    (entry->session != NULL &&
597 			     entry->timestamp < client_cache[i].timestamp)) {
598 				entry = &client_cache[i];
599 			}
600 		}
601 	}
602 
603 	/* Allocate session and save */
604 
605 	if (entry->session != NULL) {
606 		mbedtls_free(entry->session);
607 		entry->session = NULL;
608 	}
609 
610 	(void)mbedtls_ssl_session_save(session, NULL, 0, &session_len);
611 
612 	entry->session = mbedtls_calloc(1, session_len);
613 	if (entry->session == NULL) {
614 		NET_ERR("Failed to allocate session buffer.");
615 		return -ENOMEM;
616 	}
617 
618 	ret = mbedtls_ssl_session_save(session, entry->session, session_len,
619 				       &session_len);
620 	if (ret < 0) {
621 		NET_ERR("Failed to serialize session, err: -0x%x.", -ret);
622 		mbedtls_free(entry->session);
623 		entry->session = NULL;
624 		return -ENOMEM;
625 	}
626 
627 	entry->session_len = session_len;
628 	entry->timestamp = k_uptime_get();
629 	memcpy(&entry->peer_addr, peer_addr, sizeof(*peer_addr));
630 
631 	return 0;
632 }
633 
tls_session_get(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)634 static int tls_session_get(const struct sockaddr *peer_addr,
635 			   mbedtls_ssl_session *session)
636 {
637 	struct tls_session_cache *entry = NULL;
638 	int ret;
639 
640 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
641 		if (client_cache[i].session != NULL &&
642 		    peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
643 			entry = &client_cache[i];
644 			break;
645 		}
646 	}
647 
648 	if (entry == NULL) {
649 		return -ENOENT;
650 	}
651 
652 	ret = mbedtls_ssl_session_load(session, entry->session,
653 				       entry->session_len);
654 	if (ret < 0) {
655 		/* Discard corrupted session data. */
656 		mbedtls_free(entry->session);
657 		entry->session = NULL;
658 		NET_ERR("Failed to load TLS session %d", ret);
659 		return -EIO;
660 	}
661 
662 	return 0;
663 }
664 
tls_session_store(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)665 static void tls_session_store(struct tls_context *context,
666 			      const struct sockaddr *addr,
667 			      socklen_t addrlen)
668 {
669 	mbedtls_ssl_session session;
670 	struct sockaddr peer_addr = { 0 };
671 	int ret;
672 
673 	if (!context->options.cache_enabled) {
674 		return;
675 	}
676 
677 	memcpy(&peer_addr, addr, addrlen);
678 	mbedtls_ssl_session_init(&session);
679 
680 	ret = mbedtls_ssl_get_session(&context->ssl, &session);
681 	if (ret < 0) {
682 		NET_ERR("Failed to obtain session for %p", context);
683 		goto exit;
684 	}
685 
686 	ret = tls_session_save(&peer_addr, &session);
687 	if (ret < 0) {
688 		NET_ERR("Failed to save session for %p", context);
689 	}
690 
691 exit:
692 	mbedtls_ssl_session_free(&session);
693 }
694 
tls_session_restore(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)695 static void tls_session_restore(struct tls_context *context,
696 				const struct sockaddr *addr,
697 				socklen_t addrlen)
698 {
699 	mbedtls_ssl_session session;
700 	struct sockaddr peer_addr = { 0 };
701 	int ret;
702 
703 	if (!context->options.cache_enabled) {
704 		return;
705 	}
706 
707 	memcpy(&peer_addr, addr, addrlen);
708 	mbedtls_ssl_session_init(&session);
709 
710 	ret = tls_session_get(&peer_addr, &session);
711 	if (ret < 0) {
712 		NET_DBG("Session not found for %p", context);
713 		goto exit;
714 	}
715 
716 	ret = mbedtls_ssl_set_session(&context->ssl, &session);
717 	if (ret < 0) {
718 		NET_ERR("Failed to set session for %p", context);
719 	}
720 
721 exit:
722 	mbedtls_ssl_session_free(&session);
723 }
724 
tls_session_purge(void)725 static void tls_session_purge(void)
726 {
727 	tls_session_cache_reset();
728 
729 #if defined(MBEDTLS_SSL_CACHE_C)
730 	mbedtls_ssl_cache_free(&server_cache);
731 	mbedtls_ssl_cache_init(&server_cache);
732 #endif
733 }
734 
time_left(uint32_t start,uint32_t timeout)735 static inline int time_left(uint32_t start, uint32_t timeout)
736 {
737 	uint32_t elapsed = k_uptime_get_32() - start;
738 
739 	return timeout - elapsed;
740 }
741 
wait(int sock,int timeout,int event)742 static int wait(int sock, int timeout, int event)
743 {
744 	struct zsock_pollfd fds = {
745 		.fd = sock,
746 		.events = event,
747 	};
748 	int ret;
749 
750 	ret = zsock_poll(&fds, 1, timeout);
751 	if (ret < 0) {
752 		return ret;
753 	}
754 
755 	if (ret == 1) {
756 		if (fds.revents & ZSOCK_POLLNVAL) {
757 			return -EBADF;
758 		}
759 
760 		if (fds.revents & ZSOCK_POLLERR) {
761 			int optval;
762 			socklen_t optlen = sizeof(optval);
763 
764 			if (zsock_getsockopt(fds.fd, SOL_SOCKET, SO_ERROR,
765 					     &optval, &optlen) == 0) {
766 				NET_ERR("TLS underlying socket poll error %d",
767 					-optval);
768 				return -optval;
769 			}
770 
771 			return -EIO;
772 		}
773 	}
774 
775 	return 0;
776 }
777 
wait_for_reason(int sock,int timeout,int reason)778 static int wait_for_reason(int sock, int timeout, int reason)
779 {
780 	if (reason == MBEDTLS_ERR_SSL_WANT_READ) {
781 		return wait(sock, timeout, ZSOCK_POLLIN);
782 	}
783 
784 	if (reason == MBEDTLS_ERR_SSL_WANT_WRITE) {
785 		return wait(sock, timeout, ZSOCK_POLLOUT);
786 	}
787 
788 	/* Any other reason - no way to monitor, just wait for some time. */
789 	k_msleep(TLS_WAIT_MS);
790 
791 	return 0;
792 }
793 
is_blocking(int sock,int flags)794 static bool is_blocking(int sock, int flags)
795 {
796 	int sock_flags = zsock_fcntl(sock, F_GETFL, 0);
797 
798 	if (sock_flags == -1) {
799 		return false;
800 	}
801 
802 	return !((flags & ZSOCK_MSG_DONTWAIT) || (sock_flags & O_NONBLOCK));
803 }
804 
timeout_to_ms(k_timeout_t * timeout)805 static int timeout_to_ms(k_timeout_t *timeout)
806 {
807 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
808 		return 0;
809 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
810 		return SYS_FOREVER_MS;
811 	} else {
812 		return k_ticks_to_ms_floor32(timeout->ticks);
813 	}
814 }
815 
ctx_set_lock(struct tls_context * ctx,struct k_mutex * lock)816 static void ctx_set_lock(struct tls_context *ctx, struct k_mutex *lock)
817 {
818 	ctx->lock = lock;
819 }
820 
821 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
dtls_is_peer_addr_valid(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)822 static bool dtls_is_peer_addr_valid(struct tls_context *context,
823 				    const struct sockaddr *peer_addr,
824 				    socklen_t addrlen)
825 {
826 	if (context->dtls_peer_addrlen != addrlen) {
827 		return false;
828 	}
829 
830 	return peer_addr_cmp(&context->dtls_peer_addr, peer_addr);
831 }
832 
dtls_peer_address_set(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)833 static void dtls_peer_address_set(struct tls_context *context,
834 				  const struct sockaddr *peer_addr,
835 				  socklen_t addrlen)
836 {
837 	if (addrlen <= sizeof(context->dtls_peer_addr)) {
838 		memcpy(&context->dtls_peer_addr, peer_addr, addrlen);
839 		context->dtls_peer_addrlen = addrlen;
840 	}
841 }
842 
dtls_peer_address_get(struct tls_context * context,struct sockaddr * peer_addr,socklen_t * addrlen)843 static void dtls_peer_address_get(struct tls_context *context,
844 				  struct sockaddr *peer_addr,
845 				  socklen_t *addrlen)
846 {
847 	socklen_t len = MIN(context->dtls_peer_addrlen, *addrlen);
848 
849 	memcpy(peer_addr, &context->dtls_peer_addr, len);
850 	*addrlen = len;
851 }
852 
dtls_tx(void * ctx,const unsigned char * buf,size_t len)853 static int dtls_tx(void *ctx, const unsigned char *buf, size_t len)
854 {
855 	struct tls_context *tls_ctx = ctx;
856 	ssize_t sent;
857 
858 	sent = zsock_sendto(tls_ctx->sock, buf, len, ZSOCK_MSG_DONTWAIT,
859 			    &tls_ctx->dtls_peer_addr,
860 			    tls_ctx->dtls_peer_addrlen);
861 	if (sent < 0) {
862 		if (errno == EAGAIN) {
863 			return MBEDTLS_ERR_SSL_WANT_WRITE;
864 		}
865 
866 		return MBEDTLS_ERR_NET_SEND_FAILED;
867 	}
868 
869 	return sent;
870 }
871 
dtls_rx(void * ctx,unsigned char * buf,size_t len)872 static int dtls_rx(void *ctx, unsigned char *buf, size_t len)
873 {
874 	struct tls_context *tls_ctx = ctx;
875 	socklen_t addrlen = sizeof(struct sockaddr);
876 	struct sockaddr addr;
877 	int err;
878 	ssize_t received;
879 
880 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
881 				  ZSOCK_MSG_DONTWAIT, &addr, &addrlen);
882 	if (received < 0) {
883 		if (errno == EAGAIN) {
884 			return MBEDTLS_ERR_SSL_WANT_READ;
885 		}
886 
887 		return MBEDTLS_ERR_NET_RECV_FAILED;
888 	}
889 
890 	if (tls_ctx->dtls_peer_addrlen == 0) {
891 		/* Only allow to store peer address for DTLS servers. */
892 		if (tls_ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
893 			dtls_peer_address_set(tls_ctx, &addr, addrlen);
894 
895 			err = mbedtls_ssl_set_client_transport_id(
896 				&tls_ctx->ssl,
897 				(const unsigned char *)&addr, addrlen);
898 			if (err < 0) {
899 				return err;
900 			}
901 		} else {
902 			/* For clients it's incorrect to receive when
903 			 * no peer has been set up.
904 			 */
905 			return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED;
906 		}
907 	} else if (!dtls_is_peer_addr_valid(tls_ctx, &addr, addrlen)) {
908 		return MBEDTLS_ERR_SSL_WANT_READ;
909 	}
910 
911 	return received;
912 }
913 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
914 
tls_tx(void * ctx,const unsigned char * buf,size_t len)915 static int tls_tx(void *ctx, const unsigned char *buf, size_t len)
916 {
917 	struct tls_context *tls_ctx = ctx;
918 	ssize_t sent;
919 
920 	sent = zsock_sendto(tls_ctx->sock, buf, len,
921 			    ZSOCK_MSG_DONTWAIT, NULL, 0);
922 	if (sent < 0) {
923 		if (errno == EAGAIN) {
924 			return MBEDTLS_ERR_SSL_WANT_WRITE;
925 		}
926 
927 		return MBEDTLS_ERR_NET_SEND_FAILED;
928 	}
929 
930 	return sent;
931 }
932 
tls_rx(void * ctx,unsigned char * buf,size_t len)933 static int tls_rx(void *ctx, unsigned char *buf, size_t len)
934 {
935 	struct tls_context *tls_ctx = ctx;
936 	ssize_t received;
937 
938 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
939 				  ZSOCK_MSG_DONTWAIT, NULL, 0);
940 	if (received < 0) {
941 		if (errno == EAGAIN) {
942 			return MBEDTLS_ERR_SSL_WANT_READ;
943 		}
944 
945 		return MBEDTLS_ERR_NET_RECV_FAILED;
946 	}
947 
948 	return received;
949 }
950 
951 #if defined(MBEDTLS_X509_CRT_PARSE_C)
crt_is_pem(const unsigned char * buf,size_t buflen)952 static bool crt_is_pem(const unsigned char *buf, size_t buflen)
953 {
954 	return (buflen != 0 && buf[buflen - 1] == '\0' &&
955 		strstr((const char *)buf, "-----BEGIN CERTIFICATE-----") != NULL);
956 }
957 #endif
958 
tls_add_ca_certificate(struct tls_context * tls,struct tls_credential * ca_cert)959 static int tls_add_ca_certificate(struct tls_context *tls,
960 				  struct tls_credential *ca_cert)
961 {
962 #if defined(MBEDTLS_X509_CRT_PARSE_C)
963 	int err;
964 
965 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
966 	    crt_is_pem(ca_cert->buf, ca_cert->len)) {
967 		err = mbedtls_x509_crt_parse(&tls->ca_chain, ca_cert->buf,
968 					     ca_cert->len);
969 	} else {
970 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->ca_chain,
971 							ca_cert->buf,
972 							ca_cert->len);
973 	}
974 
975 	if (err != 0) {
976 		NET_ERR("Failed to parse CA certificate, err: -0x%x", -err);
977 		return -EINVAL;
978 	}
979 
980 	return 0;
981 #endif /* MBEDTLS_X509_CRT_PARSE_C */
982 
983 	return -ENOTSUP;
984 }
985 
tls_set_ca_chain(struct tls_context * tls)986 static void tls_set_ca_chain(struct tls_context *tls)
987 {
988 #if defined(MBEDTLS_X509_CRT_PARSE_C)
989 	mbedtls_ssl_conf_ca_chain(&tls->config, &tls->ca_chain, NULL);
990 	mbedtls_ssl_conf_cert_profile(&tls->config,
991 				      &mbedtls_x509_crt_profile_default);
992 #endif /* MBEDTLS_X509_CRT_PARSE_C */
993 }
994 
tls_add_own_cert(struct tls_context * tls,struct tls_credential * own_cert)995 static int tls_add_own_cert(struct tls_context *tls,
996 			    struct tls_credential *own_cert)
997 {
998 #if defined(MBEDTLS_X509_CRT_PARSE_C)
999 	int err;
1000 
1001 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
1002 	    crt_is_pem(own_cert->buf, own_cert->len)) {
1003 		err = mbedtls_x509_crt_parse(&tls->own_cert,
1004 					     own_cert->buf, own_cert->len);
1005 	} else {
1006 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->own_cert,
1007 							own_cert->buf,
1008 							own_cert->len);
1009 	}
1010 
1011 	if (err != 0) {
1012 		return -EINVAL;
1013 	}
1014 
1015 	return 0;
1016 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1017 
1018 	return -ENOTSUP;
1019 }
1020 
tls_set_own_cert(struct tls_context * tls)1021 static int tls_set_own_cert(struct tls_context *tls)
1022 {
1023 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1024 	int err = mbedtls_ssl_conf_own_cert(&tls->config, &tls->own_cert,
1025 					    &tls->priv_key);
1026 	if (err != 0) {
1027 		err = -ENOMEM;
1028 	}
1029 
1030 	return err;
1031 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1032 
1033 	return -ENOTSUP;
1034 }
1035 
tls_set_private_key(struct tls_context * tls,struct tls_credential * priv_key)1036 static int tls_set_private_key(struct tls_context *tls,
1037 			       struct tls_credential *priv_key)
1038 {
1039 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1040 	int err;
1041 
1042 	err = mbedtls_pk_parse_key(&tls->priv_key, priv_key->buf,
1043 				   priv_key->len, NULL, 0,
1044 				   tls_ctr_drbg_random, NULL);
1045 	if (err != 0) {
1046 		return -EINVAL;
1047 	}
1048 
1049 	return 0;
1050 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1051 
1052 	return -ENOTSUP;
1053 }
1054 
tls_set_psk(struct tls_context * tls,struct tls_credential * psk,struct tls_credential * psk_id)1055 static int tls_set_psk(struct tls_context *tls,
1056 		       struct tls_credential *psk,
1057 		       struct tls_credential *psk_id)
1058 {
1059 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
1060 	int err = mbedtls_ssl_conf_psk(&tls->config,
1061 				       psk->buf, psk->len,
1062 				       (const unsigned char *)psk_id->buf,
1063 				       psk_id->len);
1064 	if (err != 0) {
1065 		return -EINVAL;
1066 	}
1067 
1068 	return 0;
1069 #endif
1070 
1071 	return -ENOTSUP;
1072 }
1073 
tls_set_credential(struct tls_context * tls,struct tls_credential * cred)1074 static int tls_set_credential(struct tls_context *tls,
1075 			      struct tls_credential *cred)
1076 {
1077 	switch (cred->type) {
1078 	case TLS_CREDENTIAL_CA_CERTIFICATE:
1079 		return tls_add_ca_certificate(tls, cred);
1080 
1081 	case TLS_CREDENTIAL_SERVER_CERTIFICATE:
1082 		return tls_add_own_cert(tls, cred);
1083 
1084 	case TLS_CREDENTIAL_PRIVATE_KEY:
1085 		return tls_set_private_key(tls, cred);
1086 	break;
1087 
1088 	case TLS_CREDENTIAL_PSK:
1089 	{
1090 		struct tls_credential *psk_id =
1091 			credential_get(cred->tag, TLS_CREDENTIAL_PSK_ID);
1092 		if (!psk_id) {
1093 			return -ENOENT;
1094 		}
1095 
1096 		return tls_set_psk(tls, cred, psk_id);
1097 	}
1098 
1099 	case TLS_CREDENTIAL_PSK_ID:
1100 		/* Ignore PSK ID - it will be used together
1101 		 * with PSK
1102 		 */
1103 		break;
1104 
1105 	default:
1106 		return -EINVAL;
1107 	}
1108 
1109 	return 0;
1110 }
1111 
tls_mbedtls_set_credentials(struct tls_context * tls)1112 static int tls_mbedtls_set_credentials(struct tls_context *tls)
1113 {
1114 	struct tls_credential *cred;
1115 	sec_tag_t tag;
1116 	int i, err = 0;
1117 	bool tag_found, ca_cert_present = false, own_cert_present = false;
1118 
1119 	credentials_lock();
1120 
1121 	for (i = 0; i < tls->options.sec_tag_list.sec_tag_count; i++) {
1122 		tag = tls->options.sec_tag_list.sec_tags[i];
1123 		cred = NULL;
1124 		tag_found = false;
1125 
1126 		while ((cred = credential_next_get(tag, cred)) != NULL) {
1127 			tag_found = true;
1128 
1129 			err = tls_set_credential(tls, cred);
1130 			if (err != 0) {
1131 				goto exit;
1132 			}
1133 
1134 			if (cred->type == TLS_CREDENTIAL_CA_CERTIFICATE) {
1135 				ca_cert_present = true;
1136 			} else if (cred->type == TLS_CREDENTIAL_SERVER_CERTIFICATE) {
1137 				own_cert_present = true;
1138 			}
1139 		}
1140 
1141 		if (!tag_found) {
1142 			err = -ENOENT;
1143 			goto exit;
1144 		}
1145 	}
1146 
1147 exit:
1148 	credentials_unlock();
1149 
1150 	if (err == 0) {
1151 		if (ca_cert_present) {
1152 			tls_set_ca_chain(tls);
1153 		}
1154 		if (own_cert_present) {
1155 			err = tls_set_own_cert(tls);
1156 		}
1157 	}
1158 
1159 	return err;
1160 }
1161 
tls_mbedtls_reset(struct tls_context * context)1162 static int tls_mbedtls_reset(struct tls_context *context)
1163 {
1164 	int ret;
1165 
1166 	ret = mbedtls_ssl_session_reset(&context->ssl);
1167 	if (ret != 0) {
1168 		return ret;
1169 	}
1170 
1171 	k_sem_reset(&context->tls_established);
1172 
1173 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1174 	/* Server role: reset the address so that a new
1175 	 *              client can connect w/o a need to reopen a socket
1176 	 * Client role: keep peer addr so socket can continue to be used
1177 	 *              even on handshake timeout
1178 	 */
1179 	if (context->options.role == MBEDTLS_SSL_IS_SERVER) {
1180 		(void)memset(&context->dtls_peer_addr, 0,
1181 			     sizeof(context->dtls_peer_addr));
1182 		context->dtls_peer_addrlen = 0;
1183 	}
1184 #endif
1185 
1186 	return 0;
1187 }
1188 
tls_mbedtls_handshake(struct tls_context * context,k_timeout_t timeout)1189 static int tls_mbedtls_handshake(struct tls_context *context,
1190 				 k_timeout_t timeout)
1191 {
1192 	k_timepoint_t end;
1193 	int ret;
1194 
1195 	context->handshake_in_progress = true;
1196 
1197 	end = sys_timepoint_calc(timeout);
1198 
1199 	while ((ret = mbedtls_ssl_handshake(&context->ssl)) != 0) {
1200 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
1201 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
1202 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
1203 		    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
1204 			int timeout_ms;
1205 
1206 			/* Blocking timeout. */
1207 			timeout = sys_timepoint_timeout(end);
1208 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1209 				ret = -EAGAIN;
1210 				break;
1211 			}
1212 
1213 			/* Block. */
1214 			timeout_ms = timeout_to_ms(&timeout);
1215 
1216 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1217 			if (context->type == SOCK_DGRAM) {
1218 				int timeout_dtls =
1219 					dtls_get_remaining_timeout(context);
1220 
1221 				if (timeout_dtls != SYS_FOREVER_MS) {
1222 					if (timeout_ms == SYS_FOREVER_MS) {
1223 						timeout_ms = timeout_dtls;
1224 					} else {
1225 						timeout_ms = MIN(timeout_dtls,
1226 								 timeout_ms);
1227 					}
1228 				}
1229 			}
1230 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1231 
1232 			ret = wait_for_reason(context->sock, timeout_ms, ret);
1233 			if (ret != 0) {
1234 				break;
1235 			}
1236 
1237 			continue;
1238 		} else if (ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED) {
1239 			ret = tls_mbedtls_reset(context);
1240 			if (ret == 0) {
1241 				if (!K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1242 					continue;
1243 				}
1244 
1245 				ret = -EAGAIN;
1246 				break;
1247 			}
1248 		} else if (ret == MBEDTLS_ERR_SSL_TIMEOUT) {
1249 			/* MbedTLS API documentation requires session to
1250 			 * be reset in this case
1251 			 */
1252 			ret = tls_mbedtls_reset(context);
1253 			if (ret == 0) {
1254 				NET_ERR("TLS handshake timeout");
1255 				context->error = ETIMEDOUT;
1256 				ret = -ETIMEDOUT;
1257 				break;
1258 			}
1259 		} else {
1260 			/* MbedTLS API documentation requires session to
1261 			 * be reset in other error cases
1262 			 */
1263 			NET_ERR("TLS handshake error: -0x%x", -ret);
1264 			ret = tls_mbedtls_reset(context);
1265 			if (ret == 0) {
1266 				context->error = ECONNABORTED;
1267 				ret = -ECONNABORTED;
1268 				break;
1269 			}
1270 		}
1271 
1272 		/* Avoid constant loop if tls_mbedtls_reset fails */
1273 		NET_ERR("TLS reset error: -0x%x", -ret);
1274 		context->error = ECONNABORTED;
1275 		ret = -ECONNABORTED;
1276 		break;
1277 	}
1278 
1279 	if (ret == 0) {
1280 		k_sem_give(&context->tls_established);
1281 	}
1282 
1283 	context->handshake_in_progress = false;
1284 
1285 	return ret;
1286 }
1287 
tls_mbedtls_init(struct tls_context * context,bool is_server)1288 static int tls_mbedtls_init(struct tls_context *context, bool is_server)
1289 {
1290 	int role, type, ret;
1291 
1292 	role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
1293 
1294 	type = (context->type == SOCK_STREAM) ?
1295 		MBEDTLS_SSL_TRANSPORT_STREAM :
1296 		MBEDTLS_SSL_TRANSPORT_DATAGRAM;
1297 
1298 	if (type == MBEDTLS_SSL_TRANSPORT_STREAM) {
1299 		mbedtls_ssl_set_bio(&context->ssl, context,
1300 				    tls_tx, tls_rx, NULL);
1301 	} else {
1302 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1303 		mbedtls_ssl_set_bio(&context->ssl, context,
1304 				    dtls_tx, dtls_rx, NULL);
1305 #else
1306 		return -ENOTSUP;
1307 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1308 	}
1309 
1310 	ret = mbedtls_ssl_config_defaults(&context->config, role, type,
1311 					  MBEDTLS_SSL_PRESET_DEFAULT);
1312 	if (ret != 0) {
1313 		/* According to mbedTLS API documentation,
1314 		 * mbedtls_ssl_config_defaults can fail due to memory
1315 		 * allocation failure
1316 		 */
1317 		return -ENOMEM;
1318 	}
1319 	tls_set_max_frag_len(&context->config, context->type);
1320 
1321 #if defined(MBEDTLS_SSL_RENEGOTIATION)
1322 	mbedtls_ssl_conf_legacy_renegotiation(&context->config,
1323 					   MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE);
1324 	mbedtls_ssl_conf_renegotiation(&context->config,
1325 				       MBEDTLS_SSL_RENEGOTIATION_ENABLED);
1326 #endif
1327 
1328 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1329 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1330 		/* DTLS requires timer callbacks to operate */
1331 		mbedtls_ssl_set_timer_cb(&context->ssl,
1332 					 &context->dtls_timing,
1333 					 dtls_timing_set_delay,
1334 					 dtls_timing_get_delay);
1335 		mbedtls_ssl_conf_handshake_timeout(&context->config,
1336 				context->options.dtls_handshake_timeout_min,
1337 				context->options.dtls_handshake_timeout_max);
1338 
1339 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1340 		if (context->options.dtls_cid.enabled) {
1341 			ret = mbedtls_ssl_conf_cid(
1342 					&context->config,
1343 					context->options.dtls_cid.cid_len,
1344 					MBEDTLS_SSL_UNEXPECTED_CID_IGNORE);
1345 			if (ret != 0) {
1346 				return -EINVAL;
1347 			}
1348 		}
1349 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1350 
1351 		/* Configure cookie for DTLS server */
1352 		if (role == MBEDTLS_SSL_IS_SERVER) {
1353 			ret = mbedtls_ssl_cookie_setup(&context->cookie,
1354 						       tls_ctr_drbg_random,
1355 						       NULL);
1356 			if (ret != 0) {
1357 				return -ENOMEM;
1358 			}
1359 
1360 			mbedtls_ssl_conf_dtls_cookies(&context->config,
1361 						      mbedtls_ssl_cookie_write,
1362 						      mbedtls_ssl_cookie_check,
1363 						      &context->cookie);
1364 
1365 			mbedtls_ssl_conf_read_timeout(
1366 					&context->config,
1367 					CONFIG_NET_SOCKETS_DTLS_TIMEOUT);
1368 		}
1369 	}
1370 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1371 
1372 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1373 	/* For TLS clients, set hostname to empty string to enforce it's
1374 	 * verification - only if hostname option was not set. Otherwise
1375 	 * depend on user configuration.
1376 	 */
1377 	if (!is_server && !context->options.is_hostname_set) {
1378 		mbedtls_ssl_set_hostname(&context->ssl, "");
1379 	}
1380 #endif
1381 
1382 	/* If verification level was specified explicitly, set it. Otherwise,
1383 	 * use mbedTLS default values (required for client, none for server)
1384 	 */
1385 	if (context->options.verify_level != -1) {
1386 		mbedtls_ssl_conf_authmode(&context->config,
1387 					  context->options.verify_level);
1388 	}
1389 
1390 	mbedtls_ssl_conf_rng(&context->config,
1391 			     tls_ctr_drbg_random,
1392 			     NULL);
1393 
1394 	ret = tls_mbedtls_set_credentials(context);
1395 	if (ret != 0) {
1396 		return ret;
1397 	}
1398 
1399 	if (context->options.ciphersuites[0] != 0) {
1400 		/* Specific ciphersuites configured, so use them */
1401 		NET_DBG("Using user-specified ciphersuites");
1402 		mbedtls_ssl_conf_ciphersuites(&context->config,
1403 					      context->options.ciphersuites);
1404 	}
1405 
1406 #if defined(CONFIG_MBEDTLS_SSL_ALPN)
1407 	if (ALPN_MAX_PROTOCOLS && context->options.alpn_list[0] != NULL) {
1408 		ret = mbedtls_ssl_conf_alpn_protocols(&context->config,
1409 				context->options.alpn_list);
1410 		if (ret != 0) {
1411 			return -EINVAL;
1412 		}
1413 	}
1414 #endif /* CONFIG_MBEDTLS_SSL_ALPN */
1415 
1416 #if defined(MBEDTLS_SSL_CACHE_C)
1417 	if (is_server && context->options.cache_enabled) {
1418 		mbedtls_ssl_conf_session_cache(&context->config, &server_cache,
1419 					       mbedtls_ssl_cache_get,
1420 					       mbedtls_ssl_cache_set);
1421 	}
1422 #endif
1423 
1424 #if defined(MBEDTLS_SSL_EARLY_DATA)
1425 	mbedtls_ssl_conf_early_data(&context->config, MBEDTLS_SSL_EARLY_DATA_ENABLED);
1426 #endif
1427 
1428 	ret = mbedtls_ssl_setup(&context->ssl,
1429 				&context->config);
1430 	if (ret != 0) {
1431 		/* According to mbedTLS API documentation,
1432 		 * mbedtls_ssl_setup can fail due to memory allocation failure
1433 		 */
1434 		return -ENOMEM;
1435 	}
1436 
1437 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS) && defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1438 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1439 		if (context->options.dtls_cid.enabled) {
1440 			ret = mbedtls_ssl_set_cid(&context->ssl, MBEDTLS_SSL_CID_ENABLED,
1441 						  context->options.dtls_cid.cid,
1442 						  context->options.dtls_cid.cid_len);
1443 			if (ret != 0) {
1444 				return -EINVAL;
1445 			}
1446 		}
1447 	}
1448 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS && CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1449 
1450 	context->is_initialized = true;
1451 
1452 	return 0;
1453 }
1454 
tls_opt_sec_tag_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1455 static int tls_opt_sec_tag_list_set(struct tls_context *context,
1456 				    const void *optval, socklen_t optlen)
1457 {
1458 	int sec_tag_cnt;
1459 
1460 	if (!optval) {
1461 		return -EINVAL;
1462 	}
1463 
1464 	if (optlen % sizeof(sec_tag_t) != 0) {
1465 		return -EINVAL;
1466 	}
1467 
1468 	sec_tag_cnt = optlen / sizeof(sec_tag_t);
1469 	if (sec_tag_cnt >
1470 		ARRAY_SIZE(context->options.sec_tag_list.sec_tags)) {
1471 		return -EINVAL;
1472 	}
1473 
1474 	memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
1475 	context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;
1476 
1477 	return 0;
1478 }
1479 
sock_opt_protocol_get(struct tls_context * context,void * optval,socklen_t * optlen)1480 static int sock_opt_protocol_get(struct tls_context *context,
1481 				 void *optval, socklen_t *optlen)
1482 {
1483 	int protocol = (int)context->tls_version;
1484 
1485 	if (*optlen != sizeof(protocol)) {
1486 		return -EINVAL;
1487 	}
1488 
1489 	*(int *)optval = protocol;
1490 
1491 	return 0;
1492 }
1493 
tls_opt_sec_tag_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1494 static int tls_opt_sec_tag_list_get(struct tls_context *context,
1495 				    void *optval, socklen_t *optlen)
1496 {
1497 	int len;
1498 
1499 	if (*optlen % sizeof(sec_tag_t) != 0 || *optlen == 0) {
1500 		return -EINVAL;
1501 	}
1502 
1503 	len = MIN(context->options.sec_tag_list.sec_tag_count *
1504 		  sizeof(sec_tag_t), *optlen);
1505 
1506 	memcpy(optval, context->options.sec_tag_list.sec_tags, len);
1507 	*optlen = len;
1508 
1509 	return 0;
1510 }
1511 
tls_opt_hostname_set(struct tls_context * context,const void * optval,socklen_t optlen)1512 static int tls_opt_hostname_set(struct tls_context *context,
1513 				const void *optval, socklen_t optlen)
1514 {
1515 	ARG_UNUSED(optlen);
1516 
1517 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1518 	if (mbedtls_ssl_set_hostname(&context->ssl, optval) != 0) {
1519 		return -EINVAL;
1520 	}
1521 #else
1522 	return -ENOPROTOOPT;
1523 #endif
1524 
1525 	context->options.is_hostname_set = true;
1526 
1527 	return 0;
1528 }
1529 
tls_opt_ciphersuite_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1530 static int tls_opt_ciphersuite_list_set(struct tls_context *context,
1531 					const void *optval, socklen_t optlen)
1532 {
1533 	int cipher_cnt;
1534 
1535 	if (!optval) {
1536 		return -EINVAL;
1537 	}
1538 
1539 	if (optlen % sizeof(int) != 0) {
1540 		return -EINVAL;
1541 	}
1542 
1543 	cipher_cnt = optlen / sizeof(int);
1544 
1545 	/* + 1 for 0-termination. */
1546 	if (cipher_cnt + 1 > ARRAY_SIZE(context->options.ciphersuites)) {
1547 		return -EINVAL;
1548 	}
1549 
1550 	memcpy(context->options.ciphersuites, optval, optlen);
1551 	context->options.ciphersuites[cipher_cnt] = 0;
1552 
1553 	mbedtls_ssl_conf_ciphersuites(&context->config,
1554 				      context->options.ciphersuites);
1555 	return 0;
1556 }
1557 
tls_opt_ciphersuite_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1558 static int tls_opt_ciphersuite_list_get(struct tls_context *context,
1559 					void *optval, socklen_t *optlen)
1560 {
1561 	const int *selected_ciphers;
1562 	int cipher_cnt, i = 0;
1563 	int *ciphers = optval;
1564 
1565 	if (*optlen % sizeof(int) != 0 || *optlen == 0) {
1566 		return -EINVAL;
1567 	}
1568 
1569 	if (context->options.ciphersuites[0] == 0) {
1570 		/* No specific ciphersuites configured, return all available. */
1571 		selected_ciphers = mbedtls_ssl_list_ciphersuites();
1572 	} else {
1573 		selected_ciphers = context->options.ciphersuites;
1574 	}
1575 
1576 	cipher_cnt = *optlen / sizeof(int);
1577 	while (selected_ciphers[i] != 0) {
1578 		ciphers[i] = selected_ciphers[i];
1579 
1580 		if (++i == cipher_cnt) {
1581 			break;
1582 		}
1583 	}
1584 
1585 	*optlen = i * sizeof(int);
1586 
1587 	return 0;
1588 }
1589 
tls_opt_ciphersuite_used_get(struct tls_context * context,void * optval,socklen_t * optlen)1590 static int tls_opt_ciphersuite_used_get(struct tls_context *context,
1591 					void *optval, socklen_t *optlen)
1592 {
1593 	const char *ciph;
1594 
1595 	if (*optlen != sizeof(int)) {
1596 		return -EINVAL;
1597 	}
1598 
1599 	ciph = mbedtls_ssl_get_ciphersuite(&context->ssl);
1600 	if (ciph == NULL) {
1601 		return -ENOTCONN;
1602 	}
1603 
1604 	*(int *)optval = mbedtls_ssl_get_ciphersuite_id(ciph);
1605 
1606 	return 0;
1607 }
1608 
tls_opt_alpn_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1609 static int tls_opt_alpn_list_set(struct tls_context *context,
1610 				 const void *optval, socklen_t optlen)
1611 {
1612 	int alpn_cnt;
1613 
1614 	if (!ALPN_MAX_PROTOCOLS) {
1615 		return -EINVAL;
1616 	}
1617 
1618 	if (!optval) {
1619 		return -EINVAL;
1620 	}
1621 
1622 	if (optlen % sizeof(const char *) != 0) {
1623 		return -EINVAL;
1624 	}
1625 
1626 	alpn_cnt = optlen / sizeof(const char *);
1627 	/* + 1 for NULL-termination. */
1628 	if (alpn_cnt + 1 > ARRAY_SIZE(context->options.alpn_list)) {
1629 		return -EINVAL;
1630 	}
1631 
1632 	memcpy(context->options.alpn_list, optval, optlen);
1633 	context->options.alpn_list[alpn_cnt] = NULL;
1634 
1635 	return 0;
1636 }
1637 
1638 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
tls_opt_dtls_handshake_timeout_get(struct tls_context * context,void * optval,socklen_t * optlen,bool is_max)1639 static int tls_opt_dtls_handshake_timeout_get(struct tls_context *context,
1640 					      void *optval, socklen_t *optlen,
1641 					      bool is_max)
1642 {
1643 	uint32_t *val = (uint32_t *)optval;
1644 
1645 	if (sizeof(uint32_t) != *optlen) {
1646 		return -EINVAL;
1647 	}
1648 
1649 	if (is_max) {
1650 		*val = context->options.dtls_handshake_timeout_max;
1651 	} else {
1652 		*val = context->options.dtls_handshake_timeout_min;
1653 	}
1654 
1655 	return 0;
1656 }
1657 
tls_opt_dtls_handshake_timeout_set(struct tls_context * context,const void * optval,socklen_t optlen,bool is_max)1658 static int tls_opt_dtls_handshake_timeout_set(struct tls_context *context,
1659 					      const void *optval,
1660 					      socklen_t optlen, bool is_max)
1661 {
1662 	uint32_t *val = (uint32_t *)optval;
1663 
1664 	if (!optval) {
1665 		return -EINVAL;
1666 	}
1667 
1668 	if (sizeof(uint32_t) != optlen) {
1669 		return -EINVAL;
1670 	}
1671 
1672 	/* If mbedTLS context not inited, it will
1673 	 * use these values upon init.
1674 	 */
1675 	if (is_max) {
1676 		context->options.dtls_handshake_timeout_max = *val;
1677 	} else {
1678 		context->options.dtls_handshake_timeout_min = *val;
1679 	}
1680 
1681 	/* If mbedTLS context already inited, we need to
1682 	 * update mbedTLS config for it to take effect
1683 	 */
1684 	mbedtls_ssl_conf_handshake_timeout(&context->config,
1685 			context->options.dtls_handshake_timeout_min,
1686 			context->options.dtls_handshake_timeout_max);
1687 
1688 	return 0;
1689 }
1690 
tls_opt_dtls_connection_id_set(struct tls_context * context,const void * optval,socklen_t optlen)1691 static int tls_opt_dtls_connection_id_set(struct tls_context *context,
1692 					  const void *optval, socklen_t optlen)
1693 {
1694 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1695 	int value;
1696 
1697 	if (optlen > 0 && optval == NULL) {
1698 		return -EINVAL;
1699 	}
1700 
1701 	if (optlen != sizeof(int)) {
1702 		return -EINVAL;
1703 	}
1704 
1705 	value = *((int *)optval);
1706 
1707 	switch (value) {
1708 	case TLS_DTLS_CID_DISABLED:
1709 		context->options.dtls_cid.enabled = false;
1710 		context->options.dtls_cid.cid_len = 0;
1711 		break;
1712 	case TLS_DTLS_CID_SUPPORTED:
1713 		context->options.dtls_cid.enabled = true;
1714 		context->options.dtls_cid.cid_len = 0;
1715 		break;
1716 	case TLS_DTLS_CID_ENABLED:
1717 		context->options.dtls_cid.enabled = true;
1718 		if (context->options.dtls_cid.cid_len == 0) {
1719 			/* generate random self cid */
1720 #if defined(CONFIG_CSPRNG_ENABLED)
1721 			sys_csrand_get(context->options.dtls_cid.cid,
1722 				       MBEDTLS_SSL_CID_OUT_LEN_MAX);
1723 #else
1724 			sys_rand_get(context->options.dtls_cid.cid,
1725 				     MBEDTLS_SSL_CID_OUT_LEN_MAX);
1726 #endif
1727 			context->options.dtls_cid.cid_len = MBEDTLS_SSL_CID_OUT_LEN_MAX;
1728 		}
1729 		break;
1730 	default:
1731 		return -EINVAL;
1732 	}
1733 
1734 	return 0;
1735 #else
1736 	return -ENOPROTOOPT;
1737 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1738 }
1739 
tls_opt_dtls_connection_id_value_set(struct tls_context * context,const void * optval,socklen_t optlen)1740 static int tls_opt_dtls_connection_id_value_set(struct tls_context *context,
1741 						const void *optval,
1742 						socklen_t optlen)
1743 {
1744 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1745 	if (optlen > 0 && optval == NULL) {
1746 		return -EINVAL;
1747 	}
1748 
1749 	if (optlen > MBEDTLS_SSL_CID_IN_LEN_MAX) {
1750 		return -EINVAL;
1751 	}
1752 
1753 	context->options.dtls_cid.cid_len = optlen;
1754 	memcpy(context->options.dtls_cid.cid, optval, optlen);
1755 
1756 	return 0;
1757 #else
1758 	return -ENOPROTOOPT;
1759 #endif /* CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID */
1760 }
1761 
tls_opt_dtls_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1762 static int tls_opt_dtls_connection_id_value_get(struct tls_context *context,
1763 						void *optval, socklen_t *optlen)
1764 {
1765 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1766 
1767 	if (*optlen < context->options.dtls_cid.cid_len) {
1768 		return -EINVAL;
1769 	}
1770 
1771 	*optlen = context->options.dtls_cid.cid_len;
1772 	memcpy(optval, context->options.dtls_cid.cid, *optlen);
1773 
1774 	return 0;
1775 #else
1776 	return -ENOPROTOOPT;
1777 #endif
1778 }
1779 
tls_opt_dtls_peer_connection_id_value_get(struct tls_context * context,void * optval,socklen_t * optlen)1780 static int tls_opt_dtls_peer_connection_id_value_get(struct tls_context *context,
1781 						     void *optval,
1782 						     socklen_t *optlen)
1783 {
1784 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1785 	int enabled = false;
1786 	int ret;
1787 
1788 	if (!context->is_initialized) {
1789 		return -ENOTCONN;
1790 	}
1791 
1792 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled, optval, optlen);
1793 	if (!enabled) {
1794 		*optlen = 0;
1795 	}
1796 	return ret;
1797 #else
1798 	return -ENOPROTOOPT;
1799 #endif
1800 }
1801 
tls_opt_dtls_connection_id_status_get(struct tls_context * context,void * optval,socklen_t * optlen)1802 static int tls_opt_dtls_connection_id_status_get(struct tls_context *context,
1803 					  void *optval, socklen_t *optlen)
1804 {
1805 #if defined(CONFIG_MBEDTLS_SSL_DTLS_CONNECTION_ID)
1806 	struct tls_dtls_cid cid;
1807 	int ret;
1808 	int val;
1809 	int enabled;
1810 	bool have_self_cid;
1811 	bool have_peer_cid;
1812 
1813 	if (sizeof(int) != *optlen) {
1814 		return -EINVAL;
1815 	}
1816 
1817 	if (!context->is_initialized) {
1818 		return -ENOTCONN;
1819 	}
1820 
1821 	ret = mbedtls_ssl_get_peer_cid(&context->ssl, &enabled,
1822 				       cid.cid,
1823 				       &cid.cid_len);
1824 	if (ret) {
1825 		/* Handshake is not complete */
1826 		return -EAGAIN;
1827 	}
1828 
1829 	cid.enabled = (enabled == MBEDTLS_SSL_CID_ENABLED);
1830 	have_self_cid = (context->options.dtls_cid.cid_len != 0);
1831 	have_peer_cid = (cid.cid_len != 0);
1832 
1833 	if (!context->options.dtls_cid.enabled) {
1834 		val = TLS_DTLS_CID_STATUS_DISABLED;
1835 	} else if (have_self_cid && have_peer_cid) {
1836 		val = TLS_DTLS_CID_STATUS_BIDIRECTIONAL;
1837 	} else if (have_self_cid) {
1838 		val = TLS_DTLS_CID_STATUS_DOWNLINK;
1839 	} else if (have_peer_cid) {
1840 		val = TLS_DTLS_CID_STATUS_UPLINK;
1841 	} else {
1842 		val = TLS_DTLS_CID_STATUS_DISABLED;
1843 	}
1844 
1845 	*((int *)optval) = val;
1846 	return 0;
1847 #else
1848 	return -ENOPROTOOPT;
1849 #endif
1850 }
1851 
tls_opt_dtls_handshake_on_connect_set(struct tls_context * context,const void * optval,socklen_t optlen)1852 static int tls_opt_dtls_handshake_on_connect_set(struct tls_context *context,
1853 						 const void *optval,
1854 						 socklen_t optlen)
1855 {
1856 	int *val = (int *)optval;
1857 
1858 	if (!optval) {
1859 		return -EINVAL;
1860 	}
1861 
1862 	if (sizeof(int) != optlen) {
1863 		return -EINVAL;
1864 	}
1865 
1866 	context->options.dtls_handshake_on_connect = (bool)*val;
1867 
1868 	return 0;
1869 }
1870 
tls_opt_dtls_handshake_on_connect_get(struct tls_context * context,void * optval,socklen_t * optlen)1871 static int tls_opt_dtls_handshake_on_connect_get(struct tls_context *context,
1872 						 void *optval,
1873 						 socklen_t *optlen)
1874 {
1875 	if (*optlen != sizeof(int)) {
1876 		return -EINVAL;
1877 	}
1878 
1879 	*(int *)optval = context->options.dtls_handshake_on_connect;
1880 
1881 	return 0;
1882 }
1883 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1884 
tls_opt_alpn_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1885 static int tls_opt_alpn_list_get(struct tls_context *context,
1886 				 void *optval, socklen_t *optlen)
1887 {
1888 	const char **alpn_list = context->options.alpn_list;
1889 	int alpn_cnt, i = 0;
1890 	const char **ret_list = optval;
1891 
1892 	if (!ALPN_MAX_PROTOCOLS) {
1893 		return -EINVAL;
1894 	}
1895 
1896 	if (*optlen % sizeof(const char *) != 0 || *optlen == 0) {
1897 		return -EINVAL;
1898 	}
1899 
1900 	alpn_cnt = *optlen / sizeof(const char *);
1901 	while (alpn_list[i] != NULL) {
1902 		ret_list[i] = alpn_list[i];
1903 
1904 		if (++i == alpn_cnt) {
1905 			break;
1906 		}
1907 	}
1908 
1909 	*optlen = i * sizeof(const char *);
1910 
1911 	return 0;
1912 }
1913 
tls_opt_session_cache_set(struct tls_context * context,const void * optval,socklen_t optlen)1914 static int tls_opt_session_cache_set(struct tls_context *context,
1915 				     const void *optval, socklen_t optlen)
1916 {
1917 	int *val = (int *)optval;
1918 
1919 	if (!optval) {
1920 		return -EINVAL;
1921 	}
1922 
1923 	if (sizeof(int) != optlen) {
1924 		return -EINVAL;
1925 	}
1926 
1927 	context->options.cache_enabled = (*val == TLS_SESSION_CACHE_ENABLED);
1928 
1929 	return 0;
1930 }
1931 
tls_opt_session_cache_get(struct tls_context * context,void * optval,socklen_t * optlen)1932 static int tls_opt_session_cache_get(struct tls_context *context,
1933 				     void *optval, socklen_t *optlen)
1934 {
1935 	int cache_enabled = context->options.cache_enabled ?
1936 			    TLS_SESSION_CACHE_ENABLED :
1937 			    TLS_SESSION_CACHE_DISABLED;
1938 
1939 	if (*optlen != sizeof(cache_enabled)) {
1940 		return -EINVAL;
1941 	}
1942 
1943 	*(int *)optval = cache_enabled;
1944 
1945 	return 0;
1946 }
1947 
tls_opt_session_cache_purge_set(struct tls_context * context,const void * optval,socklen_t optlen)1948 static int tls_opt_session_cache_purge_set(struct tls_context *context,
1949 					   const void *optval, socklen_t optlen)
1950 {
1951 	ARG_UNUSED(context);
1952 	ARG_UNUSED(optval);
1953 	ARG_UNUSED(optlen);
1954 
1955 	tls_session_purge();
1956 
1957 	return 0;
1958 }
1959 
tls_opt_peer_verify_set(struct tls_context * context,const void * optval,socklen_t optlen)1960 static int tls_opt_peer_verify_set(struct tls_context *context,
1961 				   const void *optval, socklen_t optlen)
1962 {
1963 	int *peer_verify;
1964 
1965 	if (!optval) {
1966 		return -EINVAL;
1967 	}
1968 
1969 	if (optlen != sizeof(int)) {
1970 		return -EINVAL;
1971 	}
1972 
1973 	peer_verify = (int *)optval;
1974 
1975 	if (*peer_verify != MBEDTLS_SSL_VERIFY_NONE &&
1976 	    *peer_verify != MBEDTLS_SSL_VERIFY_OPTIONAL &&
1977 	    *peer_verify != MBEDTLS_SSL_VERIFY_REQUIRED) {
1978 		return -EINVAL;
1979 	}
1980 
1981 	context->options.verify_level = *peer_verify;
1982 
1983 	return 0;
1984 }
1985 
tls_opt_cert_nocopy_set(struct tls_context * context,const void * optval,socklen_t optlen)1986 static int tls_opt_cert_nocopy_set(struct tls_context *context,
1987 				   const void *optval, socklen_t optlen)
1988 {
1989 	int *cert_nocopy;
1990 
1991 	if (!optval) {
1992 		return -EINVAL;
1993 	}
1994 
1995 	if (optlen != sizeof(int)) {
1996 		return -EINVAL;
1997 	}
1998 
1999 	cert_nocopy = (int *)optval;
2000 
2001 	if (*cert_nocopy != TLS_CERT_NOCOPY_NONE &&
2002 	    *cert_nocopy != TLS_CERT_NOCOPY_OPTIONAL) {
2003 		return -EINVAL;
2004 	}
2005 
2006 	context->options.cert_nocopy = *cert_nocopy;
2007 
2008 	return 0;
2009 }
2010 
tls_opt_dtls_role_set(struct tls_context * context,const void * optval,socklen_t optlen)2011 static int tls_opt_dtls_role_set(struct tls_context *context,
2012 				 const void *optval, socklen_t optlen)
2013 {
2014 	int *role;
2015 
2016 	if (!optval) {
2017 		return -EINVAL;
2018 	}
2019 
2020 	if (optlen != sizeof(int)) {
2021 		return -EINVAL;
2022 	}
2023 
2024 	role = (int *)optval;
2025 	if (*role != MBEDTLS_SSL_IS_CLIENT &&
2026 	    *role != MBEDTLS_SSL_IS_SERVER) {
2027 		return -EINVAL;
2028 	}
2029 
2030 	context->options.role = *role;
2031 
2032 	return 0;
2033 }
2034 
protocol_check(int family,int type,int * proto)2035 static int protocol_check(int family, int type, int *proto)
2036 {
2037 	if (family != AF_INET && family != AF_INET6) {
2038 		return -EAFNOSUPPORT;
2039 	}
2040 
2041 	if (*proto >= IPPROTO_TLS_1_0 && *proto <= IPPROTO_TLS_1_3) {
2042 		if (type != SOCK_STREAM) {
2043 			return -EPROTOTYPE;
2044 		}
2045 
2046 		*proto = IPPROTO_TCP;
2047 	} else if (*proto >= IPPROTO_DTLS_1_0 && *proto <= IPPROTO_DTLS_1_2) {
2048 		if (!IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS)) {
2049 			return -EPROTONOSUPPORT;
2050 		}
2051 
2052 		if (type != SOCK_DGRAM) {
2053 			return -EPROTOTYPE;
2054 		}
2055 
2056 		*proto = IPPROTO_UDP;
2057 	} else {
2058 		return -EPROTONOSUPPORT;
2059 	}
2060 
2061 	return 0;
2062 }
2063 
ztls_socket(int family,int type,int proto)2064 static int ztls_socket(int family, int type, int proto)
2065 {
2066 	enum net_ip_protocol_secure tls_proto = proto;
2067 	int fd = zvfs_reserve_fd();
2068 	int sock = -1;
2069 	int ret;
2070 	struct tls_context *ctx;
2071 
2072 	if (fd < 0) {
2073 		return -1;
2074 	}
2075 
2076 	ret = protocol_check(family, type, &proto);
2077 	if (ret < 0) {
2078 		errno = -ret;
2079 		goto free_fd;
2080 	}
2081 
2082 	ctx = tls_alloc();
2083 	if (ctx == NULL) {
2084 		errno = ENOMEM;
2085 		goto free_fd;
2086 	}
2087 
2088 	sock = zsock_socket(family, type, proto);
2089 	if (sock < 0) {
2090 		goto release_tls;
2091 	}
2092 
2093 	ctx->tls_version = tls_proto;
2094 	ctx->type = (proto == IPPROTO_TCP) ? SOCK_STREAM : SOCK_DGRAM;
2095 	ctx->sock = sock;
2096 
2097 	zvfs_finalize_typed_fd(fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2098 			    ZVFS_MODE_IFSOCK);
2099 
2100 	return fd;
2101 
2102 release_tls:
2103 	(void)tls_release(ctx);
2104 
2105 free_fd:
2106 	zvfs_free_fd(fd);
2107 
2108 	return -1;
2109 }
2110 
ztls_close_ctx(struct tls_context * ctx,int sock)2111 int ztls_close_ctx(struct tls_context *ctx, int sock)
2112 {
2113 	int ret, err = 0;
2114 
2115 	/* Try to send close notification. */
2116 	ctx->flags = 0;
2117 
2118 	(void)mbedtls_ssl_close_notify(&ctx->ssl);
2119 
2120 	err = tls_release(ctx);
2121 	ret = zsock_close(ctx->sock);
2122 
2123 	if (ret == 0) {
2124 		(void)sock_obj_core_dealloc(sock);
2125 	}
2126 
2127 	/* In case close fails, we propagate errno value set by close.
2128 	 * In case close succeeds, but tls_release fails, set errno
2129 	 * according to tls_release return value.
2130 	 */
2131 	if (ret == 0 && err < 0) {
2132 		errno = -err;
2133 		ret = -1;
2134 	}
2135 
2136 	return ret;
2137 }
2138 
ztls_connect_ctx(struct tls_context * ctx,const struct sockaddr * addr,socklen_t addrlen)2139 int ztls_connect_ctx(struct tls_context *ctx, const struct sockaddr *addr,
2140 		     socklen_t addrlen)
2141 {
2142 	int ret;
2143 	int sock_flags;
2144 	bool is_non_block;
2145 
2146 	sock_flags = zsock_fcntl(ctx->sock, F_GETFL, 0);
2147 	if (sock_flags < 0) {
2148 		return -EIO;
2149 	}
2150 
2151 	is_non_block = sock_flags & O_NONBLOCK;
2152 	if (is_non_block) {
2153 		(void)zsock_fcntl(ctx->sock, F_SETFL,
2154 				  sock_flags & ~O_NONBLOCK);
2155 	}
2156 
2157 	ret = zsock_connect(ctx->sock, addr, addrlen);
2158 	if (ret < 0) {
2159 		return ret;
2160 	}
2161 
2162 	if (is_non_block) {
2163 		(void)zsock_fcntl(ctx->sock, F_SETFL, sock_flags);
2164 	}
2165 
2166 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2167 	if (ctx->type == SOCK_DGRAM) {
2168 		dtls_peer_address_set(ctx, addr, addrlen);
2169 	}
2170 #endif
2171 
2172 	if (ctx->type == SOCK_STREAM
2173 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2174 	    || (ctx->type == SOCK_DGRAM && ctx->options.dtls_handshake_on_connect)
2175 #endif
2176 	    ) {
2177 		ret = tls_mbedtls_init(ctx, false);
2178 		if (ret < 0) {
2179 			goto error;
2180 		}
2181 
2182 		/* Do not use any socket flags during the handshake. */
2183 		ctx->flags = 0;
2184 
2185 		tls_session_restore(ctx, addr, addrlen);
2186 
2187 		/* TODO For simplicity, TLS handshake blocks the socket
2188 		 * even for non-blocking socket.
2189 		 */
2190 		ret = tls_mbedtls_handshake(
2191 			ctx, K_MSEC(CONFIG_NET_SOCKETS_CONNECT_TIMEOUT));
2192 		if (ret < 0) {
2193 			if ((ret == -EAGAIN) && !is_non_block) {
2194 				ret = -ETIMEDOUT;
2195 			}
2196 
2197 			goto error;
2198 		}
2199 
2200 		tls_session_store(ctx, addr, addrlen);
2201 	}
2202 
2203 	return 0;
2204 
2205 error:
2206 	errno = -ret;
2207 	return -1;
2208 }
2209 
ztls_accept_ctx(struct tls_context * parent,struct sockaddr * addr,socklen_t * addrlen)2210 int ztls_accept_ctx(struct tls_context *parent, struct sockaddr *addr,
2211 		    socklen_t *addrlen)
2212 {
2213 	struct tls_context *child = NULL;
2214 	int ret, err, fd, sock;
2215 
2216 	fd = zvfs_reserve_fd();
2217 	if (fd < 0) {
2218 		return -1;
2219 	}
2220 
2221 
2222 	k_mutex_unlock(parent->lock);
2223 	sock = zsock_accept(parent->sock, addr, addrlen);
2224 	k_mutex_lock(parent->lock, K_FOREVER);
2225 	if (sock < 0) {
2226 		ret = -errno;
2227 		goto error;
2228 	}
2229 
2230 	child = tls_clone(parent);
2231 	if (child == NULL) {
2232 		ret = -ENOMEM;
2233 		goto error;
2234 	}
2235 
2236 	zvfs_finalize_typed_fd(fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable,
2237 			    ZVFS_MODE_IFSOCK);
2238 
2239 	child->sock = sock;
2240 
2241 	ret = tls_mbedtls_init(child, true);
2242 	if (ret < 0) {
2243 		goto error;
2244 	}
2245 
2246 	/* Do not use any socket flags during the handshake. */
2247 	child->flags = 0;
2248 
2249 	/* TODO For simplicity, TLS handshake blocks the socket even for
2250 	 * non-blocking socket.
2251 	 */
2252 	ret = tls_mbedtls_handshake(
2253 		child, K_MSEC(CONFIG_NET_SOCKETS_CONNECT_TIMEOUT));
2254 	if (ret < 0) {
2255 		goto error;
2256 	}
2257 
2258 	return fd;
2259 
2260 error:
2261 	if (child != NULL) {
2262 		err = tls_release(child);
2263 		__ASSERT(err == 0, "TLS context release failed");
2264 	}
2265 
2266 	if (sock >= 0) {
2267 		err = zsock_close(sock);
2268 		__ASSERT(err == 0, "Child socket close failed");
2269 	}
2270 
2271 	zvfs_free_fd(fd);
2272 
2273 	errno = -ret;
2274 	return -1;
2275 }
2276 
send_tls(struct tls_context * ctx,const void * buf,size_t len,int flags)2277 static ssize_t send_tls(struct tls_context *ctx, const void *buf,
2278 			size_t len, int flags)
2279 {
2280 	const bool is_block = is_blocking(ctx->sock, flags);
2281 	k_timeout_t timeout;
2282 	k_timepoint_t end;
2283 	int ret;
2284 
2285 	if (ctx->error != 0) {
2286 		errno = ctx->error;
2287 		return -1;
2288 	}
2289 
2290 	if (ctx->session_closed) {
2291 		errno = ECONNABORTED;
2292 		return -1;
2293 	}
2294 
2295 	if (!is_block) {
2296 		timeout = K_NO_WAIT;
2297 	} else {
2298 		timeout = ctx->options.timeout_tx;
2299 	}
2300 
2301 	end = sys_timepoint_calc(timeout);
2302 
2303 	do {
2304 		ret = mbedtls_ssl_write(&ctx->ssl, buf, len);
2305 		if (ret >= 0) {
2306 			return ret;
2307 		}
2308 
2309 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2310 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2311 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2312 		    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2313 			int timeout_ms;
2314 
2315 			if (!is_block) {
2316 				errno = EAGAIN;
2317 				break;
2318 			}
2319 
2320 			/* Blocking timeout. */
2321 			timeout = sys_timepoint_timeout(end);
2322 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2323 				errno = EAGAIN;
2324 				break;
2325 			}
2326 
2327 			/* Block. */
2328 			timeout_ms = timeout_to_ms(&timeout);
2329 			ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2330 			if (ret != 0) {
2331 				errno = -ret;
2332 				break;
2333 			}
2334 		} else {
2335 			NET_ERR("TLS send error: -%x", -ret);
2336 
2337 			/* MbedTLS API documentation requires session to
2338 			 * be reset in other error cases
2339 			 */
2340 			ret = tls_mbedtls_reset(ctx);
2341 			if (ret != 0) {
2342 				ctx->error = ENOMEM;
2343 				errno = ENOMEM;
2344 			} else {
2345 				ctx->error = ECONNABORTED;
2346 				errno = ECONNABORTED;
2347 			}
2348 
2349 			break;
2350 		}
2351 	} while (true);
2352 
2353 	return -1;
2354 }
2355 
2356 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
sendto_dtls_client(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2357 static ssize_t sendto_dtls_client(struct tls_context *ctx, const void *buf,
2358 				  size_t len, int flags,
2359 				  const struct sockaddr *dest_addr,
2360 				  socklen_t addrlen)
2361 {
2362 	int ret;
2363 
2364 	if (!dest_addr) {
2365 		/* No address provided, check if we have stored one,
2366 		 * otherwise return error.
2367 		 */
2368 		if (ctx->dtls_peer_addrlen == 0) {
2369 			ret = -EDESTADDRREQ;
2370 			goto error;
2371 		}
2372 	} else if (ctx->dtls_peer_addrlen == 0) {
2373 		/* Address provided and no peer address stored. */
2374 		dtls_peer_address_set(ctx, dest_addr, addrlen);
2375 	} else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2376 		/* Address provided but it does not match stored one */
2377 		ret = -EISCONN;
2378 		goto error;
2379 	}
2380 
2381 	if (!ctx->is_initialized) {
2382 		ret = tls_mbedtls_init(ctx, false);
2383 		if (ret < 0) {
2384 			goto error;
2385 		}
2386 	}
2387 
2388 	if (!is_handshake_complete(ctx)) {
2389 		tls_session_restore(ctx, &ctx->dtls_peer_addr,
2390 				    ctx->dtls_peer_addrlen);
2391 
2392 		/* TODO For simplicity, TLS handshake blocks the socket even for
2393 		 * non-blocking socket.
2394 		 * DTLS handshake timeout/retransmissions are limited by
2395 		 * mbed TLS, so K_FOREVER is fine here, the function will not
2396 		 * block indefinitely.
2397 		 */
2398 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2399 		if (ret < 0) {
2400 			goto error;
2401 		}
2402 
2403 		/* Client socket ready to use again. */
2404 		ctx->error = 0;
2405 
2406 		tls_session_store(ctx, &ctx->dtls_peer_addr,
2407 				  ctx->dtls_peer_addrlen);
2408 	}
2409 
2410 	return send_tls(ctx, buf, len, flags);
2411 
2412 error:
2413 	errno = -ret;
2414 	return -1;
2415 }
2416 
sendto_dtls_server(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2417 static ssize_t sendto_dtls_server(struct tls_context *ctx, const void *buf,
2418 				  size_t len, int flags,
2419 				  const struct sockaddr *dest_addr,
2420 				  socklen_t addrlen)
2421 {
2422 	/* For DTLS server, require to have established DTLS connection
2423 	 * in order to send data.
2424 	 */
2425 	if (!is_handshake_complete(ctx)) {
2426 		errno = ENOTCONN;
2427 		return -1;
2428 	}
2429 
2430 	/* Verify we are sending to a peer that we have connection with. */
2431 	if (dest_addr &&
2432 	    !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2433 		errno = EISCONN;
2434 		return -1;
2435 	}
2436 
2437 	return send_tls(ctx, buf, len, flags);
2438 }
2439 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2440 
ztls_sendto_ctx(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2441 ssize_t ztls_sendto_ctx(struct tls_context *ctx, const void *buf, size_t len,
2442 			int flags, const struct sockaddr *dest_addr,
2443 			socklen_t addrlen)
2444 {
2445 	ctx->flags = flags;
2446 
2447 	/* TLS */
2448 	if (ctx->type == SOCK_STREAM) {
2449 		return send_tls(ctx, buf, len, flags);
2450 	}
2451 
2452 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2453 	/* DTLS */
2454 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2455 		return sendto_dtls_server(ctx, buf, len, flags,
2456 					  dest_addr, addrlen);
2457 	}
2458 
2459 	return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen);
2460 #else
2461 	errno = ENOTSUP;
2462 	return -1;
2463 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2464 }
2465 
dtls_sendmsg_merge_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2466 static ssize_t dtls_sendmsg_merge_and_send(struct tls_context *ctx,
2467 					   const struct msghdr *msg,
2468 					   int flags)
2469 {
2470 	static K_MUTEX_DEFINE(sendmsg_lock);
2471 	static uint8_t sendmsg_buf[DTLS_SENDMSG_BUF_SIZE];
2472 	ssize_t len = 0;
2473 
2474 	k_mutex_lock(&sendmsg_lock, K_FOREVER);
2475 
2476 	for (int i = 0; i < msg->msg_iovlen; i++) {
2477 		struct iovec *vec = msg->msg_iov + i;
2478 
2479 		if (vec->iov_len > 0) {
2480 			if (len + vec->iov_len > sizeof(sendmsg_buf)) {
2481 				k_mutex_unlock(&sendmsg_lock);
2482 				errno = EMSGSIZE;
2483 				return -1;
2484 			}
2485 
2486 			memcpy(sendmsg_buf + len, vec->iov_base, vec->iov_len);
2487 			len += vec->iov_len;
2488 		}
2489 	}
2490 
2491 	if (len > 0) {
2492 		len = ztls_sendto_ctx(ctx, sendmsg_buf, len, flags,
2493 				      msg->msg_name, msg->msg_namelen);
2494 	}
2495 
2496 	k_mutex_unlock(&sendmsg_lock);
2497 
2498 	return len;
2499 }
2500 
tls_sendmsg_loop_and_send(struct tls_context * ctx,const struct msghdr * msg,int flags)2501 static ssize_t tls_sendmsg_loop_and_send(struct tls_context *ctx,
2502 					 const struct msghdr *msg,
2503 					 int flags)
2504 {
2505 	ssize_t len = 0;
2506 	ssize_t ret;
2507 
2508 	for (int i = 0; i < msg->msg_iovlen; i++) {
2509 		struct iovec *vec = msg->msg_iov + i;
2510 		size_t sent = 0;
2511 
2512 		if (vec->iov_len == 0) {
2513 			continue;
2514 		}
2515 
2516 		while (sent < vec->iov_len) {
2517 			uint8_t *ptr = (uint8_t *)vec->iov_base + sent;
2518 
2519 			ret = ztls_sendto_ctx(ctx, ptr, vec->iov_len - sent,
2520 					      flags, msg->msg_name,
2521 					      msg->msg_namelen);
2522 			if (ret < 0) {
2523 				return ret;
2524 			}
2525 			sent += ret;
2526 		}
2527 		len += sent;
2528 	}
2529 
2530 	return len;
2531 }
2532 
ztls_sendmsg_ctx(struct tls_context * ctx,const struct msghdr * msg,int flags)2533 ssize_t ztls_sendmsg_ctx(struct tls_context *ctx, const struct msghdr *msg,
2534 			 int flags)
2535 {
2536 	if (msg == NULL) {
2537 		errno = EINVAL;
2538 		return -1;
2539 	}
2540 
2541 	if (IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS) &&
2542 	    ctx->type == SOCK_DGRAM) {
2543 		if (DTLS_SENDMSG_BUF_SIZE > 0) {
2544 			/* With one buffer only, there's no need to use
2545 			 * intermediate buffer.
2546 			 */
2547 			if (msghdr_non_empty_iov_count(msg) == 1) {
2548 				goto send_loop;
2549 			}
2550 
2551 			return dtls_sendmsg_merge_and_send(ctx, msg, flags);
2552 		}
2553 
2554 		/*
2555 		 * Current mbedTLS API (i.e. mbedtls_ssl_write()) allows only to send a single
2556 		 * contiguous buffer. This means that gather write using sendmsg() can only be
2557 		 * handled correctly if there is a single non-empty buffer in msg->msg_iov.
2558 		 */
2559 		if (msghdr_non_empty_iov_count(msg) > 1) {
2560 			errno = EMSGSIZE;
2561 			return -1;
2562 		}
2563 	}
2564 
2565 send_loop:
2566 	return tls_sendmsg_loop_and_send(ctx, msg, flags);
2567 }
2568 
recv_tls(struct tls_context * ctx,void * buf,size_t max_len,int flags)2569 static ssize_t recv_tls(struct tls_context *ctx, void *buf,
2570 			size_t max_len, int flags)
2571 {
2572 	size_t recv_len = 0;
2573 	const bool waitall = flags & ZSOCK_MSG_WAITALL;
2574 	const bool is_block = is_blocking(ctx->sock, flags);
2575 	k_timeout_t timeout;
2576 	k_timepoint_t end;
2577 	int ret;
2578 
2579 	if (ctx->error != 0) {
2580 		errno = ctx->error;
2581 		return -1;
2582 	}
2583 
2584 	if (ctx->session_closed) {
2585 		return 0;
2586 	}
2587 
2588 	if (!is_block) {
2589 		timeout = K_NO_WAIT;
2590 	} else {
2591 		timeout = ctx->options.timeout_rx;
2592 	}
2593 
2594 	end = sys_timepoint_calc(timeout);
2595 
2596 	do {
2597 		size_t read_len = max_len - recv_len;
2598 
2599 		ret = mbedtls_ssl_read(&ctx->ssl, (uint8_t *)buf + recv_len,
2600 				       read_len);
2601 		if (ret < 0) {
2602 			if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2603 				/* Peer notified that it's closing the
2604 				 * connection.
2605 				 */
2606 				ctx->session_closed = true;
2607 				break;
2608 			}
2609 
2610 			if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
2611 				/* Client reconnect on the same socket is not
2612 				 * supported. See mbedtls_ssl_read API
2613 				 * documentation.
2614 				 */
2615 				ctx->session_closed = true;
2616 				break;
2617 			}
2618 
2619 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2620 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2621 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2622 			    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
2623 			    ret == MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET) {
2624 				int timeout_ms;
2625 
2626 				if (!is_block) {
2627 					ret = -EAGAIN;
2628 					goto err;
2629 				}
2630 
2631 				/* Blocking timeout. */
2632 				timeout = sys_timepoint_timeout(end);
2633 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2634 					ret = -EAGAIN;
2635 					goto err;
2636 				}
2637 
2638 				timeout_ms = timeout_to_ms(&timeout);
2639 
2640 				/* Block. */
2641 				k_mutex_unlock(ctx->lock);
2642 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2643 				k_mutex_lock(ctx->lock, K_FOREVER);
2644 
2645 				if (ret == 0) {
2646 					/* Retry. */
2647 					continue;
2648 				}
2649 			} else {
2650 				NET_ERR("TLS recv error: -%x", -ret);
2651 				ret = -EIO;
2652 			}
2653 
2654 err:
2655 			errno = -ret;
2656 			return -1;
2657 		}
2658 
2659 		if (ret == 0) {
2660 			break;
2661 		}
2662 
2663 		recv_len += ret;
2664 	} while ((recv_len == 0) || (waitall && (recv_len < max_len)));
2665 
2666 	return recv_len;
2667 }
2668 
2669 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
recvfrom_dtls_common(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2670 static ssize_t recvfrom_dtls_common(struct tls_context *ctx, void *buf,
2671 				    size_t max_len, int flags,
2672 				    struct sockaddr *src_addr,
2673 				    socklen_t *addrlen)
2674 {
2675 	int ret;
2676 	bool is_block = is_blocking(ctx->sock, flags);
2677 	k_timeout_t timeout;
2678 	k_timepoint_t end;
2679 
2680 	if (ctx->error != 0) {
2681 		errno = ctx->error;
2682 		return -1;
2683 	}
2684 
2685 	if (!is_block) {
2686 		timeout = K_NO_WAIT;
2687 	} else {
2688 		timeout = ctx->options.timeout_rx;
2689 	}
2690 
2691 	end = sys_timepoint_calc(timeout);
2692 
2693 	do {
2694 		size_t remaining;
2695 
2696 		ret = mbedtls_ssl_read(&ctx->ssl, buf, max_len);
2697 		if (ret < 0) {
2698 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2699 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2700 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2701 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2702 				int timeout_dtls, timeout_sock, timeout_ms;
2703 
2704 				if (!is_block) {
2705 					return ret;
2706 				}
2707 
2708 				/* Blocking timeout. */
2709 				timeout = sys_timepoint_timeout(end);
2710 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2711 					return ret;
2712 				}
2713 
2714 				timeout_dtls = dtls_get_remaining_timeout(ctx);
2715 				timeout_sock = timeout_to_ms(&timeout);
2716 				if (timeout_dtls == SYS_FOREVER_MS ||
2717 				    timeout_sock == SYS_FOREVER_MS) {
2718 					timeout_ms = MAX(timeout_dtls, timeout_sock);
2719 				} else {
2720 					timeout_ms = MIN(timeout_dtls, timeout_sock);
2721 				}
2722 
2723 				/* Block. */
2724 				k_mutex_unlock(ctx->lock);
2725 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2726 				k_mutex_lock(ctx->lock, K_FOREVER);
2727 
2728 				if (ret == 0) {
2729 					/* Retry. */
2730 					continue;
2731 				} else {
2732 					return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2733 				}
2734 			} else {
2735 				return ret;
2736 			}
2737 		}
2738 
2739 		if (src_addr && addrlen) {
2740 			dtls_peer_address_get(ctx, src_addr, addrlen);
2741 		}
2742 
2743 		/* mbedtls_ssl_get_bytes_avail() indicate the data length
2744 		 * remaining in the current datagram.
2745 		 */
2746 		remaining = mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2747 
2748 		/* No more data in the datagram, or dummy read. */
2749 		if ((remaining == 0) || (max_len == 0)) {
2750 			return ret;
2751 		}
2752 
2753 		if (flags & ZSOCK_MSG_TRUNC) {
2754 			ret += remaining;
2755 		}
2756 
2757 		for (int i = 0; i < remaining; i++) {
2758 			uint8_t byte;
2759 			int err;
2760 
2761 			err = mbedtls_ssl_read(&ctx->ssl, &byte, sizeof(byte));
2762 			if (err <= 0) {
2763 				NET_ERR("Error while flushing the rest of the"
2764 					" datagram, err %d", err);
2765 				ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2766 				break;
2767 			}
2768 		}
2769 
2770 		break;
2771 	} while (true);
2772 
2773 
2774 	return ret;
2775 }
2776 
recvfrom_dtls_client(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2777 static ssize_t recvfrom_dtls_client(struct tls_context *ctx, void *buf,
2778 				    size_t max_len, int flags,
2779 				    struct sockaddr *src_addr,
2780 				    socklen_t *addrlen)
2781 {
2782 	int ret;
2783 
2784 	if (!is_handshake_complete(ctx)) {
2785 		ret = -ENOTCONN;
2786 		goto error;
2787 	}
2788 
2789 	ret = recvfrom_dtls_common(ctx, buf, max_len, flags, src_addr, addrlen);
2790 	if (ret >= 0) {
2791 		return ret;
2792 	}
2793 
2794 	switch (ret) {
2795 	case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2796 		/* Peer notified that it's closing the connection. */
2797 		ret = tls_mbedtls_reset(ctx);
2798 		if (ret == 0) {
2799 			ctx->error = ENOTCONN;
2800 			ret = -ENOTCONN;
2801 		} else {
2802 			ctx->error = ENOMEM;
2803 			ret = -ENOMEM;
2804 		}
2805 		break;
2806 
2807 	case MBEDTLS_ERR_SSL_TIMEOUT:
2808 		(void)mbedtls_ssl_close_notify(&ctx->ssl);
2809 		ctx->error = ETIMEDOUT;
2810 		ret = -ETIMEDOUT;
2811 		break;
2812 
2813 	case MBEDTLS_ERR_SSL_WANT_READ:
2814 	case MBEDTLS_ERR_SSL_WANT_WRITE:
2815 	case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2816 	case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2817 		ret = -EAGAIN;
2818 		break;
2819 
2820 	default:
2821 		NET_ERR("DTLS client recv error: -%x", -ret);
2822 
2823 		/* MbedTLS API documentation requires session to
2824 		 * be reset in other error cases
2825 		 */
2826 		ret = tls_mbedtls_reset(ctx);
2827 		if (ret != 0) {
2828 			ctx->error = ENOMEM;
2829 			errno = ENOMEM;
2830 		} else {
2831 			ctx->error = ECONNABORTED;
2832 			ret = -ECONNABORTED;
2833 		}
2834 
2835 		break;
2836 	}
2837 
2838 error:
2839 	errno = -ret;
2840 	return -1;
2841 }
2842 
recvfrom_dtls_server(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2843 static ssize_t recvfrom_dtls_server(struct tls_context *ctx, void *buf,
2844 				    size_t max_len, int flags,
2845 				    struct sockaddr *src_addr,
2846 				    socklen_t *addrlen)
2847 {
2848 	int ret;
2849 	bool repeat;
2850 	k_timeout_t timeout;
2851 
2852 	if (!ctx->is_initialized) {
2853 		ret = tls_mbedtls_init(ctx, true);
2854 		if (ret < 0) {
2855 			goto error;
2856 		}
2857 	}
2858 
2859 	if (is_blocking(ctx->sock, flags)) {
2860 		timeout = ctx->options.timeout_rx;
2861 	} else {
2862 		timeout = K_NO_WAIT;
2863 	}
2864 
2865 	/* Loop to enable DTLS reconnection for servers without closing
2866 	 * a socket.
2867 	 */
2868 	do {
2869 		repeat = false;
2870 
2871 		if (!is_handshake_complete(ctx)) {
2872 			ret = tls_mbedtls_handshake(ctx, timeout);
2873 			if (ret < 0) {
2874 				/* In case of EAGAIN, just exit. */
2875 				if (ret == -EAGAIN) {
2876 					break;
2877 				}
2878 
2879 				ret = tls_mbedtls_reset(ctx);
2880 				if (ret == 0) {
2881 					repeat = true;
2882 				} else {
2883 					ret = -ENOMEM;
2884 				}
2885 
2886 				continue;
2887 			}
2888 
2889 			/* Server socket ready to use again. */
2890 			ctx->error = 0;
2891 		}
2892 
2893 		ret = recvfrom_dtls_common(ctx, buf, max_len, flags,
2894 					   src_addr, addrlen);
2895 		if (ret >= 0) {
2896 			return ret;
2897 		}
2898 
2899 		switch (ret) {
2900 		case MBEDTLS_ERR_SSL_TIMEOUT:
2901 			(void)mbedtls_ssl_close_notify(&ctx->ssl);
2902 			__fallthrough;
2903 			/* fallthrough */
2904 
2905 		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2906 		case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
2907 			ret = tls_mbedtls_reset(ctx);
2908 			if (ret == 0) {
2909 				repeat = true;
2910 			} else {
2911 				ctx->error = ENOMEM;
2912 				ret = -ENOMEM;
2913 			}
2914 			break;
2915 
2916 		case MBEDTLS_ERR_SSL_WANT_READ:
2917 		case MBEDTLS_ERR_SSL_WANT_WRITE:
2918 		case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2919 		case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2920 			ret = -EAGAIN;
2921 			break;
2922 
2923 		default:
2924 			NET_ERR("DTLS server recv error: -%x", -ret);
2925 
2926 			ret = tls_mbedtls_reset(ctx);
2927 			if (ret != 0) {
2928 				ctx->error = ENOMEM;
2929 				errno = ENOMEM;
2930 			} else {
2931 				ctx->error = ECONNABORTED;
2932 				ret = -ECONNABORTED;
2933 			}
2934 
2935 			break;
2936 		}
2937 	} while (repeat);
2938 
2939 error:
2940 	errno = -ret;
2941 	return -1;
2942 }
2943 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2944 
ztls_recvfrom_ctx(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2945 ssize_t ztls_recvfrom_ctx(struct tls_context *ctx, void *buf, size_t max_len,
2946 			  int flags, struct sockaddr *src_addr,
2947 			  socklen_t *addrlen)
2948 {
2949 	if (flags & ZSOCK_MSG_PEEK) {
2950 		/* TODO mbedTLS does not support 'peeking' This could be
2951 		 * bypassed by having intermediate buffer for peeking
2952 		 */
2953 		errno = ENOTSUP;
2954 		return -1;
2955 	}
2956 
2957 	ctx->flags = flags;
2958 
2959 	/* TLS */
2960 	if (ctx->type == SOCK_STREAM) {
2961 		return recv_tls(ctx, buf, max_len, flags);
2962 	}
2963 
2964 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2965 	/* DTLS */
2966 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2967 		return recvfrom_dtls_server(ctx, buf, max_len, flags,
2968 					    src_addr, addrlen);
2969 	}
2970 
2971 	return recvfrom_dtls_client(ctx, buf, max_len, flags,
2972 				    src_addr, addrlen);
2973 #else
2974 	errno = ENOTSUP;
2975 	return -1;
2976 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2977 }
2978 
ztls_poll_prepare_pollin(struct tls_context * ctx)2979 static int ztls_poll_prepare_pollin(struct tls_context *ctx)
2980 {
2981 	/* If there already is mbedTLS data to read, there is no
2982 	 * need to set the k_poll_event object. Return EALREADY
2983 	 * so we won't block in the k_poll.
2984 	 */
2985 	if (!ctx->is_listening) {
2986 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
2987 			return -EALREADY;
2988 		}
2989 	}
2990 
2991 	return 0;
2992 }
2993 
ztls_poll_prepare_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)2994 static int ztls_poll_prepare_ctx(struct tls_context *ctx,
2995 				 struct zsock_pollfd *pfd,
2996 				 struct k_poll_event **pev,
2997 				 struct k_poll_event *pev_end)
2998 {
2999 	const struct fd_op_vtable *vtable;
3000 	struct k_mutex *lock;
3001 	void *obj;
3002 	int ret;
3003 	short events = pfd->events;
3004 
3005 	/* DTLS client should wait for the handshake to complete before
3006 	 * it actually starts to poll for data.
3007 	 */
3008 	if ((pfd->events & ZSOCK_POLLIN) && (ctx->type == SOCK_DGRAM) &&
3009 	    (ctx->options.role == MBEDTLS_SSL_IS_CLIENT) &&
3010 	    !is_handshake_complete(ctx)) {
3011 		(*pev)->obj = &ctx->tls_established;
3012 		(*pev)->type = K_POLL_TYPE_SEM_AVAILABLE;
3013 		(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
3014 		(*pev)->state = K_POLL_STATE_NOT_READY;
3015 		(*pev)++;
3016 
3017 		/* Since k_poll_event is configured by the TLS layer in this
3018 		 * case, do not forward ZSOCK_POLLIN to the underlying socket.
3019 		 */
3020 		pfd->events &= ~ZSOCK_POLLIN;
3021 	}
3022 
3023 	obj = zvfs_get_fd_obj_and_vtable(
3024 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3025 	if (obj == NULL) {
3026 		ret = -EBADF;
3027 		goto exit;
3028 	}
3029 
3030 	(void)k_mutex_lock(lock, K_FOREVER);
3031 
3032 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
3033 				   pfd, pev, pev_end);
3034 	if (ret != 0) {
3035 		goto exit;
3036 	}
3037 
3038 	if (pfd->events & ZSOCK_POLLIN) {
3039 		ret = ztls_poll_prepare_pollin(ctx);
3040 	}
3041 
3042 exit:
3043 	/* Restore original events. */
3044 	pfd->events = events;
3045 
3046 	k_mutex_unlock(lock);
3047 
3048 	return ret;
3049 }
3050 
3051 #include <zephyr/net/net_core.h>
3052 
ztls_socket_data_check(struct tls_context * ctx)3053 static int ztls_socket_data_check(struct tls_context *ctx)
3054 {
3055 	int ret;
3056 
3057 	if (ctx->type == SOCK_STREAM) {
3058 		if (!ctx->is_initialized) {
3059 			return -ENOTCONN;
3060 		}
3061 	}
3062 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3063 	else {
3064 		if (!ctx->is_initialized) {
3065 			bool is_server = ctx->options.role == MBEDTLS_SSL_IS_SERVER;
3066 
3067 			ret = tls_mbedtls_init(ctx, is_server);
3068 			if (ret < 0) {
3069 				return -ENOMEM;
3070 			}
3071 		}
3072 
3073 		if (!is_handshake_complete(ctx)) {
3074 			ret = tls_mbedtls_handshake(ctx, K_NO_WAIT);
3075 			if (ret < 0) {
3076 				if (ret == -EAGAIN) {
3077 					return 0;
3078 				}
3079 
3080 				ret = tls_mbedtls_reset(ctx);
3081 				if (ret != 0) {
3082 					return -ENOMEM;
3083 				}
3084 
3085 				return 0;
3086 			}
3087 
3088 			/* Socket ready to use again. */
3089 			ctx->error = 0;
3090 
3091 			return 0;
3092 		}
3093 	}
3094 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3095 
3096 	ctx->flags = ZSOCK_MSG_DONTWAIT;
3097 
3098 	ret = mbedtls_ssl_read(&ctx->ssl, NULL, 0);
3099 	if (ret < 0) {
3100 		if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
3101 			/* Don't reset the context for STREAM socket - the
3102 			 * application needs to reopen the socket anyway, and
3103 			 * resetting the context would result in an error instead
3104 			 * of 0 in a consecutive recv() call.
3105 			 */
3106 			if (ctx->type == SOCK_DGRAM) {
3107 				ret = tls_mbedtls_reset(ctx);
3108 				if (ret != 0) {
3109 					return -ENOMEM;
3110 				}
3111 			} else {
3112 				ctx->session_closed = true;
3113 			}
3114 
3115 			return -ENOTCONN;
3116 		}
3117 
3118 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
3119 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
3120 			return 0;
3121 		}
3122 
3123 		NET_ERR("TLS data check error: -%x", -ret);
3124 
3125 		/* MbedTLS API documentation requires session to
3126 		 * be reset in other error cases
3127 		 */
3128 		if (tls_mbedtls_reset(ctx) != 0) {
3129 			return -ENOMEM;
3130 		}
3131 
3132 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3133 		if (ret == MBEDTLS_ERR_SSL_TIMEOUT && ctx->type == SOCK_DGRAM) {
3134 			/* DTLS timeout interpreted as closing of connection. */
3135 			return -ENOTCONN;
3136 		}
3137 #endif
3138 		return -ECONNABORTED;
3139 	}
3140 
3141 	return mbedtls_ssl_get_bytes_avail(&ctx->ssl);
3142 }
3143 
ztls_poll_update_pollin(int fd,struct tls_context * ctx,struct zsock_pollfd * pfd)3144 static int ztls_poll_update_pollin(int fd, struct tls_context *ctx,
3145 				   struct zsock_pollfd *pfd)
3146 {
3147 	int ret;
3148 
3149 	if (!ctx->is_listening) {
3150 		/* Already had TLS data to read on socket. */
3151 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
3152 			pfd->revents |= ZSOCK_POLLIN;
3153 			goto next;
3154 		}
3155 	}
3156 
3157 	if (ctx->type == SOCK_STREAM) {
3158 		if (!(pfd->revents & ZSOCK_POLLIN)) {
3159 			/* No new data on a socket. */
3160 			goto next;
3161 		}
3162 
3163 		if (ctx->is_listening) {
3164 			goto next;
3165 		}
3166 	}
3167 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3168 	else {
3169 		/* Perform data check without incoming data for completed DTLS connections.
3170 		 * This allows the connections to timeout with CONFIG_NET_SOCKETS_DTLS_TIMEOUT.
3171 		 */
3172 		if (!is_handshake_complete(ctx) && !(pfd->revents & ZSOCK_POLLIN)) {
3173 			goto next;
3174 		}
3175 	}
3176 #endif
3177 	ret = ztls_socket_data_check(ctx);
3178 	if (ret == -ENOTCONN || (pfd->revents & ZSOCK_POLLHUP)) {
3179 		/* Datagram does not return 0 on consecutive recv, but an error
3180 		 * code, hence clear POLLIN.
3181 		 */
3182 		if (ctx->type == SOCK_DGRAM) {
3183 			pfd->revents &= ~ZSOCK_POLLIN;
3184 		}
3185 		pfd->revents |= ZSOCK_POLLHUP;
3186 		goto next;
3187 	} else if (ret < 0) {
3188 		ctx->error = -ret;
3189 		pfd->revents |= ZSOCK_POLLERR;
3190 		goto next;
3191 	} else if (ret == 0) {
3192 		goto again;
3193 	}
3194 
3195 next:
3196 	return 0;
3197 
3198 again:
3199 	/* Received encrypted data, but still not enough
3200 	 * to decrypt it and return data through socket,
3201 	 * ask for retry if no other events are set.
3202 	 */
3203 	pfd->revents &= ~ZSOCK_POLLIN;
3204 
3205 	return -EAGAIN;
3206 }
3207 
ztls_poll_update_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev)3208 static int ztls_poll_update_ctx(struct tls_context *ctx,
3209 				struct zsock_pollfd *pfd,
3210 				struct k_poll_event **pev)
3211 {
3212 	const struct fd_op_vtable *vtable;
3213 	struct k_mutex *lock;
3214 	void *obj;
3215 	int ret;
3216 	short events = pfd->events;
3217 
3218 	obj = zvfs_get_fd_obj_and_vtable(
3219 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
3220 	if (obj == NULL) {
3221 		return -EBADF;
3222 	}
3223 
3224 	(void)k_mutex_lock(lock, K_FOREVER);
3225 
3226 	/* Check if the socket was waiting for the handshake to complete. */
3227 	if ((pfd->events & ZSOCK_POLLIN) &&
3228 	    ((*pev)->obj == &ctx->tls_established)) {
3229 		/* In case handshake is complete, reconfigure the k_poll_event
3230 		 * to monitor the underlying socket now.
3231 		 */
3232 		if ((*pev)->state != K_POLL_STATE_NOT_READY) {
3233 			ret = zvfs_fdtable_call_ioctl(vtable, obj,
3234 						   ZFD_IOCTL_POLL_PREPARE,
3235 						   pfd, pev, *pev + 1);
3236 			if (ret != 0 && ret != -EALREADY) {
3237 				goto out;
3238 			}
3239 
3240 			/* Return -EAGAIN to signal to poll() that it should
3241 			 * make another iteration with the event reconfigured
3242 			 * above (if needed).
3243 			 */
3244 			ret = -EAGAIN;
3245 			goto out;
3246 		}
3247 
3248 		/* Handshake still not ready - skip ZSOCK_POLLIN verification
3249 		 * for the underlying socket.
3250 		 */
3251 		(*pev)++;
3252 		pfd->events &= ~ZSOCK_POLLIN;
3253 	}
3254 
3255 	ret = zvfs_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_UPDATE,
3256 				   pfd, pev);
3257 	if (ret != 0) {
3258 		goto exit;
3259 	}
3260 
3261 	if (pfd->events & ZSOCK_POLLIN) {
3262 		ret = ztls_poll_update_pollin(pfd->fd, ctx, pfd);
3263 		if (ret == -EAGAIN && pfd->revents == 0) {
3264 			(*pev - 1)->state = K_POLL_STATE_NOT_READY;
3265 			goto exit;
3266 		} else {
3267 			ret = 0;
3268 		}
3269 	}
3270 exit:
3271 	/* Restore original events. */
3272 	pfd->events = events;
3273 
3274 out:
3275 	k_mutex_unlock(lock);
3276 
3277 	return ret;
3278 }
3279 
3280 /* Return true if needed to retry rightoff or false otherwise. */
poll_offload_dtls_client_retry(struct tls_context * ctx,struct zsock_pollfd * pfd)3281 static bool poll_offload_dtls_client_retry(struct tls_context *ctx,
3282 					   struct zsock_pollfd *pfd)
3283 {
3284 	/* DTLS client should wait for the handshake to complete before it
3285 	 * reports that data is ready.
3286 	 */
3287 	if ((ctx->type != SOCK_DGRAM) ||
3288 	    (ctx->options.role != MBEDTLS_SSL_IS_CLIENT)) {
3289 		return false;
3290 	}
3291 
3292 	if (ctx->handshake_in_progress) {
3293 		/* Add some sleep to allow lower priority threads to proceed
3294 		 * with handshake.
3295 		 */
3296 		k_msleep(10);
3297 
3298 		pfd->revents &= ~ZSOCK_POLLIN;
3299 		return true;
3300 	} else if (!is_handshake_complete(ctx)) {
3301 		uint8_t byte;
3302 		int ret;
3303 
3304 		/* Handshake didn't start yet - just drop the incoming data -
3305 		 * it's the client who should initiate the handshake.
3306 		 */
3307 		ret = zsock_recv(ctx->sock, &byte, sizeof(byte),
3308 				 ZSOCK_MSG_DONTWAIT);
3309 		if (ret < 0) {
3310 			pfd->revents |= ZSOCK_POLLERR;
3311 		}
3312 
3313 		pfd->revents &= ~ZSOCK_POLLIN;
3314 		return true;
3315 	}
3316 
3317 	/* Handshake complete, just proceed. */
3318 	return false;
3319 }
3320 
ztls_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)3321 static int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int timeout)
3322 {
3323 	int fd_backup[CONFIG_ZVFS_POLL_MAX];
3324 	const struct fd_op_vtable *vtable;
3325 	void *ctx;
3326 	int ret = 0;
3327 	int result;
3328 	int i;
3329 	bool retry;
3330 	int remaining;
3331 	uint32_t entry = k_uptime_get_32();
3332 
3333 	/* Overwrite TLS file descriptors with underlying ones. */
3334 	for (i = 0; i < nfds; i++) {
3335 		fd_backup[i] = fds[i].fd;
3336 
3337 		ctx = zvfs_get_fd_obj(fds[i].fd,
3338 				   (const struct fd_op_vtable *)
3339 						     &tls_sock_fd_op_vtable,
3340 				   0);
3341 		if (ctx == NULL) {
3342 			continue;
3343 		}
3344 
3345 		if (fds[i].events & ZSOCK_POLLIN) {
3346 			ret = ztls_poll_prepare_pollin(ctx);
3347 			/* In case data is already available in mbedtls,
3348 			 * do not wait in poll.
3349 			 */
3350 			if (ret == -EALREADY) {
3351 				timeout = 0;
3352 			}
3353 		}
3354 
3355 		fds[i].fd = ((struct tls_context *)ctx)->sock;
3356 	}
3357 
3358 	/* Get offloaded sockets vtable. */
3359 	ctx = zvfs_get_fd_obj_and_vtable(fds[0].fd,
3360 				      (const struct fd_op_vtable **)&vtable,
3361 				      NULL);
3362 	if (ctx == NULL) {
3363 		errno = EINVAL;
3364 		goto exit;
3365 	}
3366 
3367 	remaining = timeout;
3368 
3369 	do {
3370 		for (i = 0; i < nfds; i++) {
3371 			fds[i].revents = 0;
3372 		}
3373 
3374 		ret = zvfs_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
3375 					   fds, nfds, remaining);
3376 		if (ret < 0) {
3377 			goto exit;
3378 		}
3379 
3380 		retry = false;
3381 		ret = 0;
3382 
3383 		for (i = 0; i < nfds; i++) {
3384 			ctx = zvfs_get_fd_obj(fd_backup[i],
3385 					   (const struct fd_op_vtable *)
3386 							&tls_sock_fd_op_vtable,
3387 					   0);
3388 			if (ctx != NULL) {
3389 				if (fds[i].events & ZSOCK_POLLIN) {
3390 					if (poll_offload_dtls_client_retry(
3391 							ctx, &fds[i])) {
3392 						retry = true;
3393 						continue;
3394 					}
3395 
3396 					result = ztls_poll_update_pollin(
3397 						    fd_backup[i], ctx, &fds[i]);
3398 					if (result == -EAGAIN) {
3399 						retry = true;
3400 					}
3401 				}
3402 			}
3403 
3404 			if (fds[i].revents != 0) {
3405 				ret++;
3406 			}
3407 		}
3408 
3409 		if (retry) {
3410 			if (ret > 0 || timeout == 0) {
3411 				goto exit;
3412 			}
3413 
3414 			if (timeout > 0) {
3415 				remaining = time_left(entry, timeout);
3416 				if (remaining <= 0) {
3417 					goto exit;
3418 				}
3419 			}
3420 		}
3421 	} while (retry);
3422 
3423 exit:
3424 	/* Restore original fds. */
3425 	for (i = 0; i < nfds; i++) {
3426 		fds[i].fd = fd_backup[i];
3427 	}
3428 
3429 	return ret;
3430 }
3431 
ztls_getsockopt_ctx(struct tls_context * ctx,int level,int optname,void * optval,socklen_t * optlen)3432 int ztls_getsockopt_ctx(struct tls_context *ctx, int level, int optname,
3433 			void *optval, socklen_t *optlen)
3434 {
3435 	int err;
3436 
3437 	if (!optval || !optlen) {
3438 		errno = EINVAL;
3439 		return -1;
3440 	}
3441 
3442 	if ((level == SOL_SOCKET) && (optname == SO_PROTOCOL)) {
3443 		/* Protocol type is overridden during socket creation. Its
3444 		 * value is restored here to return current value.
3445 		 */
3446 		err = sock_opt_protocol_get(ctx, optval, optlen);
3447 		if (err < 0) {
3448 			errno = -err;
3449 			return -1;
3450 		}
3451 		return err;
3452 	}
3453 
3454 	/* In case error was set on a socket at the TLS layer (for example due
3455 	 * to receiving TLS alert), handle SO_ERROR here, and report that error.
3456 	 * Otherwise, forward the SO_ERROR option request to the underlying
3457 	 * TCP/UDP socket to handle.
3458 	 */
3459 	if ((level == SOL_SOCKET) && (optname == SO_ERROR) && ctx->error != 0) {
3460 		if (*optlen != sizeof(int)) {
3461 			errno = EINVAL;
3462 			return -1;
3463 		}
3464 
3465 		*(int *)optval = ctx->error;
3466 
3467 		return 0;
3468 	}
3469 
3470 	if (level != SOL_TLS) {
3471 		return zsock_getsockopt(ctx->sock, level, optname,
3472 					optval, optlen);
3473 	}
3474 
3475 	switch (optname) {
3476 	case TLS_SEC_TAG_LIST:
3477 		err =  tls_opt_sec_tag_list_get(ctx, optval, optlen);
3478 		break;
3479 
3480 	case TLS_CIPHERSUITE_LIST:
3481 		err = tls_opt_ciphersuite_list_get(ctx, optval, optlen);
3482 		break;
3483 
3484 	case TLS_CIPHERSUITE_USED:
3485 		err = tls_opt_ciphersuite_used_get(ctx, optval, optlen);
3486 		break;
3487 
3488 	case TLS_ALPN_LIST:
3489 		err = tls_opt_alpn_list_get(ctx, optval, optlen);
3490 		break;
3491 
3492 	case TLS_SESSION_CACHE:
3493 		err = tls_opt_session_cache_get(ctx, optval, optlen);
3494 		break;
3495 
3496 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3497 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3498 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3499 							 optlen, false);
3500 		break;
3501 
3502 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3503 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3504 							 optlen, true);
3505 		break;
3506 
3507 	case TLS_DTLS_CID_STATUS:
3508 		err = tls_opt_dtls_connection_id_status_get(ctx, optval,
3509 							    optlen);
3510 		break;
3511 
3512 	case TLS_DTLS_CID_VALUE:
3513 		err = tls_opt_dtls_connection_id_value_get(ctx, optval, optlen);
3514 		break;
3515 
3516 	case TLS_DTLS_PEER_CID_VALUE:
3517 		err = tls_opt_dtls_peer_connection_id_value_get(ctx, optval,
3518 								optlen);
3519 		break;
3520 
3521 	case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3522 		err = tls_opt_dtls_handshake_on_connect_get(ctx, optval, optlen);
3523 		break;
3524 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3525 
3526 	default:
3527 		/* Unknown or write-only option. */
3528 		err = -ENOPROTOOPT;
3529 		break;
3530 	}
3531 
3532 	if (err < 0) {
3533 		errno = -err;
3534 		return -1;
3535 	}
3536 
3537 	return 0;
3538 }
3539 
set_timeout_opt(k_timeout_t * timeout,const void * optval,socklen_t optlen)3540 static int set_timeout_opt(k_timeout_t *timeout, const void *optval,
3541 			   socklen_t optlen)
3542 {
3543 	const struct zsock_timeval *tval = optval;
3544 
3545 	if (optlen != sizeof(struct zsock_timeval)) {
3546 		return -EINVAL;
3547 	}
3548 
3549 	if (tval->tv_sec == 0 && tval->tv_usec == 0) {
3550 		*timeout = K_FOREVER;
3551 	} else {
3552 		*timeout = K_USEC(tval->tv_sec * 1000000ULL + tval->tv_usec);
3553 	}
3554 
3555 	return 0;
3556 }
3557 
ztls_setsockopt_ctx(struct tls_context * ctx,int level,int optname,const void * optval,socklen_t optlen)3558 int ztls_setsockopt_ctx(struct tls_context *ctx, int level, int optname,
3559 			const void *optval, socklen_t optlen)
3560 {
3561 	int err;
3562 
3563 	/* Underlying socket is used in non-blocking mode, hence implement
3564 	 * timeout at the TLS socket level.
3565 	 */
3566 	if ((level == SOL_SOCKET) && (optname == SO_SNDTIMEO)) {
3567 		err = set_timeout_opt(&ctx->options.timeout_tx, optval, optlen);
3568 		goto out;
3569 	}
3570 
3571 	if ((level == SOL_SOCKET) && (optname == SO_RCVTIMEO)) {
3572 		err = set_timeout_opt(&ctx->options.timeout_rx, optval, optlen);
3573 		goto out;
3574 	}
3575 
3576 	if (level != SOL_TLS) {
3577 		return zsock_setsockopt(ctx->sock, level, optname,
3578 					optval, optlen);
3579 	}
3580 
3581 	switch (optname) {
3582 	case TLS_SEC_TAG_LIST:
3583 		err =  tls_opt_sec_tag_list_set(ctx, optval, optlen);
3584 		break;
3585 
3586 	case TLS_HOSTNAME:
3587 		err = tls_opt_hostname_set(ctx, optval, optlen);
3588 		break;
3589 
3590 	case TLS_CIPHERSUITE_LIST:
3591 		err = tls_opt_ciphersuite_list_set(ctx, optval, optlen);
3592 		break;
3593 
3594 	case TLS_PEER_VERIFY:
3595 		err = tls_opt_peer_verify_set(ctx, optval, optlen);
3596 		break;
3597 
3598 	case TLS_CERT_NOCOPY:
3599 		err = tls_opt_cert_nocopy_set(ctx, optval, optlen);
3600 		break;
3601 
3602 	case TLS_DTLS_ROLE:
3603 		err = tls_opt_dtls_role_set(ctx, optval, optlen);
3604 		break;
3605 
3606 	case TLS_ALPN_LIST:
3607 		err = tls_opt_alpn_list_set(ctx, optval, optlen);
3608 		break;
3609 
3610 	case TLS_SESSION_CACHE:
3611 		err = tls_opt_session_cache_set(ctx, optval, optlen);
3612 		break;
3613 
3614 	case TLS_SESSION_CACHE_PURGE:
3615 		err = tls_opt_session_cache_purge_set(ctx, optval, optlen);
3616 		break;
3617 
3618 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3619 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3620 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3621 							 optlen, false);
3622 		break;
3623 
3624 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3625 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3626 							 optlen, true);
3627 		break;
3628 
3629 	case TLS_DTLS_CID:
3630 		err = tls_opt_dtls_connection_id_set(ctx, optval, optlen);
3631 		break;
3632 
3633 	case TLS_DTLS_CID_VALUE:
3634 		err = tls_opt_dtls_connection_id_value_set(ctx, optval, optlen);
3635 		break;
3636 
3637 	case TLS_DTLS_HANDSHAKE_ON_CONNECT:
3638 		err = tls_opt_dtls_handshake_on_connect_set(ctx, optval, optlen);
3639 		break;
3640 
3641 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3642 
3643 	case TLS_NATIVE:
3644 		/* Option handled at the socket dispatcher level. */
3645 		err = 0;
3646 		break;
3647 
3648 	default:
3649 		/* Unknown or read-only option. */
3650 		err = -ENOPROTOOPT;
3651 		break;
3652 	}
3653 
3654 out:
3655 	if (err < 0) {
3656 		errno = -err;
3657 		return -1;
3658 	}
3659 
3660 	return 0;
3661 }
3662 
3663 #if defined(CONFIG_NET_TEST)
ztls_get_mbedtls_ssl_context(int fd)3664 mbedtls_ssl_context *ztls_get_mbedtls_ssl_context(int fd)
3665 {
3666 	struct tls_context *ctx;
3667 
3668 	ctx = zvfs_get_fd_obj(fd, (const struct fd_op_vtable *)
3669 					&tls_sock_fd_op_vtable, EBADF);
3670 	if (ctx == NULL) {
3671 		return NULL;
3672 	}
3673 
3674 	return &ctx->ssl;
3675 }
3676 #endif /* CONFIG_NET_TEST */
3677 
tls_sock_read_vmeth(void * obj,void * buffer,size_t count)3678 static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count)
3679 {
3680 	return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0);
3681 }
3682 
tls_sock_write_vmeth(void * obj,const void * buffer,size_t count)3683 static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer,
3684 				    size_t count)
3685 {
3686 	return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0);
3687 }
3688 
tls_sock_ioctl_vmeth(void * obj,unsigned int request,va_list args)3689 static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
3690 {
3691 	struct tls_context *ctx = obj;
3692 
3693 	switch (request) {
3694 	/* fcntl() commands */
3695 	case F_GETFL:
3696 	case F_SETFL: {
3697 		const struct fd_op_vtable *vtable;
3698 		struct k_mutex *lock;
3699 		void *fd_obj;
3700 		int ret;
3701 
3702 		fd_obj = zvfs_get_fd_obj_and_vtable(ctx->sock,
3703 				(const struct fd_op_vtable **)&vtable, &lock);
3704 		if (fd_obj == NULL) {
3705 			errno = EBADF;
3706 			return -1;
3707 		}
3708 
3709 		(void)k_mutex_lock(lock, K_FOREVER);
3710 
3711 		/* Pass the call to the core socket implementation. */
3712 		ret = vtable->ioctl(fd_obj, request, args);
3713 
3714 		k_mutex_unlock(lock);
3715 
3716 		return ret;
3717 	}
3718 
3719 	case ZFD_IOCTL_SET_LOCK: {
3720 		struct k_mutex *lock;
3721 
3722 		lock = va_arg(args, struct k_mutex *);
3723 
3724 		ctx_set_lock(obj, lock);
3725 
3726 		return 0;
3727 	}
3728 
3729 	case ZFD_IOCTL_POLL_PREPARE: {
3730 		struct zsock_pollfd *pfd;
3731 		struct k_poll_event **pev;
3732 		struct k_poll_event *pev_end;
3733 
3734 		pfd = va_arg(args, struct zsock_pollfd *);
3735 		pev = va_arg(args, struct k_poll_event **);
3736 		pev_end = va_arg(args, struct k_poll_event *);
3737 
3738 		return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end);
3739 	}
3740 
3741 	case ZFD_IOCTL_POLL_UPDATE: {
3742 		struct zsock_pollfd *pfd;
3743 		struct k_poll_event **pev;
3744 
3745 		pfd = va_arg(args, struct zsock_pollfd *);
3746 		pev = va_arg(args, struct k_poll_event **);
3747 
3748 		return ztls_poll_update_ctx(obj, pfd, pev);
3749 	}
3750 
3751 	case ZFD_IOCTL_POLL_OFFLOAD: {
3752 		struct zsock_pollfd *fds;
3753 		int nfds;
3754 		int timeout;
3755 
3756 		fds = va_arg(args, struct zsock_pollfd *);
3757 		nfds = va_arg(args, int);
3758 		timeout = va_arg(args, int);
3759 
3760 		return ztls_poll_offload(fds, nfds, timeout);
3761 	}
3762 
3763 	default:
3764 		errno = EOPNOTSUPP;
3765 		return -1;
3766 	}
3767 }
3768 
tls_sock_shutdown_vmeth(void * obj,int how)3769 static int tls_sock_shutdown_vmeth(void *obj, int how)
3770 {
3771 	struct tls_context *ctx = obj;
3772 
3773 	return zsock_shutdown(ctx->sock, how);
3774 }
3775 
tls_sock_bind_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3776 static int tls_sock_bind_vmeth(void *obj, const struct sockaddr *addr,
3777 			       socklen_t addrlen)
3778 {
3779 	struct tls_context *ctx = obj;
3780 
3781 	return zsock_bind(ctx->sock, addr, addrlen);
3782 }
3783 
tls_sock_connect_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3784 static int tls_sock_connect_vmeth(void *obj, const struct sockaddr *addr,
3785 				  socklen_t addrlen)
3786 {
3787 	return ztls_connect_ctx(obj, addr, addrlen);
3788 }
3789 
tls_sock_listen_vmeth(void * obj,int backlog)3790 static int tls_sock_listen_vmeth(void *obj, int backlog)
3791 {
3792 	struct tls_context *ctx = obj;
3793 
3794 	ctx->is_listening = true;
3795 
3796 	return zsock_listen(ctx->sock, backlog);
3797 }
3798 
tls_sock_accept_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3799 static int tls_sock_accept_vmeth(void *obj, struct sockaddr *addr,
3800 				 socklen_t *addrlen)
3801 {
3802 	return ztls_accept_ctx(obj, addr, addrlen);
3803 }
3804 
tls_sock_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)3805 static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len,
3806 				     int flags,
3807 				     const struct sockaddr *dest_addr,
3808 				     socklen_t addrlen)
3809 {
3810 	return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen);
3811 }
3812 
tls_sock_sendmsg_vmeth(void * obj,const struct msghdr * msg,int flags)3813 static ssize_t tls_sock_sendmsg_vmeth(void *obj, const struct msghdr *msg,
3814 				      int flags)
3815 {
3816 	return ztls_sendmsg_ctx(obj, msg, flags);
3817 }
3818 
tls_sock_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)3819 static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len,
3820 				       int flags, struct sockaddr *src_addr,
3821 				       socklen_t *addrlen)
3822 {
3823 	return ztls_recvfrom_ctx(obj, buf, max_len, flags,
3824 				 src_addr, addrlen);
3825 }
3826 
tls_sock_getsockopt_vmeth(void * obj,int level,int optname,void * optval,socklen_t * optlen)3827 static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname,
3828 				     void *optval, socklen_t *optlen)
3829 {
3830 	return ztls_getsockopt_ctx(obj, level, optname, optval, optlen);
3831 }
3832 
tls_sock_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,socklen_t optlen)3833 static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname,
3834 				     const void *optval, socklen_t optlen)
3835 {
3836 	return ztls_setsockopt_ctx(obj, level, optname, optval, optlen);
3837 }
3838 
tls_sock_close2_vmeth(void * obj,int sock)3839 static int tls_sock_close2_vmeth(void *obj, int sock)
3840 {
3841 	return ztls_close_ctx(obj, sock);
3842 }
3843 
tls_sock_getpeername_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3844 static int tls_sock_getpeername_vmeth(void *obj, struct sockaddr *addr,
3845 				      socklen_t *addrlen)
3846 {
3847 	struct tls_context *ctx = obj;
3848 
3849 	return zsock_getpeername(ctx->sock, addr, addrlen);
3850 }
3851 
tls_sock_getsockname_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3852 static int tls_sock_getsockname_vmeth(void *obj, struct sockaddr *addr,
3853 				      socklen_t *addrlen)
3854 {
3855 	struct tls_context *ctx = obj;
3856 
3857 	return zsock_getsockname(ctx->sock, addr, addrlen);
3858 }
3859 
3860 static const struct socket_op_vtable tls_sock_fd_op_vtable = {
3861 	.fd_vtable = {
3862 		.read = tls_sock_read_vmeth,
3863 		.write = tls_sock_write_vmeth,
3864 		.close2 = tls_sock_close2_vmeth,
3865 		.ioctl = tls_sock_ioctl_vmeth,
3866 	},
3867 	.shutdown = tls_sock_shutdown_vmeth,
3868 	.bind = tls_sock_bind_vmeth,
3869 	.connect = tls_sock_connect_vmeth,
3870 	.listen = tls_sock_listen_vmeth,
3871 	.accept = tls_sock_accept_vmeth,
3872 	.sendto = tls_sock_sendto_vmeth,
3873 	.sendmsg = tls_sock_sendmsg_vmeth,
3874 	.recvfrom = tls_sock_recvfrom_vmeth,
3875 	.getsockopt = tls_sock_getsockopt_vmeth,
3876 	.setsockopt = tls_sock_setsockopt_vmeth,
3877 	.getpeername = tls_sock_getpeername_vmeth,
3878 	.getsockname = tls_sock_getsockname_vmeth,
3879 };
3880 
tls_is_supported(int family,int type,int proto)3881 static bool tls_is_supported(int family, int type, int proto)
3882 {
3883 	if (protocol_check(family, type, &proto) == 0) {
3884 		return true;
3885 	}
3886 
3887 	return false;
3888 }
3889 
3890 /* Since both, TLS sockets and regular ones fall under the same address family,
3891  * it's required to process TLS first in order to capture socket calls which
3892  * create sockets for secure protocols. Every other call for AF_INET/AF_INET6
3893  * will be forwarded to regular socket implementation.
3894  */
3895 BUILD_ASSERT(CONFIG_NET_SOCKETS_TLS_PRIORITY < CONFIG_NET_SOCKETS_PRIORITY_DEFAULT,
3896 	     "CONFIG_NET_SOCKETS_TLS_PRIORITY have to be smaller than CONFIG_NET_SOCKETS_PRIORITY_DEFAULT");
3897 
3898 NET_SOCKET_REGISTER(tls, CONFIG_NET_SOCKETS_TLS_PRIORITY, AF_UNSPEC,
3899 		    tls_is_supported, ztls_socket);
3900