1 /*
2  * Copyright (c) 2018 Intel Corporation
3  * Copyright (c) 2018 Nordic Semiconductor ASA
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <stdbool.h>
9 #ifdef CONFIG_ARCH_POSIX
10 #include <fcntl.h>
11 #else
12 #include <zephyr/posix/fcntl.h>
13 #endif
14 
15 #include <zephyr/logging/log.h>
16 LOG_MODULE_REGISTER(net_sock_tls, CONFIG_NET_SOCKETS_LOG_LEVEL);
17 
18 #include <zephyr/init.h>
19 #include <zephyr/sys/util.h>
20 #include <zephyr/net/socket.h>
21 #include <zephyr/random/rand32.h>
22 #include <zephyr/syscall_handler.h>
23 #include <zephyr/sys/fdtable.h>
24 
25 /* TODO: Remove all direct access to private fields.
26  * According with Mbed TLS migration guide:
27  *
28  * Direct access to fields of structures
29  * (`struct` types) declared in public headers is no longer
30  * supported. In Mbed TLS 3, the layout of structures is not
31  * considered part of the stable API, and minor versions (3.1, 3.2,
32  * etc.) may add, remove, rename, reorder or change the type of
33  * structure fields.
34  */
35 #if !defined(MBEDTLS_ALLOW_PRIVATE_ACCESS)
36 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
37 #endif
38 
39 #if defined(CONFIG_MBEDTLS)
40 #if !defined(CONFIG_MBEDTLS_CFG_FILE)
41 #include "mbedtls/config.h"
42 #else
43 #include CONFIG_MBEDTLS_CFG_FILE
44 #endif /* CONFIG_MBEDTLS_CFG_FILE */
45 
46 #include <mbedtls/net_sockets.h>
47 #include <mbedtls/x509.h>
48 #include <mbedtls/x509_crt.h>
49 #include <mbedtls/ssl.h>
50 #include <mbedtls/ssl_cookie.h>
51 #include <mbedtls/error.h>
52 #include <mbedtls/platform.h>
53 #include <mbedtls/ssl_cache.h>
54 #endif /* CONFIG_MBEDTLS */
55 
56 #include "sockets_internal.h"
57 #include "tls_internal.h"
58 
59 #if defined(CONFIG_MBEDTLS_DEBUG)
60 #include <zephyr_mbedtls_priv.h>
61 #endif
62 
63 #if defined(CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS)
64 #define ALPN_MAX_PROTOCOLS (CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS + 1)
65 #else
66 #define ALPN_MAX_PROTOCOLS 0
67 #endif /* CONFIG_NET_SOCKETS_TLS_MAX_APP_PROTOCOLS */
68 
69 static const struct socket_op_vtable tls_sock_fd_op_vtable;
70 
71 #ifndef MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED
72 #define MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE
73 #endif
74 
75 /** A list of secure tags that TLS context should use. */
76 struct sec_tag_list {
77 	/** An array of secure tags referencing TLS credentials. */
78 	sec_tag_t sec_tags[CONFIG_NET_SOCKETS_TLS_MAX_CREDENTIALS];
79 
80 	/** Number of configured secure tags. */
81 	int sec_tag_count;
82 };
83 
84 /** Timer context for DTLS. */
85 struct dtls_timing_context {
86 	/** Current time, stored during timer set. */
87 	uint32_t snapshot;
88 
89 	/** Intermediate delay value. For details, refer to mbedTLS API
90 	 *  documentation (mbedtls_ssl_set_timer_t).
91 	 */
92 	uint32_t int_ms;
93 
94 	/** Final delay value. For details, refer to mbedTLS API documentation
95 	 *  (mbedtls_ssl_set_timer_t).
96 	 */
97 	uint32_t fin_ms;
98 };
99 
100 /** TLS peer address/session ID mapping. */
101 struct tls_session_cache {
102 	/** Creation time. */
103 	int64_t timestamp;
104 
105 	/** Peer address. */
106 	struct sockaddr peer_addr;
107 
108 	/** Session buffer. */
109 	uint8_t *session;
110 
111 	/** Session length. */
112 	size_t session_len;
113 };
114 
115 /** TLS context information. */
116 __net_socket struct tls_context {
117 	/** Information whether TLS context is used. */
118 	bool is_used;
119 
120 	/** Underlying TCP/UDP socket. */
121 	int sock;
122 
123 	/** Socket type. */
124 	enum net_sock_type type;
125 
126 	/** Secure protocol version running on TLS context. */
127 	enum net_ip_protocol_secure tls_version;
128 
129 	/** Socket flags passed to a socket call. */
130 	int flags;
131 
132 	/** Information whether TLS context was initialized. */
133 	bool is_initialized;
134 
135 	/** Information whether underlying socket is listening. */
136 	bool is_listening;
137 
138 	/** Information whether TLS handshake is currently in progress. */
139 	bool handshake_in_progress;
140 
141 	/** Information whether TLS handshake is complete or not. */
142 	struct k_sem tls_established;
143 
144 	/* TLS socket mutex lock. */
145 	struct k_mutex *lock;
146 
147 	/** TLS specific option values. */
148 	struct {
149 		/** Select which credentials to use with TLS. */
150 		struct sec_tag_list sec_tag_list;
151 
152 		/** 0-terminated list of allowed ciphersuites (mbedTLS format).
153 		 */
154 		int ciphersuites[CONFIG_NET_SOCKETS_TLS_MAX_CIPHERSUITES + 1];
155 
156 		/** Information if hostname was explicitly set on a socket. */
157 		bool is_hostname_set;
158 
159 		/** Peer verification level. */
160 		int8_t verify_level;
161 
162 		/** Indicating on whether DER certificates should not be copied
163 		 * to the heap.
164 		 */
165 		int8_t cert_nocopy;
166 
167 		/** DTLS role, client by default. */
168 		int8_t role;
169 
170 		/** NULL-terminated list of allowed application layer
171 		 * protocols.
172 		 */
173 		const char *alpn_list[ALPN_MAX_PROTOCOLS];
174 
175 		/** Session cache enabled on a socket. */
176 		bool cache_enabled;
177 
178 		/** Socket TX timeout */
179 		k_timeout_t timeout_tx;
180 
181 		/** Socket RX timeout */
182 		k_timeout_t timeout_rx;
183 
184 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
185 		/* DTLS handshake timeout */
186 		uint32_t dtls_handshake_timeout_min;
187 		uint32_t dtls_handshake_timeout_max;
188 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
189 	} options;
190 
191 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
192 	/** Context information for DTLS timing. */
193 	struct dtls_timing_context dtls_timing;
194 
195 	/** mbedTLS cookie context for DTLS */
196 	mbedtls_ssl_cookie_ctx cookie;
197 
198 	/** DTLS peer address. */
199 	struct sockaddr dtls_peer_addr;
200 
201 	/** DTLS peer address length. */
202 	socklen_t dtls_peer_addrlen;
203 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
204 
205 #if defined(CONFIG_MBEDTLS)
206 	/** mbedTLS context. */
207 	mbedtls_ssl_context ssl;
208 
209 	/** mbedTLS configuration. */
210 	mbedtls_ssl_config config;
211 
212 #if defined(MBEDTLS_X509_CRT_PARSE_C)
213 	/** mbedTLS structure for CA chain. */
214 	mbedtls_x509_crt ca_chain;
215 
216 	/** mbedTLS structure for own certificate. */
217 	mbedtls_x509_crt own_cert;
218 
219 	/** mbedTLS structure for own private key. */
220 	mbedtls_pk_context priv_key;
221 #endif /* MBEDTLS_X509_CRT_PARSE_C */
222 
223 #endif /* CONFIG_MBEDTLS */
224 };
225 
226 
227 /* A global pool of TLS contexts. */
228 static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS];
229 
230 static struct tls_session_cache client_cache[CONFIG_NET_SOCKETS_TLS_MAX_CLIENT_SESSION_COUNT];
231 
232 #if defined(MBEDTLS_SSL_CACHE_C)
233 static mbedtls_ssl_cache_context server_cache;
234 #endif
235 
236 /* A mutex for protecting TLS context allocation. */
237 static struct k_mutex context_lock;
238 
239 /* Arbitrary delay value to wait if mbedTLS reports it cannot proceed for
240  * reasons other than TX/RX block.
241  */
242 #define TLS_WAIT_MS 100
243 
tls_session_cache_reset(void)244 static void tls_session_cache_reset(void)
245 {
246 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
247 		if (client_cache[i].session != NULL) {
248 			mbedtls_free(client_cache[i].session);
249 		}
250 	}
251 
252 	(void)memset(client_cache, 0, sizeof(client_cache));
253 }
254 
net_socket_is_tls(void * obj)255 bool net_socket_is_tls(void *obj)
256 {
257 	return PART_OF_ARRAY(tls_contexts, (struct tls_context *)obj);
258 }
259 
tls_ctr_drbg_random(void * ctx,unsigned char * buf,size_t len)260 static int tls_ctr_drbg_random(void *ctx, unsigned char *buf, size_t len)
261 {
262 	ARG_UNUSED(ctx);
263 
264 #if defined(CONFIG_ENTROPY_HAS_DRIVER)
265 	return sys_csrand_get(buf, len);
266 #else
267 	sys_rand_get(buf, len);
268 
269 	return 0;
270 #endif
271 }
272 
273 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
274 /* mbedTLS-defined function for setting timer. */
dtls_timing_set_delay(void * data,uint32_t int_ms,uint32_t fin_ms)275 static void dtls_timing_set_delay(void *data, uint32_t int_ms, uint32_t fin_ms)
276 {
277 	struct dtls_timing_context *ctx = data;
278 
279 	ctx->int_ms = int_ms;
280 	ctx->fin_ms = fin_ms;
281 
282 	if (fin_ms != 0U) {
283 		ctx->snapshot = k_uptime_get_32();
284 	}
285 }
286 
287 /* mbedTLS-defined function for getting timer status.
288  * The return values are specified by mbedTLS. The callback must return:
289  *   -1 if cancelled (fin_ms == 0),
290  *    0 if none of the delays have passed,
291  *    1 if only the intermediate delay has passed,
292  *    2 if the final delay has passed.
293  */
dtls_timing_get_delay(void * data)294 static int dtls_timing_get_delay(void *data)
295 {
296 	struct dtls_timing_context *timing = data;
297 	unsigned long elapsed_ms;
298 
299 	NET_ASSERT(timing);
300 
301 	if (timing->fin_ms == 0U) {
302 		return -1;
303 	}
304 
305 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
306 
307 	if (elapsed_ms >= timing->fin_ms) {
308 		return 2;
309 	}
310 
311 	if (elapsed_ms >= timing->int_ms) {
312 		return 1;
313 	}
314 
315 	return 0;
316 }
317 
dtls_get_remaining_timeout(struct tls_context * ctx)318 static int dtls_get_remaining_timeout(struct tls_context *ctx)
319 {
320 	struct dtls_timing_context *timing = &ctx->dtls_timing;
321 	uint32_t elapsed_ms;
322 
323 	elapsed_ms = k_uptime_get_32() - timing->snapshot;
324 
325 	if (timing->fin_ms == 0U) {
326 		return SYS_FOREVER_MS;
327 	}
328 
329 	if (elapsed_ms >= timing->fin_ms) {
330 		return 0;
331 	}
332 
333 	return timing->fin_ms - elapsed_ms;
334 }
335 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
336 
337 /* Initialize TLS internals. */
tls_init(void)338 static int tls_init(void)
339 {
340 
341 #if !defined(CONFIG_ENTROPY_HAS_DRIVER)
342 	NET_WARN("No entropy device on the system, "
343 		 "TLS communication is insecure!");
344 #endif
345 
346 	(void)memset(tls_contexts, 0, sizeof(tls_contexts));
347 	(void)memset(client_cache, 0, sizeof(client_cache));
348 
349 	k_mutex_init(&context_lock);
350 
351 #if defined(MBEDTLS_SSL_CACHE_C)
352 	mbedtls_ssl_cache_init(&server_cache);
353 #endif
354 
355 	return 0;
356 }
357 
358 SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT);
359 
is_handshake_complete(struct tls_context * ctx)360 static inline bool is_handshake_complete(struct tls_context *ctx)
361 {
362 	return k_sem_count_get(&ctx->tls_established) != 0;
363 }
364 
365 /*
366  * Copied from include/mbedtls/ssl_internal.h
367  *
368  * Maximum length we can advertise as our max content length for
369  * RFC 6066 max_fragment_length extension negotiation purposes
370  * (the lesser of both sizes, if they are unequal.)
371  */
372 #define MBEDTLS_TLS_EXT_ADV_CONTENT_LEN (                            \
373 	(MBEDTLS_SSL_IN_CONTENT_LEN > MBEDTLS_SSL_OUT_CONTENT_LEN)   \
374 	? (MBEDTLS_SSL_OUT_CONTENT_LEN)				     \
375 	: (MBEDTLS_SSL_IN_CONTENT_LEN)				     \
376 	)
377 
378 #if defined(CONFIG_NET_SOCKETS_TLS_SET_MAX_FRAGMENT_LENGTH) &&	\
379 	defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) &&		\
380 	(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN < 16384)
381 
382 BUILD_ASSERT(MBEDTLS_TLS_EXT_ADV_CONTENT_LEN >= 512,
383 	     "Too small content length!");
384 
tls_mfl_code_from_content_len(void)385 static inline unsigned char tls_mfl_code_from_content_len(void)
386 {
387 	size_t len = MBEDTLS_TLS_EXT_ADV_CONTENT_LEN;
388 
389 	if (len >= 4096) {
390 		return MBEDTLS_SSL_MAX_FRAG_LEN_4096;
391 	} else if (len >= 2048) {
392 		return MBEDTLS_SSL_MAX_FRAG_LEN_2048;
393 	} else if (len >= 1024) {
394 		return MBEDTLS_SSL_MAX_FRAG_LEN_1024;
395 	} else if (len >= 512) {
396 		return MBEDTLS_SSL_MAX_FRAG_LEN_512;
397 	} else {
398 		return MBEDTLS_SSL_MAX_FRAG_LEN_INVALID;
399 	}
400 }
401 
tls_set_max_frag_len(mbedtls_ssl_config * config)402 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config)
403 {
404 	unsigned char mfl_code = tls_mfl_code_from_content_len();
405 
406 	mbedtls_ssl_conf_max_frag_len(config, mfl_code);
407 }
408 #else
tls_set_max_frag_len(mbedtls_ssl_config * config)409 static inline void tls_set_max_frag_len(mbedtls_ssl_config *config) {}
410 #endif
411 
412 /* Allocate TLS context. */
tls_alloc(void)413 static struct tls_context *tls_alloc(void)
414 {
415 	int i;
416 	struct tls_context *tls = NULL;
417 
418 	k_mutex_lock(&context_lock, K_FOREVER);
419 
420 	for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) {
421 		if (!tls_contexts[i].is_used) {
422 			tls = &tls_contexts[i];
423 			(void)memset(tls, 0, sizeof(*tls));
424 			tls->is_used = true;
425 			tls->options.verify_level = -1;
426 			tls->options.timeout_tx = K_FOREVER;
427 			tls->options.timeout_rx = K_FOREVER;
428 			tls->sock = -1;
429 
430 			NET_DBG("Allocated TLS context, %p", tls);
431 			break;
432 		}
433 	}
434 
435 	k_mutex_unlock(&context_lock);
436 
437 	if (tls) {
438 		k_sem_init(&tls->tls_established, 0, 1);
439 
440 		mbedtls_ssl_init(&tls->ssl);
441 		mbedtls_ssl_config_init(&tls->config);
442 		tls_set_max_frag_len(&tls->config);
443 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
444 		mbedtls_ssl_cookie_init(&tls->cookie);
445 		tls->options.dtls_handshake_timeout_min =
446 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN;
447 		tls->options.dtls_handshake_timeout_max =
448 			MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX;
449 #endif
450 #if defined(MBEDTLS_X509_CRT_PARSE_C)
451 		mbedtls_x509_crt_init(&tls->ca_chain);
452 		mbedtls_x509_crt_init(&tls->own_cert);
453 		mbedtls_pk_init(&tls->priv_key);
454 #endif
455 
456 #if defined(CONFIG_MBEDTLS_DEBUG)
457 		mbedtls_ssl_conf_dbg(&tls->config, zephyr_mbedtls_debug, NULL);
458 #endif
459 	} else {
460 		NET_WARN("Failed to allocate TLS context");
461 	}
462 
463 	return tls;
464 }
465 
466 /* Allocate new TLS context and copy the content from the source context. */
tls_clone(struct tls_context * source_tls)467 static struct tls_context *tls_clone(struct tls_context *source_tls)
468 {
469 	struct tls_context *target_tls;
470 
471 	target_tls = tls_alloc();
472 	if (!target_tls) {
473 		return NULL;
474 	}
475 
476 	target_tls->tls_version = source_tls->tls_version;
477 	target_tls->type = source_tls->type;
478 
479 	memcpy(&target_tls->options, &source_tls->options,
480 	       sizeof(target_tls->options));
481 
482 #if defined(MBEDTLS_X509_CRT_PARSE_C)
483 	if (target_tls->options.is_hostname_set) {
484 		mbedtls_ssl_set_hostname(&target_tls->ssl,
485 					 source_tls->ssl.hostname);
486 	}
487 #endif
488 
489 	return target_tls;
490 }
491 
492 /* Release TLS context. */
tls_release(struct tls_context * tls)493 static int tls_release(struct tls_context *tls)
494 {
495 	if (!PART_OF_ARRAY(tls_contexts, tls)) {
496 		NET_ERR("Invalid TLS context");
497 		return -EBADF;
498 	}
499 
500 	if (!tls->is_used) {
501 		NET_ERR("Deallocating unused TLS context");
502 		return -EBADF;
503 	}
504 
505 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
506 	mbedtls_ssl_cookie_free(&tls->cookie);
507 #endif
508 	mbedtls_ssl_config_free(&tls->config);
509 	mbedtls_ssl_free(&tls->ssl);
510 #if defined(MBEDTLS_X509_CRT_PARSE_C)
511 	mbedtls_x509_crt_free(&tls->ca_chain);
512 	mbedtls_x509_crt_free(&tls->own_cert);
513 	mbedtls_pk_free(&tls->priv_key);
514 #endif
515 
516 	tls->is_used = false;
517 
518 	return 0;
519 }
520 
peer_addr_cmp(const struct sockaddr * addr,const struct sockaddr * peer_addr)521 static bool peer_addr_cmp(const struct sockaddr *addr,
522 			  const struct sockaddr *peer_addr)
523 {
524 	if (addr->sa_family != peer_addr->sa_family) {
525 		return false;
526 	}
527 
528 	if (IS_ENABLED(CONFIG_NET_IPV6) && peer_addr->sa_family == AF_INET6) {
529 		struct sockaddr_in6 *addr1 = net_sin6(peer_addr);
530 		struct sockaddr_in6 *addr2 = net_sin6(addr);
531 
532 		return (addr1->sin6_port == addr2->sin6_port) &&
533 			net_ipv6_addr_cmp(&addr1->sin6_addr, &addr2->sin6_addr);
534 	} else if (IS_ENABLED(CONFIG_NET_IPV4) && peer_addr->sa_family == AF_INET) {
535 		struct sockaddr_in *addr1 = net_sin(peer_addr);
536 		struct sockaddr_in *addr2 = net_sin(addr);
537 
538 		return (addr1->sin_port == addr2->sin_port) &&
539 			net_ipv4_addr_cmp(&addr1->sin_addr, &addr2->sin_addr);
540 	}
541 
542 	return false;
543 }
544 
tls_session_save(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)545 static int tls_session_save(const struct sockaddr *peer_addr,
546 			    mbedtls_ssl_session *session)
547 {
548 	struct tls_session_cache *entry = NULL;
549 	size_t session_len;
550 	int ret;
551 
552 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
553 		if (client_cache[i].session == NULL) {
554 			/* New entry. */
555 			if (entry == NULL || entry->session != NULL) {
556 				entry = &client_cache[i];
557 			}
558 		} else {
559 			if (peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
560 				/* Reuse old entry for given address. */
561 				entry = &client_cache[i];
562 				break;
563 			}
564 
565 			/* Remember the oldest entry and reuse if needed. */
566 			if (entry == NULL ||
567 			    (entry->session != NULL &&
568 			     entry->timestamp < client_cache[i].timestamp)) {
569 				entry = &client_cache[i];
570 			}
571 		}
572 	}
573 
574 	/* Allocate session and save */
575 
576 	if (entry->session != NULL) {
577 		mbedtls_free(entry->session);
578 		entry->session = NULL;
579 	}
580 
581 	(void)mbedtls_ssl_session_save(session, NULL, 0, &session_len);
582 
583 	entry->session = mbedtls_calloc(1, session_len);
584 	if (entry->session == NULL) {
585 		NET_ERR("Failed to allocate session buffer.");
586 		return -ENOMEM;
587 	}
588 
589 	ret = mbedtls_ssl_session_save(session, entry->session, session_len,
590 				       &session_len);
591 	if (ret < 0) {
592 		NET_ERR("Failed to serialize session, err: 0x%x.", -ret);
593 		mbedtls_free(entry->session);
594 		entry->session = NULL;
595 		return -ENOMEM;
596 	}
597 
598 	entry->session_len = session_len;
599 	entry->timestamp = k_uptime_get();
600 	memcpy(&entry->peer_addr, peer_addr, sizeof(*peer_addr));
601 
602 	return 0;
603 }
604 
tls_session_get(const struct sockaddr * peer_addr,mbedtls_ssl_session * session)605 static int tls_session_get(const struct sockaddr *peer_addr,
606 			   mbedtls_ssl_session *session)
607 {
608 	struct tls_session_cache *entry = NULL;
609 	int ret;
610 
611 	for (int i = 0; i < ARRAY_SIZE(client_cache); i++) {
612 		if (client_cache[i].session != NULL &&
613 		    peer_addr_cmp(&client_cache[i].peer_addr, peer_addr)) {
614 			entry = &client_cache[i];
615 			break;
616 		}
617 	}
618 
619 	if (entry == NULL) {
620 		return -ENOENT;
621 	}
622 
623 	ret = mbedtls_ssl_session_load(session, entry->session,
624 				       entry->session_len);
625 	if (ret < 0) {
626 		/* Discard corrupted session data. */
627 		mbedtls_free(entry->session);
628 		entry->session = NULL;
629 		return -EIO;
630 	}
631 
632 	return 0;
633 }
634 
tls_session_store(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)635 static void tls_session_store(struct tls_context *context,
636 			      const struct sockaddr *addr,
637 			      socklen_t addrlen)
638 {
639 	mbedtls_ssl_session session;
640 	struct sockaddr peer_addr = { 0 };
641 	int ret;
642 
643 	if (!context->options.cache_enabled) {
644 		return;
645 	}
646 
647 	memcpy(&peer_addr, addr, addrlen);
648 	mbedtls_ssl_session_init(&session);
649 
650 	ret = mbedtls_ssl_get_session(&context->ssl, &session);
651 	if (ret < 0) {
652 		NET_ERR("Failed to obtain session for %p", context);
653 		goto exit;
654 	}
655 
656 	ret = tls_session_save(&peer_addr, &session);
657 	if (ret < 0) {
658 		NET_ERR("Failed to save session for %p", context);
659 	}
660 
661 exit:
662 	mbedtls_ssl_session_free(&session);
663 }
664 
tls_session_restore(struct tls_context * context,const struct sockaddr * addr,socklen_t addrlen)665 static void tls_session_restore(struct tls_context *context,
666 				const struct sockaddr *addr,
667 				socklen_t addrlen)
668 {
669 	mbedtls_ssl_session session;
670 	struct sockaddr peer_addr = { 0 };
671 	int ret;
672 
673 	if (!context->options.cache_enabled) {
674 		return;
675 	}
676 
677 	memcpy(&peer_addr, addr, addrlen);
678 	mbedtls_ssl_session_init(&session);
679 
680 	ret = tls_session_get(&peer_addr, &session);
681 	if (ret < 0) {
682 		NET_DBG("Session not found for %p", context);
683 		goto exit;
684 	}
685 
686 	ret = mbedtls_ssl_set_session(&context->ssl, &session);
687 	if (ret < 0) {
688 		NET_ERR("Failed to set session for %p", context);
689 	}
690 
691 exit:
692 	mbedtls_ssl_session_free(&session);
693 }
694 
tls_session_purge(void)695 static void tls_session_purge(void)
696 {
697 	tls_session_cache_reset();
698 
699 #if defined(MBEDTLS_SSL_CACHE_C)
700 	mbedtls_ssl_cache_free(&server_cache);
701 	mbedtls_ssl_cache_init(&server_cache);
702 #endif
703 }
704 
time_left(uint32_t start,uint32_t timeout)705 static inline int time_left(uint32_t start, uint32_t timeout)
706 {
707 	uint32_t elapsed = k_uptime_get_32() - start;
708 
709 	return timeout - elapsed;
710 }
711 
wait(int sock,int timeout,int event)712 static int wait(int sock, int timeout, int event)
713 {
714 	struct zsock_pollfd fds = {
715 		.fd = sock,
716 		.events = event,
717 	};
718 	int ret;
719 
720 	ret = zsock_poll(&fds, 1, timeout);
721 	if (ret < 0) {
722 		return ret;
723 	}
724 
725 	if (ret == 1) {
726 		if (fds.revents & ZSOCK_POLLNVAL) {
727 			return -EBADF;
728 		}
729 
730 		if (fds.revents & ZSOCK_POLLERR) {
731 			return -EIO;
732 		}
733 	}
734 
735 	return 0;
736 }
737 
wait_for_reason(int sock,int timeout,int reason)738 static int wait_for_reason(int sock, int timeout, int reason)
739 {
740 	if (reason == MBEDTLS_ERR_SSL_WANT_READ) {
741 		return wait(sock, timeout, ZSOCK_POLLIN);
742 	}
743 
744 	if (reason == MBEDTLS_ERR_SSL_WANT_WRITE) {
745 		return wait(sock, timeout, ZSOCK_POLLOUT);
746 	}
747 
748 	/* Any other reason - no way to monitor, just wait for some time. */
749 	k_msleep(TLS_WAIT_MS);
750 
751 	return 0;
752 }
753 
is_blocking(int sock,int flags)754 static bool is_blocking(int sock, int flags)
755 {
756 	int sock_flags = zsock_fcntl(sock, F_GETFL, 0);
757 
758 	if (sock_flags == -1) {
759 		return false;
760 	}
761 
762 	return !((flags & ZSOCK_MSG_DONTWAIT) || (sock_flags & O_NONBLOCK));
763 }
764 
timeout_recalc(uint64_t end,k_timeout_t * timeout)765 static void timeout_recalc(uint64_t end, k_timeout_t *timeout)
766 {
767 	if (!K_TIMEOUT_EQ(*timeout, K_NO_WAIT) &&
768 	    !K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
769 		int64_t remaining = end - sys_clock_tick_get();
770 
771 		if (remaining <= 0) {
772 			*timeout = K_NO_WAIT;
773 		} else {
774 			*timeout = Z_TIMEOUT_TICKS(remaining);
775 		}
776 	}
777 }
778 
timeout_to_ms(k_timeout_t * timeout)779 static int timeout_to_ms(k_timeout_t *timeout)
780 {
781 	if (K_TIMEOUT_EQ(*timeout, K_NO_WAIT)) {
782 		return 0;
783 	} else if (K_TIMEOUT_EQ(*timeout, K_FOREVER)) {
784 		return SYS_FOREVER_MS;
785 	} else {
786 		return k_ticks_to_ms_floor32(timeout->ticks);
787 	}
788 }
789 
ctx_set_lock(struct tls_context * ctx,struct k_mutex * lock)790 static void ctx_set_lock(struct tls_context *ctx, struct k_mutex *lock)
791 {
792 	ctx->lock = lock;
793 }
794 
795 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
dtls_is_peer_addr_valid(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)796 static bool dtls_is_peer_addr_valid(struct tls_context *context,
797 				    const struct sockaddr *peer_addr,
798 				    socklen_t addrlen)
799 {
800 	if (context->dtls_peer_addrlen != addrlen) {
801 		return false;
802 	}
803 
804 	return peer_addr_cmp(&context->dtls_peer_addr, peer_addr);
805 }
806 
dtls_peer_address_set(struct tls_context * context,const struct sockaddr * peer_addr,socklen_t addrlen)807 static void dtls_peer_address_set(struct tls_context *context,
808 				  const struct sockaddr *peer_addr,
809 				  socklen_t addrlen)
810 {
811 	if (addrlen <= sizeof(context->dtls_peer_addr)) {
812 		memcpy(&context->dtls_peer_addr, peer_addr, addrlen);
813 		context->dtls_peer_addrlen = addrlen;
814 	}
815 }
816 
dtls_peer_address_get(struct tls_context * context,struct sockaddr * peer_addr,socklen_t * addrlen)817 static void dtls_peer_address_get(struct tls_context *context,
818 				  struct sockaddr *peer_addr,
819 				  socklen_t *addrlen)
820 {
821 	socklen_t len = MIN(context->dtls_peer_addrlen, *addrlen);
822 
823 	memcpy(peer_addr, &context->dtls_peer_addr, len);
824 	*addrlen = len;
825 }
826 
dtls_tx(void * ctx,const unsigned char * buf,size_t len)827 static int dtls_tx(void *ctx, const unsigned char *buf, size_t len)
828 {
829 	struct tls_context *tls_ctx = ctx;
830 	ssize_t sent;
831 
832 	sent = zsock_sendto(tls_ctx->sock, buf, len, ZSOCK_MSG_DONTWAIT,
833 			    &tls_ctx->dtls_peer_addr,
834 			    tls_ctx->dtls_peer_addrlen);
835 	if (sent < 0) {
836 		if (errno == EAGAIN) {
837 			return MBEDTLS_ERR_SSL_WANT_WRITE;
838 		}
839 
840 		return MBEDTLS_ERR_NET_SEND_FAILED;
841 	}
842 
843 	return sent;
844 }
845 
dtls_rx(void * ctx,unsigned char * buf,size_t len)846 static int dtls_rx(void *ctx, unsigned char *buf, size_t len)
847 {
848 	struct tls_context *tls_ctx = ctx;
849 	socklen_t addrlen = sizeof(struct sockaddr);
850 	struct sockaddr addr;
851 	int err;
852 	ssize_t received;
853 
854 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
855 				  ZSOCK_MSG_DONTWAIT, &addr, &addrlen);
856 	if (received < 0) {
857 		if (errno == EAGAIN) {
858 			return MBEDTLS_ERR_SSL_WANT_READ;
859 		}
860 
861 		return MBEDTLS_ERR_NET_RECV_FAILED;
862 	}
863 
864 	if (tls_ctx->dtls_peer_addrlen == 0) {
865 		/* Only allow to store peer address for DTLS servers. */
866 		if (tls_ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
867 			dtls_peer_address_set(tls_ctx, &addr, addrlen);
868 
869 			err = mbedtls_ssl_set_client_transport_id(
870 				&tls_ctx->ssl,
871 				(const unsigned char *)&addr, addrlen);
872 			if (err < 0) {
873 				return err;
874 			}
875 		} else {
876 			/* For clients it's incorrect to receive when
877 			 * no peer has been set up.
878 			 */
879 			return MBEDTLS_ERR_SSL_PEER_VERIFY_FAILED;
880 		}
881 	} else if (!dtls_is_peer_addr_valid(tls_ctx, &addr, addrlen)) {
882 		return MBEDTLS_ERR_SSL_WANT_READ;
883 	}
884 
885 	return received;
886 }
887 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
888 
tls_tx(void * ctx,const unsigned char * buf,size_t len)889 static int tls_tx(void *ctx, const unsigned char *buf, size_t len)
890 {
891 	struct tls_context *tls_ctx = ctx;
892 	ssize_t sent;
893 
894 	sent = zsock_sendto(tls_ctx->sock, buf, len,
895 			    ZSOCK_MSG_DONTWAIT, NULL, 0);
896 	if (sent < 0) {
897 		if (errno == EAGAIN) {
898 			return MBEDTLS_ERR_SSL_WANT_WRITE;
899 		}
900 
901 		return MBEDTLS_ERR_NET_SEND_FAILED;
902 	}
903 
904 	return sent;
905 }
906 
tls_rx(void * ctx,unsigned char * buf,size_t len)907 static int tls_rx(void *ctx, unsigned char *buf, size_t len)
908 {
909 	struct tls_context *tls_ctx = ctx;
910 	ssize_t received;
911 
912 	received = zsock_recvfrom(tls_ctx->sock, buf, len,
913 				  ZSOCK_MSG_DONTWAIT, NULL, 0);
914 	if (received < 0) {
915 		if (errno == EAGAIN) {
916 			return MBEDTLS_ERR_SSL_WANT_READ;
917 		}
918 
919 		return MBEDTLS_ERR_NET_RECV_FAILED;
920 	}
921 
922 	return received;
923 }
924 
925 #if defined(MBEDTLS_X509_CRT_PARSE_C)
crt_is_pem(const unsigned char * buf,size_t buflen)926 static bool crt_is_pem(const unsigned char *buf, size_t buflen)
927 {
928 	return (buflen != 0 && buf[buflen - 1] == '\0' &&
929 		strstr((const char *)buf, "-----BEGIN CERTIFICATE-----") != NULL);
930 }
931 #endif
932 
tls_add_ca_certificate(struct tls_context * tls,struct tls_credential * ca_cert)933 static int tls_add_ca_certificate(struct tls_context *tls,
934 				  struct tls_credential *ca_cert)
935 {
936 #if defined(MBEDTLS_X509_CRT_PARSE_C)
937 	int err;
938 
939 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
940 	    crt_is_pem(ca_cert->buf, ca_cert->len)) {
941 		err = mbedtls_x509_crt_parse(&tls->ca_chain, ca_cert->buf,
942 					     ca_cert->len);
943 	} else {
944 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->ca_chain,
945 							ca_cert->buf,
946 							ca_cert->len);
947 	}
948 
949 	if (err != 0) {
950 		return -EINVAL;
951 	}
952 
953 	return 0;
954 #endif /* MBEDTLS_X509_CRT_PARSE_C */
955 
956 	return -ENOTSUP;
957 }
958 
tls_set_ca_chain(struct tls_context * tls)959 static void tls_set_ca_chain(struct tls_context *tls)
960 {
961 #if defined(MBEDTLS_X509_CRT_PARSE_C)
962 	mbedtls_ssl_conf_ca_chain(&tls->config, &tls->ca_chain, NULL);
963 	mbedtls_ssl_conf_cert_profile(&tls->config,
964 				      &mbedtls_x509_crt_profile_default);
965 #endif /* MBEDTLS_X509_CRT_PARSE_C */
966 }
967 
tls_add_own_cert(struct tls_context * tls,struct tls_credential * own_cert)968 static int tls_add_own_cert(struct tls_context *tls,
969 			    struct tls_credential *own_cert)
970 {
971 #if defined(MBEDTLS_X509_CRT_PARSE_C)
972 	int err;
973 
974 	if (tls->options.cert_nocopy == TLS_CERT_NOCOPY_NONE ||
975 	    crt_is_pem(own_cert->buf, own_cert->len)) {
976 		err = mbedtls_x509_crt_parse(&tls->own_cert,
977 					     own_cert->buf, own_cert->len);
978 	} else {
979 		err = mbedtls_x509_crt_parse_der_nocopy(&tls->own_cert,
980 							own_cert->buf,
981 							own_cert->len);
982 	}
983 
984 	if (err != 0) {
985 		return -EINVAL;
986 	}
987 
988 	return 0;
989 #endif /* MBEDTLS_X509_CRT_PARSE_C */
990 
991 	return -ENOTSUP;
992 }
993 
tls_set_own_cert(struct tls_context * tls)994 static int tls_set_own_cert(struct tls_context *tls)
995 {
996 #if defined(MBEDTLS_X509_CRT_PARSE_C)
997 	int err = mbedtls_ssl_conf_own_cert(&tls->config, &tls->own_cert,
998 					    &tls->priv_key);
999 	if (err != 0) {
1000 		err = -ENOMEM;
1001 	}
1002 
1003 	return err;
1004 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1005 
1006 	return -ENOTSUP;
1007 }
1008 
tls_set_private_key(struct tls_context * tls,struct tls_credential * priv_key)1009 static int tls_set_private_key(struct tls_context *tls,
1010 			       struct tls_credential *priv_key)
1011 {
1012 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1013 	int err;
1014 
1015 	err = mbedtls_pk_parse_key(&tls->priv_key, priv_key->buf,
1016 				   priv_key->len, NULL, 0,
1017 				   tls_ctr_drbg_random, NULL);
1018 	if (err != 0) {
1019 		return -EINVAL;
1020 	}
1021 
1022 	return 0;
1023 #endif /* MBEDTLS_X509_CRT_PARSE_C */
1024 
1025 	return -ENOTSUP;
1026 }
1027 
tls_set_psk(struct tls_context * tls,struct tls_credential * psk,struct tls_credential * psk_id)1028 static int tls_set_psk(struct tls_context *tls,
1029 		       struct tls_credential *psk,
1030 		       struct tls_credential *psk_id)
1031 {
1032 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
1033 	int err = mbedtls_ssl_conf_psk(&tls->config,
1034 				       psk->buf, psk->len,
1035 				       (const unsigned char *)psk_id->buf,
1036 				       psk_id->len);
1037 	if (err != 0) {
1038 		return -EINVAL;
1039 	}
1040 
1041 	return 0;
1042 #endif
1043 
1044 	return -ENOTSUP;
1045 }
1046 
tls_set_credential(struct tls_context * tls,struct tls_credential * cred)1047 static int tls_set_credential(struct tls_context *tls,
1048 			      struct tls_credential *cred)
1049 {
1050 	switch (cred->type) {
1051 	case TLS_CREDENTIAL_CA_CERTIFICATE:
1052 		return tls_add_ca_certificate(tls, cred);
1053 
1054 	case TLS_CREDENTIAL_SERVER_CERTIFICATE:
1055 		return tls_add_own_cert(tls, cred);
1056 
1057 	case TLS_CREDENTIAL_PRIVATE_KEY:
1058 		return tls_set_private_key(tls, cred);
1059 	break;
1060 
1061 	case TLS_CREDENTIAL_PSK:
1062 	{
1063 		struct tls_credential *psk_id =
1064 			credential_get(cred->tag, TLS_CREDENTIAL_PSK_ID);
1065 		if (!psk_id) {
1066 			return -ENOENT;
1067 		}
1068 
1069 		return tls_set_psk(tls, cred, psk_id);
1070 	}
1071 
1072 	case TLS_CREDENTIAL_PSK_ID:
1073 		/* Ignore PSK ID - it will be used together
1074 		 * with PSK
1075 		 */
1076 		break;
1077 
1078 	default:
1079 		return -EINVAL;
1080 	}
1081 
1082 	return 0;
1083 }
1084 
tls_mbedtls_set_credentials(struct tls_context * tls)1085 static int tls_mbedtls_set_credentials(struct tls_context *tls)
1086 {
1087 	struct tls_credential *cred;
1088 	sec_tag_t tag;
1089 	int i, err = 0;
1090 	bool tag_found, ca_cert_present = false, own_cert_present = false;
1091 
1092 	credentials_lock();
1093 
1094 	for (i = 0; i < tls->options.sec_tag_list.sec_tag_count; i++) {
1095 		tag = tls->options.sec_tag_list.sec_tags[i];
1096 		cred = NULL;
1097 		tag_found = false;
1098 
1099 		while ((cred = credential_next_get(tag, cred)) != NULL) {
1100 			tag_found = true;
1101 
1102 			err = tls_set_credential(tls, cred);
1103 			if (err != 0) {
1104 				goto exit;
1105 			}
1106 
1107 			if (cred->type == TLS_CREDENTIAL_CA_CERTIFICATE) {
1108 				ca_cert_present = true;
1109 			} else if (cred->type == TLS_CREDENTIAL_SERVER_CERTIFICATE) {
1110 				own_cert_present = true;
1111 			}
1112 		}
1113 
1114 		if (!tag_found) {
1115 			err = -ENOENT;
1116 			goto exit;
1117 		}
1118 	}
1119 
1120 exit:
1121 	credentials_unlock();
1122 
1123 	if (err == 0) {
1124 		if (ca_cert_present) {
1125 			tls_set_ca_chain(tls);
1126 		}
1127 		if (own_cert_present) {
1128 			err = tls_set_own_cert(tls);
1129 		}
1130 	}
1131 
1132 	return err;
1133 }
1134 
tls_mbedtls_reset(struct tls_context * context)1135 static int tls_mbedtls_reset(struct tls_context *context)
1136 {
1137 	int ret;
1138 
1139 	ret = mbedtls_ssl_session_reset(&context->ssl);
1140 	if (ret != 0) {
1141 		return ret;
1142 	}
1143 
1144 	k_sem_reset(&context->tls_established);
1145 
1146 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1147 	/* Server role: reset the address so that a new
1148 	 *              client can connect w/o a need to reopen a socket
1149 	 * Client role: keep peer addr so socket can continue to be used
1150 	 *              even on handshake timeout
1151 	 */
1152 	if (context->options.role == MBEDTLS_SSL_IS_SERVER) {
1153 		(void)memset(&context->dtls_peer_addr, 0,
1154 			     sizeof(context->dtls_peer_addr));
1155 		context->dtls_peer_addrlen = 0;
1156 	}
1157 #endif
1158 
1159 	return 0;
1160 }
1161 
tls_mbedtls_handshake(struct tls_context * context,k_timeout_t timeout)1162 static int tls_mbedtls_handshake(struct tls_context *context,
1163 				 k_timeout_t timeout)
1164 {
1165 	uint64_t end;
1166 	int ret;
1167 
1168 	context->handshake_in_progress = true;
1169 
1170 	end = sys_clock_timeout_end_calc(timeout);
1171 
1172 	while ((ret = mbedtls_ssl_handshake(&context->ssl)) != 0) {
1173 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
1174 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
1175 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
1176 		    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
1177 			int timeout_ms;
1178 
1179 			/* Blocking timeout. */
1180 			timeout_recalc(end, &timeout);
1181 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1182 				ret = -EAGAIN;
1183 				break;
1184 			}
1185 
1186 			/* Block. */
1187 			timeout_ms = timeout_to_ms(&timeout);
1188 
1189 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1190 			if (context->type == SOCK_DGRAM) {
1191 				int timeout_dtls =
1192 					dtls_get_remaining_timeout(context);
1193 
1194 				if (timeout_dtls != SYS_FOREVER_MS) {
1195 					if (timeout_ms == SYS_FOREVER_MS) {
1196 						timeout_ms = timeout_dtls;
1197 					} else {
1198 						timeout_ms = MIN(timeout_dtls,
1199 								 timeout_ms);
1200 					}
1201 				}
1202 			}
1203 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1204 
1205 			ret = wait_for_reason(context->sock, timeout_ms, ret);
1206 			if (ret != 0) {
1207 				break;
1208 			}
1209 
1210 			continue;
1211 		} else if (ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED) {
1212 			ret = tls_mbedtls_reset(context);
1213 			if (ret == 0) {
1214 				if (!K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
1215 					continue;
1216 				}
1217 
1218 				ret = -EAGAIN;
1219 				break;
1220 			}
1221 		} else if (ret == MBEDTLS_ERR_SSL_TIMEOUT) {
1222 			/* MbedTLS API documentation requires session to
1223 			 * be reset in this case
1224 			 */
1225 			ret = tls_mbedtls_reset(context);
1226 			if (ret == 0) {
1227 				NET_ERR("TLS handshake timeout");
1228 				ret = -ETIMEDOUT;
1229 				break;
1230 			}
1231 		} else {
1232 			/* MbedTLS API documentation requires session to
1233 			 * be reset in other error cases
1234 			 */
1235 			NET_ERR("TLS handshake error: -%x", -ret);
1236 			ret = tls_mbedtls_reset(context);
1237 			if (ret == 0) {
1238 				ret = -ECONNABORTED;
1239 				break;
1240 			}
1241 		}
1242 
1243 		/* Avoid constant loop if tls_mbedtls_reset fails */
1244 		NET_ERR("TLS reset error: -%x", -ret);
1245 		ret = -ECONNABORTED;
1246 		break;
1247 	}
1248 
1249 	if (ret == 0) {
1250 		k_sem_give(&context->tls_established);
1251 	}
1252 
1253 	context->handshake_in_progress = false;
1254 
1255 	return ret;
1256 }
1257 
tls_mbedtls_init(struct tls_context * context,bool is_server)1258 static int tls_mbedtls_init(struct tls_context *context, bool is_server)
1259 {
1260 	int role, type, ret;
1261 
1262 	role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
1263 
1264 	type = (context->type == SOCK_STREAM) ?
1265 		MBEDTLS_SSL_TRANSPORT_STREAM :
1266 		MBEDTLS_SSL_TRANSPORT_DATAGRAM;
1267 
1268 	if (type == MBEDTLS_SSL_TRANSPORT_STREAM) {
1269 		mbedtls_ssl_set_bio(&context->ssl, context,
1270 				    tls_tx, tls_rx, NULL);
1271 	} else {
1272 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1273 		mbedtls_ssl_set_bio(&context->ssl, context,
1274 				    dtls_tx, dtls_rx, NULL);
1275 #else
1276 		return -ENOTSUP;
1277 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1278 	}
1279 
1280 	ret = mbedtls_ssl_config_defaults(&context->config, role, type,
1281 					  MBEDTLS_SSL_PRESET_DEFAULT);
1282 	if (ret != 0) {
1283 		/* According to mbedTLS API documentation,
1284 		 * mbedtls_ssl_config_defaults can fail due to memory
1285 		 * allocation failure
1286 		 */
1287 		return -ENOMEM;
1288 	}
1289 
1290 #if defined(MBEDTLS_SSL_RENEGOTIATION)
1291 	mbedtls_ssl_conf_legacy_renegotiation(&context->config,
1292 					   MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE);
1293 	mbedtls_ssl_conf_renegotiation(&context->config,
1294 				       MBEDTLS_SSL_RENEGOTIATION_ENABLED);
1295 #endif
1296 
1297 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1298 	if (type == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
1299 		/* DTLS requires timer callbacks to operate */
1300 		mbedtls_ssl_set_timer_cb(&context->ssl,
1301 					 &context->dtls_timing,
1302 					 dtls_timing_set_delay,
1303 					 dtls_timing_get_delay);
1304 		mbedtls_ssl_conf_handshake_timeout(&context->config,
1305 				context->options.dtls_handshake_timeout_min,
1306 				context->options.dtls_handshake_timeout_max);
1307 
1308 		/* Configure cookie for DTLS server */
1309 		if (role == MBEDTLS_SSL_IS_SERVER) {
1310 			ret = mbedtls_ssl_cookie_setup(&context->cookie,
1311 						       tls_ctr_drbg_random,
1312 						       NULL);
1313 			if (ret != 0) {
1314 				return -ENOMEM;
1315 			}
1316 
1317 			mbedtls_ssl_conf_dtls_cookies(&context->config,
1318 						      mbedtls_ssl_cookie_write,
1319 						      mbedtls_ssl_cookie_check,
1320 						      &context->cookie);
1321 
1322 			mbedtls_ssl_conf_read_timeout(
1323 					&context->config,
1324 					CONFIG_NET_SOCKETS_DTLS_TIMEOUT);
1325 		}
1326 	}
1327 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1328 
1329 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1330 	/* For TLS clients, set hostname to empty string to enforce it's
1331 	 * verification - only if hostname option was not set. Otherwise
1332 	 * depend on user configuration.
1333 	 */
1334 	if (!is_server && !context->options.is_hostname_set) {
1335 		mbedtls_ssl_set_hostname(&context->ssl, "");
1336 	}
1337 #endif
1338 
1339 	/* If verification level was specified explicitly, set it. Otherwise,
1340 	 * use mbedTLS default values (required for client, none for server)
1341 	 */
1342 	if (context->options.verify_level != -1) {
1343 		mbedtls_ssl_conf_authmode(&context->config,
1344 					  context->options.verify_level);
1345 	}
1346 
1347 	mbedtls_ssl_conf_rng(&context->config,
1348 			     tls_ctr_drbg_random,
1349 			     NULL);
1350 
1351 	ret = tls_mbedtls_set_credentials(context);
1352 	if (ret != 0) {
1353 		return ret;
1354 	}
1355 
1356 	if (context->options.ciphersuites[0] != 0) {
1357 		/* Specific ciphersuites configured, so use them */
1358 		NET_DBG("Using user-specified ciphersuites");
1359 		mbedtls_ssl_conf_ciphersuites(&context->config,
1360 					      context->options.ciphersuites);
1361 	}
1362 
1363 #if defined(CONFIG_MBEDTLS_SSL_ALPN)
1364 	if (ALPN_MAX_PROTOCOLS && context->options.alpn_list[0] != NULL) {
1365 		ret = mbedtls_ssl_conf_alpn_protocols(&context->config,
1366 				context->options.alpn_list);
1367 		if (ret != 0) {
1368 			return -EINVAL;
1369 		}
1370 	}
1371 #endif /* CONFIG_MBEDTLS_SSL_ALPN */
1372 
1373 #if defined(MBEDTLS_SSL_CACHE_C)
1374 	if (is_server && context->options.cache_enabled) {
1375 		mbedtls_ssl_conf_session_cache(&context->config, &server_cache,
1376 					       mbedtls_ssl_cache_get,
1377 					       mbedtls_ssl_cache_set);
1378 	}
1379 #endif
1380 
1381 	ret = mbedtls_ssl_setup(&context->ssl,
1382 				&context->config);
1383 	if (ret != 0) {
1384 		/* According to mbedTLS API documentation,
1385 		 * mbedtls_ssl_setup can fail due to memory allocation failure
1386 		 */
1387 		return -ENOMEM;
1388 	}
1389 
1390 	context->is_initialized = true;
1391 
1392 	return 0;
1393 }
1394 
tls_opt_sec_tag_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1395 static int tls_opt_sec_tag_list_set(struct tls_context *context,
1396 				    const void *optval, socklen_t optlen)
1397 {
1398 	int sec_tag_cnt;
1399 
1400 	if (!optval) {
1401 		return -EINVAL;
1402 	}
1403 
1404 	if (optlen % sizeof(sec_tag_t) != 0) {
1405 		return -EINVAL;
1406 	}
1407 
1408 	sec_tag_cnt = optlen / sizeof(sec_tag_t);
1409 	if (sec_tag_cnt >
1410 		ARRAY_SIZE(context->options.sec_tag_list.sec_tags)) {
1411 		return -EINVAL;
1412 	}
1413 
1414 	memcpy(context->options.sec_tag_list.sec_tags, optval, optlen);
1415 	context->options.sec_tag_list.sec_tag_count = sec_tag_cnt;
1416 
1417 	return 0;
1418 }
1419 
sock_opt_protocol_get(struct tls_context * context,void * optval,socklen_t * optlen)1420 static int sock_opt_protocol_get(struct tls_context *context,
1421 				 void *optval, socklen_t *optlen)
1422 {
1423 	int protocol = (int)context->tls_version;
1424 
1425 	if (*optlen != sizeof(protocol)) {
1426 		return -EINVAL;
1427 	}
1428 
1429 	*(int *)optval = protocol;
1430 
1431 	return 0;
1432 }
1433 
tls_opt_sec_tag_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1434 static int tls_opt_sec_tag_list_get(struct tls_context *context,
1435 				    void *optval, socklen_t *optlen)
1436 {
1437 	int len;
1438 
1439 	if (*optlen % sizeof(sec_tag_t) != 0 || *optlen == 0) {
1440 		return -EINVAL;
1441 	}
1442 
1443 	len = MIN(context->options.sec_tag_list.sec_tag_count *
1444 		  sizeof(sec_tag_t), *optlen);
1445 
1446 	memcpy(optval, context->options.sec_tag_list.sec_tags, len);
1447 	*optlen = len;
1448 
1449 	return 0;
1450 }
1451 
tls_opt_hostname_set(struct tls_context * context,const void * optval,socklen_t optlen)1452 static int tls_opt_hostname_set(struct tls_context *context,
1453 				const void *optval, socklen_t optlen)
1454 {
1455 	ARG_UNUSED(optlen);
1456 
1457 #if defined(MBEDTLS_X509_CRT_PARSE_C)
1458 	if (mbedtls_ssl_set_hostname(&context->ssl, optval) != 0) {
1459 		return -EINVAL;
1460 	}
1461 #else
1462 	return -ENOPROTOOPT;
1463 #endif
1464 
1465 	context->options.is_hostname_set = true;
1466 
1467 	return 0;
1468 }
1469 
tls_opt_ciphersuite_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1470 static int tls_opt_ciphersuite_list_set(struct tls_context *context,
1471 					const void *optval, socklen_t optlen)
1472 {
1473 	int cipher_cnt;
1474 
1475 	if (!optval) {
1476 		return -EINVAL;
1477 	}
1478 
1479 	if (optlen % sizeof(int) != 0) {
1480 		return -EINVAL;
1481 	}
1482 
1483 	cipher_cnt = optlen / sizeof(int);
1484 
1485 	/* + 1 for 0-termination. */
1486 	if (cipher_cnt + 1 > ARRAY_SIZE(context->options.ciphersuites)) {
1487 		return -EINVAL;
1488 	}
1489 
1490 	memcpy(context->options.ciphersuites, optval, optlen);
1491 	context->options.ciphersuites[cipher_cnt] = 0;
1492 
1493 	mbedtls_ssl_conf_ciphersuites(&context->config,
1494 				      context->options.ciphersuites);
1495 	return 0;
1496 }
1497 
tls_opt_ciphersuite_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1498 static int tls_opt_ciphersuite_list_get(struct tls_context *context,
1499 					void *optval, socklen_t *optlen)
1500 {
1501 	const int *selected_ciphers;
1502 	int cipher_cnt, i = 0;
1503 	int *ciphers = optval;
1504 
1505 	if (*optlen % sizeof(int) != 0 || *optlen == 0) {
1506 		return -EINVAL;
1507 	}
1508 
1509 	if (context->options.ciphersuites[0] == 0) {
1510 		/* No specific ciphersuites configured, return all available. */
1511 		selected_ciphers = mbedtls_ssl_list_ciphersuites();
1512 	} else {
1513 		selected_ciphers = context->options.ciphersuites;
1514 	}
1515 
1516 	cipher_cnt = *optlen / sizeof(int);
1517 	while (selected_ciphers[i] != 0) {
1518 		ciphers[i] = selected_ciphers[i];
1519 
1520 		if (++i == cipher_cnt) {
1521 			break;
1522 		}
1523 	}
1524 
1525 	*optlen = i * sizeof(int);
1526 
1527 	return 0;
1528 }
1529 
tls_opt_ciphersuite_used_get(struct tls_context * context,void * optval,socklen_t * optlen)1530 static int tls_opt_ciphersuite_used_get(struct tls_context *context,
1531 					void *optval, socklen_t *optlen)
1532 {
1533 	const char *ciph;
1534 
1535 	if (*optlen != sizeof(int)) {
1536 		return -EINVAL;
1537 	}
1538 
1539 	ciph = mbedtls_ssl_get_ciphersuite(&context->ssl);
1540 	if (ciph == NULL) {
1541 		return -ENOTCONN;
1542 	}
1543 
1544 	*(int *)optval = mbedtls_ssl_get_ciphersuite_id(ciph);
1545 
1546 	return 0;
1547 }
1548 
tls_opt_alpn_list_set(struct tls_context * context,const void * optval,socklen_t optlen)1549 static int tls_opt_alpn_list_set(struct tls_context *context,
1550 				 const void *optval, socklen_t optlen)
1551 {
1552 	int alpn_cnt;
1553 
1554 	if (!ALPN_MAX_PROTOCOLS) {
1555 		return -EINVAL;
1556 	}
1557 
1558 	if (!optval) {
1559 		return -EINVAL;
1560 	}
1561 
1562 	if (optlen % sizeof(const char *) != 0) {
1563 		return -EINVAL;
1564 	}
1565 
1566 	alpn_cnt = optlen / sizeof(const char *);
1567 	/* + 1 for NULL-termination. */
1568 	if (alpn_cnt + 1 > ARRAY_SIZE(context->options.alpn_list)) {
1569 		return -EINVAL;
1570 	}
1571 
1572 	memcpy(context->options.alpn_list, optval, optlen);
1573 	context->options.alpn_list[alpn_cnt] = NULL;
1574 
1575 	return 0;
1576 }
1577 
1578 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
tls_opt_dtls_handshake_timeout_get(struct tls_context * context,void * optval,socklen_t * optlen,bool is_max)1579 static int tls_opt_dtls_handshake_timeout_get(struct tls_context *context,
1580 					      void *optval, socklen_t *optlen,
1581 					      bool is_max)
1582 {
1583 	uint32_t *val = (uint32_t *)optval;
1584 
1585 	if (sizeof(uint32_t) != *optlen) {
1586 		return -EINVAL;
1587 	}
1588 
1589 	if (is_max) {
1590 		*val = context->options.dtls_handshake_timeout_max;
1591 	} else {
1592 		*val = context->options.dtls_handshake_timeout_min;
1593 	}
1594 
1595 	return 0;
1596 }
1597 
tls_opt_dtls_handshake_timeout_set(struct tls_context * context,const void * optval,socklen_t optlen,bool is_max)1598 static int tls_opt_dtls_handshake_timeout_set(struct tls_context *context,
1599 					      const void *optval,
1600 					      socklen_t optlen, bool is_max)
1601 {
1602 	uint32_t *val = (uint32_t *)optval;
1603 
1604 	if (!optval) {
1605 		return -EINVAL;
1606 	}
1607 
1608 	if (sizeof(uint32_t) != optlen) {
1609 		return -EINVAL;
1610 	}
1611 
1612 	/* If mbedTLS context not inited, it will
1613 	 * use these values upon init.
1614 	 */
1615 	if (is_max) {
1616 		context->options.dtls_handshake_timeout_max = *val;
1617 	} else {
1618 		context->options.dtls_handshake_timeout_min = *val;
1619 	}
1620 
1621 	/* If mbedTLS context already inited, we need to
1622 	 * update mbedTLS config for it to take effect
1623 	 */
1624 	mbedtls_ssl_conf_handshake_timeout(&context->config,
1625 			context->options.dtls_handshake_timeout_min,
1626 			context->options.dtls_handshake_timeout_max);
1627 
1628 	return 0;
1629 }
1630 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1631 
tls_opt_alpn_list_get(struct tls_context * context,void * optval,socklen_t * optlen)1632 static int tls_opt_alpn_list_get(struct tls_context *context,
1633 				 void *optval, socklen_t *optlen)
1634 {
1635 	const char **alpn_list = context->options.alpn_list;
1636 	int alpn_cnt, i = 0;
1637 	const char **ret_list = optval;
1638 
1639 	if (!ALPN_MAX_PROTOCOLS) {
1640 		return -EINVAL;
1641 	}
1642 
1643 	if (*optlen % sizeof(const char *) != 0 || *optlen == 0) {
1644 		return -EINVAL;
1645 	}
1646 
1647 	alpn_cnt = *optlen / sizeof(const char *);
1648 	while (alpn_list[i] != NULL) {
1649 		ret_list[i] = alpn_list[i];
1650 
1651 		if (++i == alpn_cnt) {
1652 			break;
1653 		}
1654 	}
1655 
1656 	*optlen = i * sizeof(const char *);
1657 
1658 	return 0;
1659 }
1660 
tls_opt_session_cache_set(struct tls_context * context,const void * optval,socklen_t optlen)1661 static int tls_opt_session_cache_set(struct tls_context *context,
1662 				     const void *optval, socklen_t optlen)
1663 {
1664 	int *val = (int *)optval;
1665 
1666 	if (!optval) {
1667 		return -EINVAL;
1668 	}
1669 
1670 	if (sizeof(int) != optlen) {
1671 		return -EINVAL;
1672 	}
1673 
1674 	context->options.cache_enabled = (*val == TLS_SESSION_CACHE_ENABLED);
1675 
1676 	return 0;
1677 }
1678 
tls_opt_session_cache_get(struct tls_context * context,void * optval,socklen_t * optlen)1679 static int tls_opt_session_cache_get(struct tls_context *context,
1680 				     void *optval, socklen_t *optlen)
1681 {
1682 	int cache_enabled = context->options.cache_enabled ?
1683 			    TLS_SESSION_CACHE_ENABLED :
1684 			    TLS_SESSION_CACHE_DISABLED;
1685 
1686 	if (*optlen != sizeof(cache_enabled)) {
1687 		return -EINVAL;
1688 	}
1689 
1690 	*(int *)optval = cache_enabled;
1691 
1692 	return 0;
1693 }
1694 
tls_opt_session_cache_purge_set(struct tls_context * context,const void * optval,socklen_t optlen)1695 static int tls_opt_session_cache_purge_set(struct tls_context *context,
1696 					   const void *optval, socklen_t optlen)
1697 {
1698 	ARG_UNUSED(context);
1699 	ARG_UNUSED(optval);
1700 	ARG_UNUSED(optlen);
1701 
1702 	tls_session_purge();
1703 
1704 	return 0;
1705 }
1706 
tls_opt_peer_verify_set(struct tls_context * context,const void * optval,socklen_t optlen)1707 static int tls_opt_peer_verify_set(struct tls_context *context,
1708 				   const void *optval, socklen_t optlen)
1709 {
1710 	int *peer_verify;
1711 
1712 	if (!optval) {
1713 		return -EINVAL;
1714 	}
1715 
1716 	if (optlen != sizeof(int)) {
1717 		return -EINVAL;
1718 	}
1719 
1720 	peer_verify = (int *)optval;
1721 
1722 	if (*peer_verify != MBEDTLS_SSL_VERIFY_NONE &&
1723 	    *peer_verify != MBEDTLS_SSL_VERIFY_OPTIONAL &&
1724 	    *peer_verify != MBEDTLS_SSL_VERIFY_REQUIRED) {
1725 		return -EINVAL;
1726 	}
1727 
1728 	context->options.verify_level = *peer_verify;
1729 
1730 	return 0;
1731 }
1732 
tls_opt_cert_nocopy_set(struct tls_context * context,const void * optval,socklen_t optlen)1733 static int tls_opt_cert_nocopy_set(struct tls_context *context,
1734 				   const void *optval, socklen_t optlen)
1735 {
1736 	int *cert_nocopy;
1737 
1738 	if (!optval) {
1739 		return -EINVAL;
1740 	}
1741 
1742 	if (optlen != sizeof(int)) {
1743 		return -EINVAL;
1744 	}
1745 
1746 	cert_nocopy = (int *)optval;
1747 
1748 	if (*cert_nocopy != TLS_CERT_NOCOPY_NONE &&
1749 	    *cert_nocopy != TLS_CERT_NOCOPY_OPTIONAL) {
1750 		return -EINVAL;
1751 	}
1752 
1753 	context->options.cert_nocopy = *cert_nocopy;
1754 
1755 	return 0;
1756 }
1757 
tls_opt_dtls_role_set(struct tls_context * context,const void * optval,socklen_t optlen)1758 static int tls_opt_dtls_role_set(struct tls_context *context,
1759 				 const void *optval, socklen_t optlen)
1760 {
1761 	int *role;
1762 
1763 	if (!optval) {
1764 		return -EINVAL;
1765 	}
1766 
1767 	if (optlen != sizeof(int)) {
1768 		return -EINVAL;
1769 	}
1770 
1771 	role = (int *)optval;
1772 	if (*role != MBEDTLS_SSL_IS_CLIENT &&
1773 	    *role != MBEDTLS_SSL_IS_SERVER) {
1774 		return -EINVAL;
1775 	}
1776 
1777 	context->options.role = *role;
1778 
1779 	return 0;
1780 }
1781 
protocol_check(int family,int type,int * proto)1782 static int protocol_check(int family, int type, int *proto)
1783 {
1784 	if (family != AF_INET && family != AF_INET6) {
1785 		return -EAFNOSUPPORT;
1786 	}
1787 
1788 	if (*proto >= IPPROTO_TLS_1_0 && *proto <= IPPROTO_TLS_1_2) {
1789 		if (type != SOCK_STREAM) {
1790 			return -EPROTOTYPE;
1791 		}
1792 
1793 		*proto = IPPROTO_TCP;
1794 	} else if (*proto >= IPPROTO_DTLS_1_0 && *proto <= IPPROTO_DTLS_1_2) {
1795 		if (!IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS)) {
1796 			return -EPROTONOSUPPORT;
1797 		}
1798 
1799 		if (type != SOCK_DGRAM) {
1800 			return -EPROTOTYPE;
1801 		}
1802 
1803 		*proto = IPPROTO_UDP;
1804 	} else {
1805 		return -EPROTONOSUPPORT;
1806 	}
1807 
1808 	return 0;
1809 }
1810 
ztls_socket(int family,int type,int proto)1811 static int ztls_socket(int family, int type, int proto)
1812 {
1813 	enum net_ip_protocol_secure tls_proto = proto;
1814 	int fd = z_reserve_fd();
1815 	int sock = -1;
1816 	int ret;
1817 	struct tls_context *ctx;
1818 
1819 	if (fd < 0) {
1820 		return -1;
1821 	}
1822 
1823 	ret = protocol_check(family, type, &proto);
1824 	if (ret < 0) {
1825 		errno = -ret;
1826 		goto free_fd;
1827 	}
1828 
1829 	ctx = tls_alloc();
1830 	if (ctx == NULL) {
1831 		errno = ENOMEM;
1832 		goto free_fd;
1833 	}
1834 
1835 	sock = zsock_socket(family, type, proto);
1836 	if (sock < 0) {
1837 		goto release_tls;
1838 	}
1839 
1840 	ctx->tls_version = tls_proto;
1841 	ctx->type = (proto == IPPROTO_TCP) ? SOCK_STREAM : SOCK_DGRAM;
1842 	ctx->sock = sock;
1843 
1844 	z_finalize_fd(
1845 		fd, ctx, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable);
1846 
1847 	return fd;
1848 
1849 release_tls:
1850 	(void)tls_release(ctx);
1851 
1852 free_fd:
1853 	z_free_fd(fd);
1854 
1855 	return -1;
1856 }
1857 
ztls_close_ctx(struct tls_context * ctx)1858 int ztls_close_ctx(struct tls_context *ctx)
1859 {
1860 	int ret, err = 0;
1861 
1862 	/* Try to send close notification. */
1863 	ctx->flags = 0;
1864 
1865 	(void)mbedtls_ssl_close_notify(&ctx->ssl);
1866 
1867 	err = tls_release(ctx);
1868 	ret = zsock_close(ctx->sock);
1869 
1870 	/* In case close fails, we propagate errno value set by close.
1871 	 * In case close succeeds, but tls_release fails, set errno
1872 	 * according to tls_release return value.
1873 	 */
1874 	if (ret == 0 && err < 0) {
1875 		errno = -err;
1876 		ret = -1;
1877 	}
1878 
1879 	return ret;
1880 }
1881 
ztls_connect_ctx(struct tls_context * ctx,const struct sockaddr * addr,socklen_t addrlen)1882 int ztls_connect_ctx(struct tls_context *ctx, const struct sockaddr *addr,
1883 		     socklen_t addrlen)
1884 {
1885 	int ret;
1886 	int sock_flags;
1887 
1888 	sock_flags = zsock_fcntl(ctx->sock, F_GETFL, 0);
1889 	if (sock_flags < 0) {
1890 		return -EIO;
1891 	}
1892 
1893 	if (sock_flags & O_NONBLOCK) {
1894 		(void)zsock_fcntl(ctx->sock, F_SETFL,
1895 				  sock_flags & ~O_NONBLOCK);
1896 	}
1897 
1898 	ret = zsock_connect(ctx->sock, addr, addrlen);
1899 	if (ret < 0) {
1900 		return ret;
1901 	}
1902 
1903 	if (sock_flags & O_NONBLOCK) {
1904 		(void)zsock_fcntl(ctx->sock, F_SETFL, sock_flags);
1905 	}
1906 
1907 	if (ctx->type == SOCK_STREAM) {
1908 		/* Do the handshake for TLS, not DTLS. */
1909 		ret = tls_mbedtls_init(ctx, false);
1910 		if (ret < 0) {
1911 			goto error;
1912 		}
1913 
1914 		/* Do not use any socket flags during the handshake. */
1915 		ctx->flags = 0;
1916 
1917 		tls_session_restore(ctx, addr, addrlen);
1918 
1919 		/* TODO For simplicity, TLS handshake blocks the socket
1920 		 * even for non-blocking socket.
1921 		 */
1922 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
1923 		if (ret < 0) {
1924 			goto error;
1925 		}
1926 
1927 		tls_session_store(ctx, addr, addrlen);
1928 	} else {
1929 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
1930 		/* Just store the address. */
1931 		dtls_peer_address_set(ctx, addr, addrlen);
1932 #else
1933 		ret = -ENOTSUP;
1934 		goto error;
1935 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
1936 	}
1937 
1938 	return 0;
1939 
1940 error:
1941 	errno = -ret;
1942 	return -1;
1943 }
1944 
ztls_accept_ctx(struct tls_context * parent,struct sockaddr * addr,socklen_t * addrlen)1945 int ztls_accept_ctx(struct tls_context *parent, struct sockaddr *addr,
1946 		    socklen_t *addrlen)
1947 {
1948 	struct tls_context *child = NULL;
1949 	int ret, err, fd, sock;
1950 
1951 	fd = z_reserve_fd();
1952 	if (fd < 0) {
1953 		return -1;
1954 	}
1955 
1956 
1957 	k_mutex_unlock(parent->lock);
1958 	sock = zsock_accept(parent->sock, addr, addrlen);
1959 	k_mutex_lock(parent->lock, K_FOREVER);
1960 	if (sock < 0) {
1961 		ret = -errno;
1962 		goto error;
1963 	}
1964 
1965 	child = tls_clone(parent);
1966 	if (child == NULL) {
1967 		ret = -ENOMEM;
1968 		goto error;
1969 	}
1970 
1971 	z_finalize_fd(
1972 		fd, child, (const struct fd_op_vtable *)&tls_sock_fd_op_vtable);
1973 
1974 	child->sock = sock;
1975 
1976 	ret = tls_mbedtls_init(child, true);
1977 	if (ret < 0) {
1978 		goto error;
1979 	}
1980 
1981 	/* Do not use any socket flags during the handshake. */
1982 	child->flags = 0;
1983 
1984 	/* TODO For simplicity, TLS handshake blocks the socket even for
1985 	 * non-blocking socket.
1986 	 */
1987 	ret = tls_mbedtls_handshake(child, K_FOREVER);
1988 	if (ret < 0) {
1989 		goto error;
1990 	}
1991 
1992 	return fd;
1993 
1994 error:
1995 	if (child != NULL) {
1996 		err = tls_release(child);
1997 		__ASSERT(err == 0, "TLS context release failed");
1998 	}
1999 
2000 	if (sock >= 0) {
2001 		err = zsock_close(sock);
2002 		__ASSERT(err == 0, "Child socket close failed");
2003 	}
2004 
2005 	z_free_fd(fd);
2006 
2007 	errno = -ret;
2008 	return -1;
2009 }
2010 
send_tls(struct tls_context * ctx,const void * buf,size_t len,int flags)2011 static ssize_t send_tls(struct tls_context *ctx, const void *buf,
2012 			size_t len, int flags)
2013 {
2014 	const bool is_block = is_blocking(ctx->sock, flags);
2015 	k_timeout_t timeout;
2016 	uint64_t end;
2017 	int ret;
2018 
2019 	if (!is_block) {
2020 		timeout = K_NO_WAIT;
2021 	} else {
2022 		timeout = ctx->options.timeout_tx;
2023 	}
2024 
2025 	end = sys_clock_timeout_end_calc(timeout);
2026 
2027 	do {
2028 		ret = mbedtls_ssl_write(&ctx->ssl, buf, len);
2029 		if (ret >= 0) {
2030 			return ret;
2031 		}
2032 
2033 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2034 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2035 		    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2036 		    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2037 			int timeout_ms;
2038 
2039 			if (!is_block) {
2040 				errno = EAGAIN;
2041 				break;
2042 			}
2043 
2044 			/* Blocking timeout. */
2045 			timeout_recalc(end, &timeout);
2046 			if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2047 				errno = EAGAIN;
2048 				break;
2049 			}
2050 
2051 			/* Block. */
2052 			timeout_ms = timeout_to_ms(&timeout);
2053 			ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2054 			if (ret != 0) {
2055 				/* Retry. */
2056 				break;
2057 			}
2058 
2059 		} else {
2060 			(void)tls_mbedtls_reset(ctx);
2061 			errno = EIO;
2062 			break;
2063 		}
2064 	} while (true);
2065 
2066 	return -1;
2067 }
2068 
2069 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
sendto_dtls_client(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2070 static ssize_t sendto_dtls_client(struct tls_context *ctx, const void *buf,
2071 				  size_t len, int flags,
2072 				  const struct sockaddr *dest_addr,
2073 				  socklen_t addrlen)
2074 {
2075 	int ret;
2076 
2077 	if (!dest_addr) {
2078 		/* No address provided, check if we have stored one,
2079 		 * otherwise return error.
2080 		 */
2081 		if (ctx->dtls_peer_addrlen == 0) {
2082 			ret = -EDESTADDRREQ;
2083 			goto error;
2084 		}
2085 	} else if (ctx->dtls_peer_addrlen == 0) {
2086 		/* Address provided and no peer address stored. */
2087 		dtls_peer_address_set(ctx, dest_addr, addrlen);
2088 	} else if (!dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2089 		/* Address provided but it does not match stored one */
2090 		ret = -EISCONN;
2091 		goto error;
2092 	}
2093 
2094 	if (!ctx->is_initialized) {
2095 		ret = tls_mbedtls_init(ctx, false);
2096 		if (ret < 0) {
2097 			goto error;
2098 		}
2099 	}
2100 
2101 	if (!is_handshake_complete(ctx)) {
2102 		tls_session_restore(ctx, &ctx->dtls_peer_addr,
2103 				    ctx->dtls_peer_addrlen);
2104 
2105 		/* TODO For simplicity, TLS handshake blocks the socket even for
2106 		 * non-blocking socket.
2107 		 */
2108 		ret = tls_mbedtls_handshake(ctx, K_FOREVER);
2109 		if (ret < 0) {
2110 			goto error;
2111 		}
2112 
2113 		tls_session_store(ctx, &ctx->dtls_peer_addr,
2114 				  ctx->dtls_peer_addrlen);
2115 	}
2116 
2117 	return send_tls(ctx, buf, len, flags);
2118 
2119 error:
2120 	errno = -ret;
2121 	return -1;
2122 }
2123 
sendto_dtls_server(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2124 static ssize_t sendto_dtls_server(struct tls_context *ctx, const void *buf,
2125 				  size_t len, int flags,
2126 				  const struct sockaddr *dest_addr,
2127 				  socklen_t addrlen)
2128 {
2129 	/* For DTLS server, require to have established DTLS connection
2130 	 * in order to send data.
2131 	 */
2132 	if (!is_handshake_complete(ctx)) {
2133 		errno = ENOTCONN;
2134 		return -1;
2135 	}
2136 
2137 	/* Verify we are sending to a peer that we have connection with. */
2138 	if (dest_addr &&
2139 	    !dtls_is_peer_addr_valid(ctx, dest_addr, addrlen) != 0) {
2140 		errno = EISCONN;
2141 		return -1;
2142 	}
2143 
2144 	return send_tls(ctx, buf, len, flags);
2145 }
2146 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2147 
ztls_sendto_ctx(struct tls_context * ctx,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2148 ssize_t ztls_sendto_ctx(struct tls_context *ctx, const void *buf, size_t len,
2149 			int flags, const struct sockaddr *dest_addr,
2150 			socklen_t addrlen)
2151 {
2152 	ctx->flags = flags;
2153 
2154 	/* TLS */
2155 	if (ctx->type == SOCK_STREAM) {
2156 		return send_tls(ctx, buf, len, flags);
2157 	}
2158 
2159 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2160 	/* DTLS */
2161 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2162 		return sendto_dtls_server(ctx, buf, len, flags,
2163 					  dest_addr, addrlen);
2164 	}
2165 
2166 	return sendto_dtls_client(ctx, buf, len, flags, dest_addr, addrlen);
2167 #else
2168 	errno = ENOTSUP;
2169 	return -1;
2170 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2171 }
2172 
ztls_sendmsg_ctx(struct tls_context * ctx,const struct msghdr * msg,int flags)2173 ssize_t ztls_sendmsg_ctx(struct tls_context *ctx, const struct msghdr *msg,
2174 			 int flags)
2175 {
2176 	ssize_t len;
2177 	ssize_t ret;
2178 	int i;
2179 
2180 	if (IS_ENABLED(CONFIG_NET_SOCKETS_ENABLE_DTLS) &&
2181 	    ctx->type == SOCK_DGRAM) {
2182 		/*
2183 		 * Current mbedTLS API (i.e. mbedtls_ssl_write()) allows only to send a single
2184 		 * contiguous buffer. This means that gather write using sendmsg() can only be
2185 		 * handled correctly if there is a single non-empty buffer in msg->msg_iov.
2186 		 */
2187 		if (msghdr_non_empty_iov_count(msg) > 1) {
2188 			errno = EMSGSIZE;
2189 			return -1;
2190 		}
2191 	}
2192 
2193 	len = 0;
2194 	if (msg) {
2195 		for (i = 0; i < msg->msg_iovlen; i++) {
2196 			struct iovec *vec = msg->msg_iov + i;
2197 			size_t sent = 0;
2198 
2199 			if (vec->iov_len == 0) {
2200 				continue;
2201 			}
2202 
2203 			while (sent < vec->iov_len) {
2204 				uint8_t *ptr = (uint8_t *)vec->iov_base + sent;
2205 
2206 				ret = ztls_sendto_ctx(ctx, ptr,
2207 					    vec->iov_len - sent, flags,
2208 					    msg->msg_name, msg->msg_namelen);
2209 				if (ret < 0) {
2210 					return ret;
2211 				}
2212 				sent += ret;
2213 			}
2214 			len += sent;
2215 		}
2216 	}
2217 
2218 	return len;
2219 }
2220 
recv_tls(struct tls_context * ctx,void * buf,size_t max_len,int flags)2221 static ssize_t recv_tls(struct tls_context *ctx, void *buf,
2222 			size_t max_len, int flags)
2223 {
2224 	size_t recv_len = 0;
2225 	const bool waitall = flags & ZSOCK_MSG_WAITALL;
2226 	const bool is_block = is_blocking(ctx->sock, flags);
2227 	k_timeout_t timeout;
2228 	uint64_t end;
2229 	int ret;
2230 
2231 	if (!is_block) {
2232 		timeout = K_NO_WAIT;
2233 	} else {
2234 		timeout = ctx->options.timeout_rx;
2235 	}
2236 
2237 	end = sys_clock_timeout_end_calc(timeout);
2238 
2239 	do {
2240 		size_t read_len = max_len - recv_len;
2241 
2242 		ret = mbedtls_ssl_read(&ctx->ssl, (uint8_t *)buf + recv_len,
2243 				       read_len);
2244 		if (ret < 0) {
2245 			if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2246 				/* Peer notified that it's closing the
2247 				 * connection.
2248 				 */
2249 				break;
2250 			}
2251 
2252 			if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) {
2253 				/* Client reconnect on the same socket is not
2254 				 * supported. See mbedtls_ssl_read API
2255 				 * documentation.
2256 				 */
2257 				break;
2258 			}
2259 
2260 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2261 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2262 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2263 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2264 				int timeout_ms;
2265 
2266 				if (!is_block) {
2267 					ret = -EAGAIN;
2268 					goto err;
2269 				}
2270 
2271 				/* Blocking timeout. */
2272 				timeout_recalc(end, &timeout);
2273 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2274 					ret = -EAGAIN;
2275 					goto err;
2276 				}
2277 
2278 				timeout_ms = timeout_to_ms(&timeout);
2279 
2280 				/* Block. */
2281 				k_mutex_unlock(ctx->lock);
2282 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2283 				k_mutex_lock(ctx->lock, K_FOREVER);
2284 
2285 				if (ret == 0) {
2286 					/* Retry. */
2287 					continue;
2288 				}
2289 			} else {
2290 				ret = -EIO;
2291 			}
2292 
2293 err:
2294 			errno = -ret;
2295 			return -1;
2296 		}
2297 
2298 		if (ret == 0) {
2299 			break;
2300 		}
2301 
2302 		recv_len += ret;
2303 	} while ((recv_len == 0) || (waitall && (recv_len < max_len)));
2304 
2305 	return recv_len;
2306 }
2307 
2308 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
recvfrom_dtls_common(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2309 static ssize_t recvfrom_dtls_common(struct tls_context *ctx, void *buf,
2310 				    size_t max_len, int flags,
2311 				    struct sockaddr *src_addr,
2312 				    socklen_t *addrlen)
2313 {
2314 	int ret;
2315 	bool is_block = is_blocking(ctx->sock, flags);
2316 	k_timeout_t timeout;
2317 	uint64_t end;
2318 
2319 	if (!is_block) {
2320 		timeout = K_NO_WAIT;
2321 	} else {
2322 		timeout = ctx->options.timeout_rx;
2323 	}
2324 
2325 	end = sys_clock_timeout_end_calc(timeout);
2326 
2327 	do {
2328 		size_t remaining;
2329 
2330 		ret = mbedtls_ssl_read(&ctx->ssl, buf, max_len);
2331 		if (ret < 0) {
2332 			if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2333 			    ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
2334 			    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
2335 			    ret ==  MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
2336 				int timeout_dtls, timeout_sock, timeout_ms;
2337 
2338 				if (!is_block) {
2339 					return ret;
2340 				}
2341 
2342 				/* Blocking timeout. */
2343 				timeout_recalc(end, &timeout);
2344 				if (K_TIMEOUT_EQ(timeout, K_NO_WAIT)) {
2345 					return ret;
2346 				}
2347 
2348 				timeout_dtls = dtls_get_remaining_timeout(ctx);
2349 				timeout_sock = timeout_to_ms(&timeout);
2350 				if (timeout_dtls == SYS_FOREVER_MS ||
2351 				    timeout_sock == SYS_FOREVER_MS) {
2352 					timeout_ms = MAX(timeout_dtls, timeout_sock);
2353 				} else {
2354 					timeout_ms = MIN(timeout_dtls, timeout_sock);
2355 				}
2356 
2357 				/* Block. */
2358 				k_mutex_unlock(ctx->lock);
2359 				ret = wait_for_reason(ctx->sock, timeout_ms, ret);
2360 				k_mutex_lock(ctx->lock, K_FOREVER);
2361 
2362 				if (ret == 0) {
2363 					/* Retry. */
2364 					continue;
2365 				} else {
2366 					return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2367 				}
2368 			} else {
2369 				return ret;
2370 			}
2371 		}
2372 
2373 		if (src_addr && addrlen) {
2374 			dtls_peer_address_get(ctx, src_addr, addrlen);
2375 		}
2376 
2377 		/* mbedtls_ssl_get_bytes_avail() indicate the data length
2378 		 * remaining in the current datagram.
2379 		 */
2380 		remaining = mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2381 
2382 		/* No more data in the datagram, or dummy read. */
2383 		if ((remaining == 0) || (max_len == 0)) {
2384 			return ret;
2385 		}
2386 
2387 		if (flags & ZSOCK_MSG_TRUNC) {
2388 			ret += remaining;
2389 		}
2390 
2391 		for (int i = 0; i < remaining; i++) {
2392 			uint8_t byte;
2393 			int err;
2394 
2395 			err = mbedtls_ssl_read(&ctx->ssl, &byte, sizeof(byte));
2396 			if (err <= 0) {
2397 				NET_ERR("Error while flushing the rest of the"
2398 					" datagram, err %d", err);
2399 				ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
2400 				break;
2401 			}
2402 		}
2403 
2404 		break;
2405 	} while (true);
2406 
2407 
2408 	return ret;
2409 }
2410 
recvfrom_dtls_client(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2411 static ssize_t recvfrom_dtls_client(struct tls_context *ctx, void *buf,
2412 				    size_t max_len, int flags,
2413 				    struct sockaddr *src_addr,
2414 				    socklen_t *addrlen)
2415 {
2416 	int ret;
2417 
2418 	if (!is_handshake_complete(ctx)) {
2419 		ret = -ENOTCONN;
2420 		goto error;
2421 	}
2422 
2423 	ret = recvfrom_dtls_common(ctx, buf, max_len, flags, src_addr, addrlen);
2424 	if (ret >= 0) {
2425 		return ret;
2426 	}
2427 
2428 	switch (ret) {
2429 	case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2430 		/* Peer notified that it's closing the connection. */
2431 		ret = tls_mbedtls_reset(ctx);
2432 		if (ret == 0) {
2433 			ret = -ENOTCONN;
2434 		} else {
2435 			ret = -ENOMEM;
2436 		}
2437 		break;
2438 
2439 	case MBEDTLS_ERR_SSL_TIMEOUT:
2440 		(void)mbedtls_ssl_close_notify(&ctx->ssl);
2441 		ret = -ETIMEDOUT;
2442 		break;
2443 
2444 	case MBEDTLS_ERR_SSL_WANT_READ:
2445 	case MBEDTLS_ERR_SSL_WANT_WRITE:
2446 	case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2447 	case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2448 		ret = -EAGAIN;
2449 		break;
2450 
2451 	default:
2452 		ret = -EIO;
2453 		break;
2454 	}
2455 
2456 error:
2457 	errno = -ret;
2458 	return -1;
2459 }
2460 
recvfrom_dtls_server(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2461 static ssize_t recvfrom_dtls_server(struct tls_context *ctx, void *buf,
2462 				    size_t max_len, int flags,
2463 				    struct sockaddr *src_addr,
2464 				    socklen_t *addrlen)
2465 {
2466 	int ret;
2467 	bool repeat;
2468 	k_timeout_t timeout;
2469 
2470 	if (!ctx->is_initialized) {
2471 		ret = tls_mbedtls_init(ctx, true);
2472 		if (ret < 0) {
2473 			goto error;
2474 		}
2475 	}
2476 
2477 	if (is_blocking(ctx->sock, flags)) {
2478 		timeout = ctx->options.timeout_rx;
2479 	} else {
2480 		timeout = K_NO_WAIT;
2481 	}
2482 
2483 	/* Loop to enable DTLS reconnection for servers without closing
2484 	 * a socket.
2485 	 */
2486 	do {
2487 		repeat = false;
2488 
2489 		if (!is_handshake_complete(ctx)) {
2490 			ret = tls_mbedtls_handshake(ctx, timeout);
2491 			if (ret < 0) {
2492 				/* In case of EAGAIN, just exit. */
2493 				if (ret == -EAGAIN) {
2494 					break;
2495 				}
2496 
2497 				ret = tls_mbedtls_reset(ctx);
2498 				if (ret == 0) {
2499 					repeat = true;
2500 				} else {
2501 					ret = -ENOMEM;
2502 				}
2503 
2504 				continue;
2505 			}
2506 		}
2507 
2508 		ret = recvfrom_dtls_common(ctx, buf, max_len, flags,
2509 					   src_addr, addrlen);
2510 		if (ret >= 0) {
2511 			return ret;
2512 		}
2513 
2514 		switch (ret) {
2515 		case MBEDTLS_ERR_SSL_TIMEOUT:
2516 			(void)mbedtls_ssl_close_notify(&ctx->ssl);
2517 			__fallthrough;
2518 			/* fallthrough */
2519 
2520 		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
2521 		case MBEDTLS_ERR_SSL_CLIENT_RECONNECT:
2522 			ret = tls_mbedtls_reset(ctx);
2523 			if (ret == 0) {
2524 				repeat = true;
2525 			} else {
2526 				ret = -ENOMEM;
2527 			}
2528 			break;
2529 
2530 		case MBEDTLS_ERR_SSL_WANT_READ:
2531 		case MBEDTLS_ERR_SSL_WANT_WRITE:
2532 		case MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS:
2533 		case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
2534 			ret = -EAGAIN;
2535 			break;
2536 
2537 		default:
2538 			ret = -EIO;
2539 			break;
2540 		}
2541 	} while (repeat);
2542 
2543 error:
2544 	errno = -ret;
2545 	return -1;
2546 }
2547 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2548 
ztls_recvfrom_ctx(struct tls_context * ctx,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2549 ssize_t ztls_recvfrom_ctx(struct tls_context *ctx, void *buf, size_t max_len,
2550 			  int flags, struct sockaddr *src_addr,
2551 			  socklen_t *addrlen)
2552 {
2553 	if (flags & ZSOCK_MSG_PEEK) {
2554 		/* TODO mbedTLS does not support 'peeking' This could be
2555 		 * bypassed by having intermediate buffer for peeking
2556 		 */
2557 		errno = ENOTSUP;
2558 		return -1;
2559 	}
2560 
2561 	ctx->flags = flags;
2562 
2563 	/* TLS */
2564 	if (ctx->type == SOCK_STREAM) {
2565 		return recv_tls(ctx, buf, max_len, flags);
2566 	}
2567 
2568 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
2569 	/* DTLS */
2570 	if (ctx->options.role == MBEDTLS_SSL_IS_SERVER) {
2571 		return recvfrom_dtls_server(ctx, buf, max_len, flags,
2572 					    src_addr, addrlen);
2573 	}
2574 
2575 	return recvfrom_dtls_client(ctx, buf, max_len, flags,
2576 				    src_addr, addrlen);
2577 #else
2578 	errno = ENOTSUP;
2579 	return -1;
2580 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
2581 }
2582 
ztls_poll_prepare_pollin(struct tls_context * ctx)2583 static int ztls_poll_prepare_pollin(struct tls_context *ctx)
2584 {
2585 	/* If there already is mbedTLS data to read, there is no
2586 	 * need to set the k_poll_event object. Return EALREADY
2587 	 * so we won't block in the k_poll.
2588 	 */
2589 	if (!ctx->is_listening) {
2590 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
2591 			return -EALREADY;
2592 		}
2593 	}
2594 
2595 	return 0;
2596 }
2597 
ztls_poll_prepare_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev,struct k_poll_event * pev_end)2598 static int ztls_poll_prepare_ctx(struct tls_context *ctx,
2599 				 struct zsock_pollfd *pfd,
2600 				 struct k_poll_event **pev,
2601 				 struct k_poll_event *pev_end)
2602 {
2603 	const struct fd_op_vtable *vtable;
2604 	struct k_mutex *lock;
2605 	void *obj;
2606 	int ret;
2607 	short events = pfd->events;
2608 
2609 	/* DTLS client should wait for the handshake to complete before
2610 	 * it actually starts to poll for data.
2611 	 */
2612 	if ((pfd->events & ZSOCK_POLLIN) && (ctx->type == SOCK_DGRAM) &&
2613 	    (ctx->options.role == MBEDTLS_SSL_IS_CLIENT) &&
2614 	    !is_handshake_complete(ctx)) {
2615 		(*pev)->obj = &ctx->tls_established;
2616 		(*pev)->type = K_POLL_TYPE_SEM_AVAILABLE;
2617 		(*pev)->mode = K_POLL_MODE_NOTIFY_ONLY;
2618 		(*pev)->state = K_POLL_STATE_NOT_READY;
2619 		(*pev)++;
2620 
2621 		/* Since k_poll_event is configured by the TLS layer in this
2622 		 * case, do not forward ZSOCK_POLLIN to the underlying socket.
2623 		 */
2624 		pfd->events &= ~ZSOCK_POLLIN;
2625 	}
2626 
2627 	obj = z_get_fd_obj_and_vtable(
2628 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
2629 	if (obj == NULL) {
2630 		ret = -EBADF;
2631 		goto exit;
2632 	}
2633 
2634 	(void)k_mutex_lock(lock, K_FOREVER);
2635 
2636 	ret = z_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
2637 				   pfd, pev, pev_end);
2638 	if (ret != 0) {
2639 		goto exit;
2640 	}
2641 
2642 	if (pfd->events & ZSOCK_POLLIN) {
2643 		ret = ztls_poll_prepare_pollin(ctx);
2644 	}
2645 
2646 exit:
2647 	/* Restore original events. */
2648 	pfd->events = events;
2649 
2650 	k_mutex_unlock(lock);
2651 
2652 	return ret;
2653 }
2654 
ztls_socket_data_check(struct tls_context * ctx)2655 static int ztls_socket_data_check(struct tls_context *ctx)
2656 {
2657 	int ret;
2658 
2659 	ctx->flags = ZSOCK_MSG_DONTWAIT;
2660 
2661 	ret = mbedtls_ssl_read(&ctx->ssl, NULL, 0);
2662 	if (ret < 0) {
2663 		if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
2664 			/* Don't reset the context for STREAM socket - the
2665 			 * application needs to reopen the socket anyway, and
2666 			 * resetting the context would result in an error instead
2667 			 * of 0 in a consecutive recv() call.
2668 			 */
2669 			if (ctx->type == SOCK_DGRAM) {
2670 				ret = tls_mbedtls_reset(ctx);
2671 				if (ret != 0) {
2672 					return -ENOMEM;
2673 				}
2674 			}
2675 
2676 			return -ENOTCONN;
2677 		}
2678 
2679 		if (ret == MBEDTLS_ERR_SSL_WANT_READ ||
2680 		    ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
2681 			return 0;
2682 		}
2683 
2684 		/* Treat any other error as fatal. */
2685 		return -EIO;
2686 	}
2687 
2688 	return mbedtls_ssl_get_bytes_avail(&ctx->ssl);
2689 }
2690 
ztls_poll_update_pollin(int fd,struct tls_context * ctx,struct zsock_pollfd * pfd)2691 static int ztls_poll_update_pollin(int fd, struct tls_context *ctx,
2692 				   struct zsock_pollfd *pfd)
2693 {
2694 	int ret;
2695 
2696 	if (!ctx->is_listening) {
2697 		/* Already had TLS data to read on socket. */
2698 		if (mbedtls_ssl_get_bytes_avail(&ctx->ssl) > 0) {
2699 			pfd->revents |= ZSOCK_POLLIN;
2700 			goto next;
2701 		}
2702 	}
2703 
2704 	if (!(pfd->revents & ZSOCK_POLLIN)) {
2705 		/* No new data on a socket. */
2706 		goto next;
2707 	}
2708 
2709 	if (ctx->is_listening) {
2710 		goto next;
2711 	}
2712 
2713 	ret = ztls_socket_data_check(ctx);
2714 	if (ret == -ENOTCONN || (pfd->revents & ZSOCK_POLLHUP)) {
2715 		/* Datagram does not return 0 on consecutive recv, but an error
2716 		 * code, hence clear POLLIN.
2717 		 */
2718 		if (ctx->type == SOCK_DGRAM) {
2719 			pfd->revents &= ~ZSOCK_POLLIN;
2720 		}
2721 		pfd->revents |= ZSOCK_POLLHUP;
2722 		goto next;
2723 	} else if (ret < 0) {
2724 		pfd->revents |= ZSOCK_POLLERR;
2725 		goto next;
2726 	} else if (ret == 0) {
2727 		goto again;
2728 	}
2729 
2730 next:
2731 	return 0;
2732 
2733 again:
2734 	/* Received encrypted data, but still not enough
2735 	 * to decrypt it and return data through socket,
2736 	 * ask for retry if no other events are set.
2737 	 */
2738 	pfd->revents &= ~ZSOCK_POLLIN;
2739 
2740 	return -EAGAIN;
2741 }
2742 
ztls_poll_update_ctx(struct tls_context * ctx,struct zsock_pollfd * pfd,struct k_poll_event ** pev)2743 static int ztls_poll_update_ctx(struct tls_context *ctx,
2744 				struct zsock_pollfd *pfd,
2745 				struct k_poll_event **pev)
2746 {
2747 	const struct fd_op_vtable *vtable;
2748 	struct k_mutex *lock;
2749 	void *obj;
2750 	int ret;
2751 	short events = pfd->events;
2752 
2753 	obj = z_get_fd_obj_and_vtable(
2754 		ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
2755 	if (obj == NULL) {
2756 		return -EBADF;
2757 	}
2758 
2759 	(void)k_mutex_lock(lock, K_FOREVER);
2760 
2761 	/* Check if the socket was waiting for the handshake to complete. */
2762 	if ((pfd->events & ZSOCK_POLLIN) &&
2763 	    ((*pev)->obj == &ctx->tls_established)) {
2764 		/* In case handshake is complete, reconfigure the k_poll_event
2765 		 * to monitor the underlying socket now.
2766 		 */
2767 		if ((*pev)->state != K_POLL_STATE_NOT_READY) {
2768 			ret = z_fdtable_call_ioctl(vtable, obj,
2769 						   ZFD_IOCTL_POLL_PREPARE,
2770 						   pfd, pev, *pev + 1);
2771 			if (ret != 0 && ret != -EALREADY) {
2772 				goto out;
2773 			}
2774 
2775 			/* Return -EAGAIN to signal to poll() that it should
2776 			 * make another iteration with the event reconfigured
2777 			 * above (if needed).
2778 			 */
2779 			ret = -EAGAIN;
2780 			goto out;
2781 		}
2782 
2783 		/* Handshake still not ready - skip ZSOCK_POLLIN verification
2784 		 * for the underlying socket.
2785 		 */
2786 		(*pev)++;
2787 		pfd->events &= ~ZSOCK_POLLIN;
2788 	}
2789 
2790 	ret = z_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_UPDATE,
2791 				   pfd, pev);
2792 	if (ret != 0) {
2793 		goto exit;
2794 	}
2795 
2796 	if (pfd->events & ZSOCK_POLLIN) {
2797 		ret = ztls_poll_update_pollin(pfd->fd, ctx, pfd);
2798 		if (ret == -EAGAIN && pfd->revents != 0) {
2799 			(*pev - 1)->state = K_POLL_STATE_NOT_READY;
2800 			goto exit;
2801 		}
2802 	}
2803 exit:
2804 	/* Restore original events. */
2805 	pfd->events = events;
2806 
2807 out:
2808 	k_mutex_unlock(lock);
2809 
2810 	return ret;
2811 }
2812 
2813 /* Return true if needed to retry rightoff or false otherwise. */
poll_offload_dtls_client_retry(struct tls_context * ctx,struct zsock_pollfd * pfd)2814 static bool poll_offload_dtls_client_retry(struct tls_context *ctx,
2815 					   struct zsock_pollfd *pfd)
2816 {
2817 	/* DTLS client should wait for the handshake to complete before it
2818 	 * reports that data is ready.
2819 	 */
2820 	if ((ctx->type != SOCK_DGRAM) ||
2821 	    (ctx->options.role != MBEDTLS_SSL_IS_CLIENT)) {
2822 		return false;
2823 	}
2824 
2825 	if (ctx->handshake_in_progress) {
2826 		/* Add some sleep to allow lower priority threads to proceed
2827 		 * with handshake.
2828 		 */
2829 		k_msleep(10);
2830 
2831 		pfd->revents &= ~ZSOCK_POLLIN;
2832 		return true;
2833 	} else if (!is_handshake_complete(ctx)) {
2834 		uint8_t byte;
2835 		int ret;
2836 
2837 		/* Handshake didn't start yet - just drop the incoming data -
2838 		 * it's the client who should initiate the handshake.
2839 		 */
2840 		ret = zsock_recv(ctx->sock, &byte, sizeof(byte),
2841 				 ZSOCK_MSG_DONTWAIT);
2842 		if (ret < 0) {
2843 			pfd->revents |= ZSOCK_POLLERR;
2844 		}
2845 
2846 		pfd->revents &= ~ZSOCK_POLLIN;
2847 		return true;
2848 	}
2849 
2850 	/* Handshake complete, just proceed. */
2851 	return false;
2852 }
2853 
ztls_poll_offload(struct zsock_pollfd * fds,int nfds,int timeout)2854 static int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int timeout)
2855 {
2856 	int fd_backup[CONFIG_NET_SOCKETS_POLL_MAX];
2857 	const struct fd_op_vtable *vtable;
2858 	void *ctx;
2859 	int ret = 0;
2860 	int result;
2861 	int i;
2862 	bool retry;
2863 	int remaining;
2864 	uint32_t entry = k_uptime_get_32();
2865 
2866 	/* Overwrite TLS file descriptors with underlying ones. */
2867 	for (i = 0; i < nfds; i++) {
2868 		fd_backup[i] = fds[i].fd;
2869 
2870 		ctx = z_get_fd_obj(fds[i].fd,
2871 				   (const struct fd_op_vtable *)
2872 						     &tls_sock_fd_op_vtable,
2873 				   0);
2874 		if (ctx == NULL) {
2875 			continue;
2876 		}
2877 
2878 		if (fds[i].events & ZSOCK_POLLIN) {
2879 			ret = ztls_poll_prepare_pollin(ctx);
2880 			/* In case data is already available in mbedtls,
2881 			 * do not wait in poll.
2882 			 */
2883 			if (ret == -EALREADY) {
2884 				timeout = 0;
2885 			}
2886 		}
2887 
2888 		fds[i].fd = ((struct tls_context *)ctx)->sock;
2889 	}
2890 
2891 	/* Get offloaded sockets vtable. */
2892 	ctx = z_get_fd_obj_and_vtable(fds[0].fd,
2893 				      (const struct fd_op_vtable **)&vtable,
2894 				      NULL);
2895 	if (ctx == NULL) {
2896 		errno = EINVAL;
2897 		goto exit;
2898 	}
2899 
2900 	remaining = timeout;
2901 
2902 	do {
2903 		for (i = 0; i < nfds; i++) {
2904 			fds[i].revents = 0;
2905 		}
2906 
2907 		ret = z_fdtable_call_ioctl(vtable, ctx, ZFD_IOCTL_POLL_OFFLOAD,
2908 					   fds, nfds, remaining);
2909 		if (ret < 0) {
2910 			goto exit;
2911 		}
2912 
2913 		retry = false;
2914 		ret = 0;
2915 
2916 		for (i = 0; i < nfds; i++) {
2917 			ctx = z_get_fd_obj(fd_backup[i],
2918 					   (const struct fd_op_vtable *)
2919 							&tls_sock_fd_op_vtable,
2920 					   0);
2921 			if (ctx != NULL) {
2922 				if (fds[i].events & ZSOCK_POLLIN) {
2923 					if (poll_offload_dtls_client_retry(
2924 							ctx, &fds[i])) {
2925 						retry = true;
2926 						continue;
2927 					}
2928 
2929 					result = ztls_poll_update_pollin(
2930 						    fd_backup[i], ctx, &fds[i]);
2931 					if (result == -EAGAIN) {
2932 						retry = true;
2933 					}
2934 				}
2935 			}
2936 
2937 			if (fds[i].revents != 0) {
2938 				ret++;
2939 			}
2940 		}
2941 
2942 		if (retry) {
2943 			if (ret > 0 || timeout == 0) {
2944 				goto exit;
2945 			}
2946 
2947 			if (timeout > 0) {
2948 				remaining = time_left(entry, timeout);
2949 				if (remaining <= 0) {
2950 					goto exit;
2951 				}
2952 			}
2953 		}
2954 	} while (retry);
2955 
2956 exit:
2957 	/* Restore original fds. */
2958 	for (i = 0; i < nfds; i++) {
2959 		fds[i].fd = fd_backup[i];
2960 	}
2961 
2962 	return ret;
2963 }
2964 
ztls_getsockopt_ctx(struct tls_context * ctx,int level,int optname,void * optval,socklen_t * optlen)2965 int ztls_getsockopt_ctx(struct tls_context *ctx, int level, int optname,
2966 			void *optval, socklen_t *optlen)
2967 {
2968 	int err;
2969 
2970 	if (!optval || !optlen) {
2971 		errno = EINVAL;
2972 		return -1;
2973 	}
2974 
2975 	if ((level == SOL_SOCKET) && (optname == SO_PROTOCOL)) {
2976 		/* Protocol type is overridden during socket creation. Its
2977 		 * value is restored here to return current value.
2978 		 */
2979 		err = sock_opt_protocol_get(ctx, optval, optlen);
2980 		if (err < 0) {
2981 			errno = -err;
2982 			return -1;
2983 		}
2984 		return err;
2985 	} else if (level != SOL_TLS) {
2986 		return zsock_getsockopt(ctx->sock, level, optname,
2987 					optval, optlen);
2988 	}
2989 
2990 	switch (optname) {
2991 	case TLS_SEC_TAG_LIST:
2992 		err =  tls_opt_sec_tag_list_get(ctx, optval, optlen);
2993 		break;
2994 
2995 	case TLS_CIPHERSUITE_LIST:
2996 		err = tls_opt_ciphersuite_list_get(ctx, optval, optlen);
2997 		break;
2998 
2999 	case TLS_CIPHERSUITE_USED:
3000 		err = tls_opt_ciphersuite_used_get(ctx, optval, optlen);
3001 		break;
3002 
3003 	case TLS_ALPN_LIST:
3004 		err = tls_opt_alpn_list_get(ctx, optval, optlen);
3005 		break;
3006 
3007 	case TLS_SESSION_CACHE:
3008 		err = tls_opt_session_cache_get(ctx, optval, optlen);
3009 		break;
3010 
3011 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3012 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3013 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3014 							 optlen, false);
3015 		break;
3016 
3017 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3018 		err = tls_opt_dtls_handshake_timeout_get(ctx, optval,
3019 							 optlen, true);
3020 		break;
3021 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3022 
3023 	default:
3024 		/* Unknown or write-only option. */
3025 		err = -ENOPROTOOPT;
3026 		break;
3027 	}
3028 
3029 	if (err < 0) {
3030 		errno = -err;
3031 		return -1;
3032 	}
3033 
3034 	return 0;
3035 }
3036 
set_timeout_opt(k_timeout_t * timeout,const void * optval,socklen_t optlen)3037 static int set_timeout_opt(k_timeout_t *timeout, const void *optval,
3038 			   socklen_t optlen)
3039 {
3040 	const struct zsock_timeval *tval = optval;
3041 
3042 	if (optlen != sizeof(struct zsock_timeval)) {
3043 		return -EINVAL;
3044 	}
3045 
3046 	if (tval->tv_sec == 0 && tval->tv_usec == 0) {
3047 		*timeout = K_FOREVER;
3048 	} else {
3049 		*timeout = K_USEC(tval->tv_sec * 1000000ULL + tval->tv_usec);
3050 	}
3051 
3052 	return 0;
3053 }
3054 
ztls_setsockopt_ctx(struct tls_context * ctx,int level,int optname,const void * optval,socklen_t optlen)3055 int ztls_setsockopt_ctx(struct tls_context *ctx, int level, int optname,
3056 			const void *optval, socklen_t optlen)
3057 {
3058 	int err;
3059 
3060 	/* Underlying socket is used in non-blocking mode, hence implement
3061 	 * timeout at the TLS socket level.
3062 	 */
3063 	if ((level == SOL_SOCKET) && (optname == SO_SNDTIMEO)) {
3064 		err = set_timeout_opt(&ctx->options.timeout_tx, optval, optlen);
3065 		goto out;
3066 	}
3067 
3068 	if ((level == SOL_SOCKET) && (optname == SO_RCVTIMEO)) {
3069 		err = set_timeout_opt(&ctx->options.timeout_rx, optval, optlen);
3070 		goto out;
3071 	}
3072 
3073 	if (level != SOL_TLS) {
3074 		return zsock_setsockopt(ctx->sock, level, optname,
3075 					optval, optlen);
3076 	}
3077 
3078 	switch (optname) {
3079 	case TLS_SEC_TAG_LIST:
3080 		err =  tls_opt_sec_tag_list_set(ctx, optval, optlen);
3081 		break;
3082 
3083 	case TLS_HOSTNAME:
3084 		err = tls_opt_hostname_set(ctx, optval, optlen);
3085 		break;
3086 
3087 	case TLS_CIPHERSUITE_LIST:
3088 		err = tls_opt_ciphersuite_list_set(ctx, optval, optlen);
3089 		break;
3090 
3091 	case TLS_PEER_VERIFY:
3092 		err = tls_opt_peer_verify_set(ctx, optval, optlen);
3093 		break;
3094 
3095 	case TLS_CERT_NOCOPY:
3096 		err = tls_opt_cert_nocopy_set(ctx, optval, optlen);
3097 		break;
3098 
3099 	case TLS_DTLS_ROLE:
3100 		err = tls_opt_dtls_role_set(ctx, optval, optlen);
3101 		break;
3102 
3103 	case TLS_ALPN_LIST:
3104 		err = tls_opt_alpn_list_set(ctx, optval, optlen);
3105 		break;
3106 
3107 	case TLS_SESSION_CACHE:
3108 		err = tls_opt_session_cache_set(ctx, optval, optlen);
3109 		break;
3110 
3111 	case TLS_SESSION_CACHE_PURGE:
3112 		err = tls_opt_session_cache_purge_set(ctx, optval, optlen);
3113 		break;
3114 
3115 #if defined(CONFIG_NET_SOCKETS_ENABLE_DTLS)
3116 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MIN:
3117 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3118 							 optlen, false);
3119 		break;
3120 
3121 	case TLS_DTLS_HANDSHAKE_TIMEOUT_MAX:
3122 		err = tls_opt_dtls_handshake_timeout_set(ctx, optval,
3123 							 optlen, true);
3124 		break;
3125 #endif /* CONFIG_NET_SOCKETS_ENABLE_DTLS */
3126 
3127 	case TLS_NATIVE:
3128 		/* Option handled at the socket dispatcher level. */
3129 		err = 0;
3130 		break;
3131 
3132 	default:
3133 		/* Unknown or read-only option. */
3134 		err = -ENOPROTOOPT;
3135 		break;
3136 	}
3137 
3138 out:
3139 	if (err < 0) {
3140 		errno = -err;
3141 		return -1;
3142 	}
3143 
3144 	return 0;
3145 }
3146 
tls_sock_read_vmeth(void * obj,void * buffer,size_t count)3147 static ssize_t tls_sock_read_vmeth(void *obj, void *buffer, size_t count)
3148 {
3149 	return ztls_recvfrom_ctx(obj, buffer, count, 0, NULL, 0);
3150 }
3151 
tls_sock_write_vmeth(void * obj,const void * buffer,size_t count)3152 static ssize_t tls_sock_write_vmeth(void *obj, const void *buffer,
3153 				    size_t count)
3154 {
3155 	return ztls_sendto_ctx(obj, buffer, count, 0, NULL, 0);
3156 }
3157 
tls_sock_ioctl_vmeth(void * obj,unsigned int request,va_list args)3158 static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
3159 {
3160 	struct tls_context *ctx = obj;
3161 
3162 	switch (request) {
3163 	/* fcntl() commands */
3164 	case F_GETFL:
3165 	case F_SETFL: {
3166 		const struct fd_op_vtable *vtable;
3167 		struct k_mutex *lock;
3168 		void *obj;
3169 		int ret;
3170 
3171 		obj = z_get_fd_obj_and_vtable(ctx->sock,
3172 				(const struct fd_op_vtable **)&vtable, &lock);
3173 		if (obj == NULL) {
3174 			errno = EBADF;
3175 			return -1;
3176 		}
3177 
3178 		(void)k_mutex_lock(lock, K_FOREVER);
3179 
3180 		/* Pass the call to the core socket implementation. */
3181 		ret = vtable->ioctl(obj, request, args);
3182 
3183 		k_mutex_unlock(lock);
3184 
3185 		return ret;
3186 	}
3187 
3188 	case ZFD_IOCTL_SET_LOCK: {
3189 		struct k_mutex *lock;
3190 
3191 		lock = va_arg(args, struct k_mutex *);
3192 
3193 		ctx_set_lock(obj, lock);
3194 
3195 		return 0;
3196 	}
3197 
3198 	case ZFD_IOCTL_POLL_PREPARE: {
3199 		struct zsock_pollfd *pfd;
3200 		struct k_poll_event **pev;
3201 		struct k_poll_event *pev_end;
3202 
3203 		pfd = va_arg(args, struct zsock_pollfd *);
3204 		pev = va_arg(args, struct k_poll_event **);
3205 		pev_end = va_arg(args, struct k_poll_event *);
3206 
3207 		return ztls_poll_prepare_ctx(obj, pfd, pev, pev_end);
3208 	}
3209 
3210 	case ZFD_IOCTL_POLL_UPDATE: {
3211 		struct zsock_pollfd *pfd;
3212 		struct k_poll_event **pev;
3213 
3214 		pfd = va_arg(args, struct zsock_pollfd *);
3215 		pev = va_arg(args, struct k_poll_event **);
3216 
3217 		return ztls_poll_update_ctx(obj, pfd, pev);
3218 	}
3219 
3220 	case ZFD_IOCTL_POLL_OFFLOAD: {
3221 		struct zsock_pollfd *fds;
3222 		int nfds;
3223 		int timeout;
3224 
3225 		fds = va_arg(args, struct zsock_pollfd *);
3226 		nfds = va_arg(args, int);
3227 		timeout = va_arg(args, int);
3228 
3229 		return ztls_poll_offload(fds, nfds, timeout);
3230 	}
3231 
3232 	default:
3233 		errno = EOPNOTSUPP;
3234 		return -1;
3235 	}
3236 }
3237 
tls_sock_shutdown_vmeth(void * obj,int how)3238 static int tls_sock_shutdown_vmeth(void *obj, int how)
3239 {
3240 	struct tls_context *ctx = obj;
3241 
3242 	return zsock_shutdown(ctx->sock, how);
3243 }
3244 
tls_sock_bind_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3245 static int tls_sock_bind_vmeth(void *obj, const struct sockaddr *addr,
3246 			       socklen_t addrlen)
3247 {
3248 	struct tls_context *ctx = obj;
3249 
3250 	return zsock_bind(ctx->sock, addr, addrlen);
3251 }
3252 
tls_sock_connect_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)3253 static int tls_sock_connect_vmeth(void *obj, const struct sockaddr *addr,
3254 				  socklen_t addrlen)
3255 {
3256 	return ztls_connect_ctx(obj, addr, addrlen);
3257 }
3258 
tls_sock_listen_vmeth(void * obj,int backlog)3259 static int tls_sock_listen_vmeth(void *obj, int backlog)
3260 {
3261 	struct tls_context *ctx = obj;
3262 
3263 	ctx->is_listening = true;
3264 
3265 	return zsock_listen(ctx->sock, backlog);
3266 }
3267 
tls_sock_accept_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3268 static int tls_sock_accept_vmeth(void *obj, struct sockaddr *addr,
3269 				 socklen_t *addrlen)
3270 {
3271 	return ztls_accept_ctx(obj, addr, addrlen);
3272 }
3273 
tls_sock_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)3274 static ssize_t tls_sock_sendto_vmeth(void *obj, const void *buf, size_t len,
3275 				     int flags,
3276 				     const struct sockaddr *dest_addr,
3277 				     socklen_t addrlen)
3278 {
3279 	return ztls_sendto_ctx(obj, buf, len, flags, dest_addr, addrlen);
3280 }
3281 
tls_sock_sendmsg_vmeth(void * obj,const struct msghdr * msg,int flags)3282 static ssize_t tls_sock_sendmsg_vmeth(void *obj, const struct msghdr *msg,
3283 				      int flags)
3284 {
3285 	return ztls_sendmsg_ctx(obj, msg, flags);
3286 }
3287 
tls_sock_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)3288 static ssize_t tls_sock_recvfrom_vmeth(void *obj, void *buf, size_t max_len,
3289 				       int flags, struct sockaddr *src_addr,
3290 				       socklen_t *addrlen)
3291 {
3292 	return ztls_recvfrom_ctx(obj, buf, max_len, flags,
3293 				 src_addr, addrlen);
3294 }
3295 
tls_sock_getsockopt_vmeth(void * obj,int level,int optname,void * optval,socklen_t * optlen)3296 static int tls_sock_getsockopt_vmeth(void *obj, int level, int optname,
3297 				     void *optval, socklen_t *optlen)
3298 {
3299 	return ztls_getsockopt_ctx(obj, level, optname, optval, optlen);
3300 }
3301 
tls_sock_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,socklen_t optlen)3302 static int tls_sock_setsockopt_vmeth(void *obj, int level, int optname,
3303 				     const void *optval, socklen_t optlen)
3304 {
3305 	return ztls_setsockopt_ctx(obj, level, optname, optval, optlen);
3306 }
3307 
tls_sock_close_vmeth(void * obj)3308 static int tls_sock_close_vmeth(void *obj)
3309 {
3310 	return ztls_close_ctx(obj);
3311 }
3312 
tls_sock_getpeername_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3313 static int tls_sock_getpeername_vmeth(void *obj, struct sockaddr *addr,
3314 				      socklen_t *addrlen)
3315 {
3316 	struct tls_context *ctx = obj;
3317 
3318 	return zsock_getpeername(ctx->sock, addr, addrlen);
3319 }
3320 
tls_sock_getsockname_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)3321 static int tls_sock_getsockname_vmeth(void *obj, struct sockaddr *addr,
3322 				      socklen_t *addrlen)
3323 {
3324 	struct tls_context *ctx = obj;
3325 
3326 	return zsock_getsockname(ctx->sock, addr, addrlen);
3327 }
3328 
3329 static const struct socket_op_vtable tls_sock_fd_op_vtable = {
3330 	.fd_vtable = {
3331 		.read = tls_sock_read_vmeth,
3332 		.write = tls_sock_write_vmeth,
3333 		.close = tls_sock_close_vmeth,
3334 		.ioctl = tls_sock_ioctl_vmeth,
3335 	},
3336 	.shutdown = tls_sock_shutdown_vmeth,
3337 	.bind = tls_sock_bind_vmeth,
3338 	.connect = tls_sock_connect_vmeth,
3339 	.listen = tls_sock_listen_vmeth,
3340 	.accept = tls_sock_accept_vmeth,
3341 	.sendto = tls_sock_sendto_vmeth,
3342 	.sendmsg = tls_sock_sendmsg_vmeth,
3343 	.recvfrom = tls_sock_recvfrom_vmeth,
3344 	.getsockopt = tls_sock_getsockopt_vmeth,
3345 	.setsockopt = tls_sock_setsockopt_vmeth,
3346 	.getpeername = tls_sock_getpeername_vmeth,
3347 	.getsockname = tls_sock_getsockname_vmeth,
3348 };
3349 
tls_is_supported(int family,int type,int proto)3350 static bool tls_is_supported(int family, int type, int proto)
3351 {
3352 	if (protocol_check(family, type, &proto) == 0) {
3353 		return true;
3354 	}
3355 
3356 	return false;
3357 }
3358 
3359 /* Since both, TLS sockets and regular ones fall under the same address family,
3360  * it's required to process TLS first in order to capture socket calls which
3361  * create sockets for secure protocols. Every other call for AF_INET/AF_INET6
3362  * will be forwarded to regular socket implementation.
3363  */
3364 BUILD_ASSERT(CONFIG_NET_SOCKETS_TLS_PRIORITY < CONFIG_NET_SOCKETS_PRIORITY_DEFAULT,
3365 	     "CONFIG_NET_SOCKETS_TLS_PRIORITY have to be smaller than CONFIG_NET_SOCKETS_PRIORITY_DEFAULT");
3366 
3367 NET_SOCKET_REGISTER(tls, CONFIG_NET_SOCKETS_TLS_PRIORITY, AF_UNSPEC,
3368 		    tls_is_supported, ztls_socket);
3369