1 /*
2  * SPDX-FileCopyrightText: 2015-2022 Espressif Systems (Shanghai) CO LTD
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <string.h>
8 #include <stdlib.h>
9 
10 #include "esp_tls.h"
11 #include "esp_log.h"
12 
13 #include "esp_transport.h"
14 #include "esp_transport_ssl.h"
15 #include "esp_transport_utils.h"
16 #include "esp_transport_internal.h"
17 
18 #define INVALID_SOCKET (-1)
19 
20 #define GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t)         \
21     transport_esp_tls_t *ssl = ssl_get_context_data(t);  \
22     if (!ssl) { return; }
23 
24 static const char *TAG = "TRANSPORT_BASE";
25 
26 typedef enum {
27     TRANS_SSL_INIT = 0,
28     TRANS_SSL_CONNECTING,
29 } transport_ssl_conn_state_t;
30 
31 /**
32  *  mbedtls specific transport data
33  */
34 typedef struct transport_esp_tls {
35     esp_tls_t                *tls;
36     esp_tls_cfg_t            cfg;
37     bool                     ssl_initialized;
38     transport_ssl_conn_state_t conn_state;
39     int                      sockfd;
40 } transport_esp_tls_t;
41 
ssl_get_context_data(esp_transport_handle_t t)42 static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t)
43 {
44     if (!t) {
45         return NULL;
46     }
47     if (t->data) {  // Prefer internal ssl context (independent from the list)
48         return (transport_esp_tls_t*)t->data;
49     }
50     if (t->base && t->base->transport_esp_tls) {    // Next one is the lists inherent context
51         t->data = t->base->transport_esp_tls;       // Optimize: if we have base context, use it as internal
52         return t->base->transport_esp_tls;
53     }
54     // If we don't have a valid context, let's to create one
55     transport_esp_tls_t *ssl = esp_transport_esp_tls_create();
56     ESP_TRANSPORT_MEM_CHECK(TAG, ssl, return NULL)
57     t->data = ssl;
58     return ssl;
59 }
60 
esp_tls_connect_async(esp_transport_handle_t t,const char * host,int port,int timeout_ms,bool is_plain_tcp)61 static int esp_tls_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp)
62 {
63     transport_esp_tls_t *ssl = ssl_get_context_data(t);
64     if (ssl->conn_state == TRANS_SSL_INIT) {
65         ssl->cfg.timeout_ms = timeout_ms;
66         ssl->cfg.is_plain_tcp = is_plain_tcp;
67         ssl->cfg.non_block = true;
68         ssl->ssl_initialized = true;
69         ssl->tls = esp_tls_init();
70         if (!ssl->tls) {
71             return -1;
72         }
73         ssl->conn_state = TRANS_SSL_CONNECTING;
74         ssl->sockfd = INVALID_SOCKET;
75     }
76     if (ssl->conn_state == TRANS_SSL_CONNECTING) {
77         int progress = esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
78         if (progress >= 0) {
79             ssl->sockfd = ssl->tls->sockfd;
80         }
81         return progress;
82 
83     }
84     return 0;
85 }
86 
ssl_connect_async(esp_transport_handle_t t,const char * host,int port,int timeout_ms)87 static inline int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
88 {
89     return esp_tls_connect_async(t, host, port, timeout_ms, false);
90 }
91 
tcp_connect_async(esp_transport_handle_t t,const char * host,int port,int timeout_ms)92 static inline int tcp_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
93 {
94     return esp_tls_connect_async(t, host, port, timeout_ms, true);
95 }
96 
ssl_connect(esp_transport_handle_t t,const char * host,int port,int timeout_ms)97 static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
98 {
99     transport_esp_tls_t *ssl = ssl_get_context_data(t);
100 
101     ssl->cfg.timeout_ms = timeout_ms;
102     ssl->cfg.is_plain_tcp = false;
103 
104     ssl->ssl_initialized = true;
105     ssl->tls = esp_tls_init();
106     if (ssl->tls == NULL) {
107         ESP_LOGE(TAG, "Failed to initialize new connection object");
108         capture_tcp_transport_error(t, ERR_TCP_TRANSPORT_NO_MEM);
109         return -1;
110     }
111     if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) {
112         ESP_LOGE(TAG, "Failed to open a new connection");
113         esp_transport_set_errors(t, ssl->tls->error_handle);
114         esp_tls_conn_destroy(ssl->tls);
115         ssl->tls = NULL;
116         ssl->sockfd = INVALID_SOCKET;
117         return -1;
118     }
119     ssl->sockfd = ssl->tls->sockfd;
120     return 0;
121 }
122 
tcp_connect(esp_transport_handle_t t,const char * host,int port,int timeout_ms)123 static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
124 {
125     transport_esp_tls_t *ssl = ssl_get_context_data(t);
126     esp_tls_last_error_t *err_handle = esp_transport_get_error_handle(t);
127 
128     ssl->cfg.timeout_ms = timeout_ms;
129     esp_err_t err = esp_tls_plain_tcp_connect(host, strlen(host), port, &ssl->cfg, err_handle, &ssl->sockfd);
130     if (err != ESP_OK) {
131         ESP_LOGE(TAG, "Failed to open a new connection: %d", err);
132         err_handle->last_error = err;
133         ssl->sockfd = INVALID_SOCKET;
134         return -1;
135     }
136     return 0;
137 }
138 
base_poll_read(esp_transport_handle_t t,int timeout_ms)139 static int base_poll_read(esp_transport_handle_t t, int timeout_ms)
140 {
141     transport_esp_tls_t *ssl = ssl_get_context_data(t);
142     int ret = -1;
143     int remain = 0;
144     struct timeval timeout;
145     fd_set readset;
146     fd_set errset;
147     FD_ZERO(&readset);
148     FD_ZERO(&errset);
149     FD_SET(ssl->sockfd, &readset);
150     FD_SET(ssl->sockfd, &errset);
151 
152     if (ssl->tls && (remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) {
153         ESP_LOGD(TAG, "remain data in cache, need to read again");
154         return remain;
155     }
156     ret = select(ssl->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
157     if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) {
158         int sock_errno = 0;
159         uint32_t optlen = sizeof(sock_errno);
160         getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
161         esp_transport_capture_errno(t, sock_errno);
162         ESP_LOGE(TAG, "poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->sockfd);
163         ret = -1;
164     }
165     return ret;
166 }
167 
base_poll_write(esp_transport_handle_t t,int timeout_ms)168 static int base_poll_write(esp_transport_handle_t t, int timeout_ms)
169 {
170     transport_esp_tls_t *ssl = ssl_get_context_data(t);
171     int ret = -1;
172     struct timeval timeout;
173     fd_set writeset;
174     fd_set errset;
175     FD_ZERO(&writeset);
176     FD_ZERO(&errset);
177     FD_SET(ssl->sockfd, &writeset);
178     FD_SET(ssl->sockfd, &errset);
179     ret = select(ssl->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
180     if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) {
181         int sock_errno = 0;
182         uint32_t optlen = sizeof(sock_errno);
183         getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
184         esp_transport_capture_errno(t, sock_errno);
185         ESP_LOGE(TAG, "poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->sockfd);
186         ret = -1;
187     }
188     return ret;
189 }
190 
ssl_write(esp_transport_handle_t t,const char * buffer,int len,int timeout_ms)191 static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
192 {
193     int poll;
194     transport_esp_tls_t *ssl = ssl_get_context_data(t);
195 
196     if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
197         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->sockfd, timeout_ms);
198         return poll;
199     }
200     int ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
201     if (ret < 0) {
202         ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
203         esp_transport_set_errors(t, ssl->tls->error_handle);
204     }
205     return ret;
206 }
207 
tcp_write(esp_transport_handle_t t,const char * buffer,int len,int timeout_ms)208 static int tcp_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
209 {
210     int poll;
211     transport_esp_tls_t *ssl = ssl_get_context_data(t);
212 
213     if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
214         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->sockfd, timeout_ms);
215         return poll;
216     }
217     int ret = send(ssl->sockfd,(const unsigned char *) buffer, len, 0);
218     if (ret < 0) {
219         ESP_LOGE(TAG, "tcp_write error, errno=%s", strerror(errno));
220         esp_transport_capture_errno(t, errno);
221     }
222     return ret;
223 }
224 
ssl_read(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)225 static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
226 {
227     int poll;
228     transport_esp_tls_t *ssl = ssl_get_context_data(t);
229 
230     if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
231         return poll;
232     }
233     int ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
234     if (ret < 0) {
235         ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
236         esp_transport_set_errors(t, ssl->tls->error_handle);
237     }
238     if (ret == 0) {
239         if (poll > 0) {
240             // no error, socket reads 0 while previously detected as readable -> connection has been closed cleanly
241             capture_tcp_transport_error(t, ERR_TCP_TRANSPORT_CONNECTION_CLOSED_BY_FIN);
242         }
243         ret = -1;
244     }
245     return ret;
246 }
247 
tcp_read(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)248 static int tcp_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
249 {
250     int poll;
251     transport_esp_tls_t *ssl = ssl_get_context_data(t);
252 
253     if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
254         return poll;
255     }
256     int ret = recv(ssl->sockfd, (unsigned char *)buffer, len, 0);
257     if (ret < 0) {
258         ESP_LOGE(TAG, "tcp_read error, errno=%s", strerror(errno));
259         esp_transport_capture_errno(t, errno);
260     }
261     if (ret == 0) {
262         if (poll > 0) {
263             // no error, socket reads 0 while previously detected as readable -> connection has been closed cleanly
264             capture_tcp_transport_error(t, ERR_TCP_TRANSPORT_CONNECTION_CLOSED_BY_FIN);
265         }
266         ret = -1;
267     }
268     return ret;
269 }
270 
base_close(esp_transport_handle_t t)271 static int base_close(esp_transport_handle_t t)
272 {
273     int ret = -1;
274     transport_esp_tls_t *ssl = ssl_get_context_data(t);
275     if (ssl && ssl->ssl_initialized) {
276         ret = esp_tls_conn_destroy(ssl->tls);
277         ssl->tls = NULL;
278         ssl->conn_state = TRANS_SSL_INIT;
279         ssl->ssl_initialized = false;
280         ssl->sockfd = INVALID_SOCKET;
281     } else if (ssl && ssl->sockfd >= 0) {
282         ret = close(ssl->sockfd);
283         ssl->sockfd = INVALID_SOCKET;
284     }
285     return ret;
286 }
287 
base_destroy(esp_transport_handle_t t)288 static int base_destroy(esp_transport_handle_t t)
289 {
290     transport_esp_tls_t *ssl = ssl_get_context_data(t);
291     if (ssl) {
292         esp_transport_close(t);
293         if (t->base && t->base->transport_esp_tls &&
294             t->data == t->base->transport_esp_tls) {
295             // if internal ssl the same as the foundation transport,
296             // just zero out, it will be freed on list destroy
297             t->data = NULL;
298         }
299         esp_transport_esp_tls_destroy(t->data); // okay to pass NULL
300     }
301     return 0;
302 }
303 
esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)304 void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
305 {
306     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
307     ssl->cfg.use_global_ca_store = true;
308 }
309 
310 #ifdef CONFIG_ESP_TLS_PSK_VERIFICATION
esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t,const psk_hint_key_t * psk_hint_key)311 void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
312 {
313     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
314     ssl->cfg.psk_hint_key =  psk_hint_key;
315 }
316 #endif
317 
esp_transport_ssl_set_cert_data(esp_transport_handle_t t,const char * data,int len)318 void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
319 {
320     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
321     ssl->cfg.cacert_pem_buf = (void *)data;
322     ssl->cfg.cacert_pem_bytes = len + 1;
323 }
324 
esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t,const char * data,int len)325 void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
326 {
327     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
328     ssl->cfg.cacert_buf = (void *)data;
329     ssl->cfg.cacert_bytes = len;
330 }
331 
esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t,const char * data,int len)332 void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
333 {
334     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
335     ssl->cfg.clientcert_pem_buf = (void *)data;
336     ssl->cfg.clientcert_pem_bytes = len + 1;
337 }
338 
esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t,const char * data,int len)339 void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
340 {
341     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
342     ssl->cfg.clientcert_buf = (void *)data;
343     ssl->cfg.clientcert_bytes = len;
344 }
345 
esp_transport_ssl_set_client_key_data(esp_transport_handle_t t,const char * data,int len)346 void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
347 {
348     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
349     ssl->cfg.clientkey_pem_buf = (void *)data;
350     ssl->cfg.clientkey_pem_bytes = len + 1;
351 }
352 
esp_transport_ssl_set_client_key_password(esp_transport_handle_t t,const char * password,int password_len)353 void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len)
354 {
355     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
356     ssl->cfg.clientkey_password = (void *)password;
357     ssl->cfg.clientkey_password_len = password_len;
358 }
359 
esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t,const char * data,int len)360 void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
361 {
362     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
363     ssl->cfg.clientkey_buf = (void *)data;
364     ssl->cfg.clientkey_bytes = len;
365 }
366 
367 #if defined(CONFIG_MBEDTLS_SSL_ALPN) || defined(CONFIG_WOLFSSL_HAVE_ALPN)
esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t,const char ** alpn_protos)368 void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos)
369 {
370     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
371     ssl->cfg.alpn_protos = alpn_protos;
372 }
373 #endif
374 
esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)375 void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
376 {
377     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
378     ssl->cfg.skip_common_name = true;
379 }
380 
381 #ifdef CONFIG_ESP_TLS_USE_SECURE_ELEMENT
esp_transport_ssl_use_secure_element(esp_transport_handle_t t)382 void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
383 {
384     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
385     ssl->cfg.use_secure_element = true;
386 }
387 #endif
388 
389 #ifdef CONFIG_MBEDTLS_CERTIFICATE_BUNDLE
esp_transport_ssl_crt_bundle_attach(esp_transport_handle_t t,esp_err_t ((* crt_bundle_attach)(void * conf)))390 void esp_transport_ssl_crt_bundle_attach(esp_transport_handle_t t, esp_err_t ((*crt_bundle_attach)(void *conf)))
391 {
392     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
393     ssl->cfg.crt_bundle_attach = crt_bundle_attach;
394 }
395 #endif
396 
base_get_socket(esp_transport_handle_t t)397 static int base_get_socket(esp_transport_handle_t t)
398 {
399     transport_esp_tls_t *ctx = ssl_get_context_data(t);
400     if (ctx) {
401         return ctx->sockfd;
402     }
403     return INVALID_SOCKET;
404 }
405 
406 #ifdef CONFIG_ESP_TLS_USE_DS_PERIPHERAL
esp_transport_ssl_set_ds_data(esp_transport_handle_t t,void * ds_data)407 void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
408 {
409     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
410     ssl->cfg.ds_data = ds_data;
411 }
412 #endif
413 
esp_transport_ssl_set_keep_alive(esp_transport_handle_t t,esp_transport_keep_alive_t * keep_alive_cfg)414 void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg)
415 {
416     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
417     ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *) keep_alive_cfg;
418 }
419 
esp_transport_ssl_set_interface_name(esp_transport_handle_t t,struct ifreq * if_name)420 void esp_transport_ssl_set_interface_name(esp_transport_handle_t t, struct ifreq *if_name)
421 {
422     GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t);
423     ssl->cfg.if_name = if_name;
424 }
425 
esp_transport_ssl_init(void)426 esp_transport_handle_t esp_transport_ssl_init(void)
427 {
428     esp_transport_handle_t t = esp_transport_init();
429     if (t == NULL) {
430         return NULL;
431     }
432     esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, base_close, base_poll_read, base_poll_write, base_destroy);
433     esp_transport_set_async_connect_func(t, ssl_connect_async);
434     t->_get_socket = base_get_socket;
435     return t;
436 }
437 
esp_transport_esp_tls_create(void)438 struct transport_esp_tls* esp_transport_esp_tls_create(void)
439 {
440     transport_esp_tls_t *transport_esp_tls = calloc(1, sizeof(transport_esp_tls_t));
441     if (transport_esp_tls == NULL) {
442         return NULL;
443     }
444     transport_esp_tls->sockfd = INVALID_SOCKET;
445     return transport_esp_tls;
446 }
447 
esp_transport_esp_tls_destroy(struct transport_esp_tls * transport_esp_tls)448 void esp_transport_esp_tls_destroy(struct transport_esp_tls* transport_esp_tls)
449 {
450     free(transport_esp_tls);
451 }
452 
esp_transport_tcp_init(void)453 esp_transport_handle_t esp_transport_tcp_init(void)
454 {
455     esp_transport_handle_t t = esp_transport_init();
456     if (t == NULL) {
457         return NULL;
458     }
459     esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, base_close, base_poll_read, base_poll_write, base_destroy);
460     esp_transport_set_async_connect_func(t, tcp_connect_async);
461     t->_get_socket = base_get_socket;
462     return t;
463 }
464 
esp_transport_tcp_set_keep_alive(esp_transport_handle_t t,esp_transport_keep_alive_t * keep_alive_cfg)465 void esp_transport_tcp_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg)
466 {
467     return esp_transport_ssl_set_keep_alive(t, keep_alive_cfg);
468 }
469 
esp_transport_tcp_set_interface_name(esp_transport_handle_t t,struct ifreq * if_name)470 void esp_transport_tcp_set_interface_name(esp_transport_handle_t t, struct ifreq *if_name)
471 {
472     return esp_transport_ssl_set_interface_name(t, if_name);
473 }
474