1 // Copyright 2015-2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <string.h>
16 #include <stdlib.h>
17 
18 #include "freertos/FreeRTOS.h"
19 #include "freertos/task.h"
20 #include "esp_tls.h"
21 #include "esp_log.h"
22 #include "esp_system.h"
23 
24 #include "esp_transport.h"
25 #include "esp_transport_ssl.h"
26 #include "esp_transport_utils.h"
27 #include "esp_transport_ssl_internal.h"
28 #include "esp_transport_internal.h"
29 
30 static const char *TAG = "TRANS_SSL";
31 
32 typedef enum {
33     TRANS_SSL_INIT = 0,
34     TRANS_SSL_CONNECTING,
35 } transport_ssl_conn_state_t;
36 
37 /**
38  *  mbedtls specific transport data
39  */
40 typedef struct {
41     esp_tls_t                *tls;
42     esp_tls_cfg_t            cfg;
43     bool                     ssl_initialized;
44     transport_ssl_conn_state_t conn_state;
45 } transport_ssl_t;
46 
47 static int ssl_close(esp_transport_handle_t t);
48 
ssl_connect_async(esp_transport_handle_t t,const char * host,int port,int timeout_ms)49 static int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
50 {
51     transport_ssl_t *ssl = esp_transport_get_context_data(t);
52     if (ssl->conn_state == TRANS_SSL_INIT) {
53         ssl->cfg.timeout_ms = timeout_ms;
54         ssl->cfg.non_block = true;
55         ssl->ssl_initialized = true;
56         ssl->tls = esp_tls_init();
57         if (!ssl->tls) {
58             return -1;
59         }
60         ssl->conn_state = TRANS_SSL_CONNECTING;
61     }
62     if (ssl->conn_state == TRANS_SSL_CONNECTING) {
63         return esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
64     }
65     return 0;
66 }
67 
ssl_connect(esp_transport_handle_t t,const char * host,int port,int timeout_ms)68 static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
69 {
70     transport_ssl_t *ssl = esp_transport_get_context_data(t);
71 
72     ssl->cfg.timeout_ms = timeout_ms;
73     ssl->ssl_initialized = true;
74     ssl->tls = esp_tls_init();
75     if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) <= 0) {
76         ESP_LOGE(TAG, "Failed to open a new connection");
77         esp_transport_set_errors(t, ssl->tls->error_handle);
78         esp_tls_conn_destroy(ssl->tls);
79         ssl->tls = NULL;
80         return -1;
81     }
82 
83     return 0;
84 }
85 
ssl_poll_read(esp_transport_handle_t t,int timeout_ms)86 static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
87 {
88     transport_ssl_t *ssl = esp_transport_get_context_data(t);
89     int ret = -1;
90     int remain = 0;
91     struct timeval timeout;
92     fd_set readset;
93     fd_set errset;
94     FD_ZERO(&readset);
95     FD_ZERO(&errset);
96     FD_SET(ssl->tls->sockfd, &readset);
97     FD_SET(ssl->tls->sockfd, &errset);
98 
99     if ((remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) {
100         ESP_LOGD(TAG, "remain data in cache, need to read again");
101         return remain;
102     }
103     ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
104     if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
105         int sock_errno = 0;
106         uint32_t optlen = sizeof(sock_errno);
107         getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
108         esp_transport_capture_errno(t, sock_errno);
109         ESP_LOGE(TAG, "ssl_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
110         ret = -1;
111     }
112     return ret;
113 }
114 
ssl_poll_write(esp_transport_handle_t t,int timeout_ms)115 static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
116 {
117     transport_ssl_t *ssl = esp_transport_get_context_data(t);
118     int ret = -1;
119     struct timeval timeout;
120     fd_set writeset;
121     fd_set errset;
122     FD_ZERO(&writeset);
123     FD_ZERO(&errset);
124     FD_SET(ssl->tls->sockfd, &writeset);
125     FD_SET(ssl->tls->sockfd, &errset);
126     ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
127     if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
128         int sock_errno = 0;
129         uint32_t optlen = sizeof(sock_errno);
130         getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
131         esp_transport_capture_errno(t, sock_errno);
132         ESP_LOGE(TAG, "ssl_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
133         ret = -1;
134     }
135     return ret;
136 }
137 
ssl_write(esp_transport_handle_t t,const char * buffer,int len,int timeout_ms)138 static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
139 {
140     int poll, ret;
141     transport_ssl_t *ssl = esp_transport_get_context_data(t);
142 
143     if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
144         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
145         return poll;
146     }
147     ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
148     if (ret < 0) {
149         ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
150         esp_transport_set_errors(t, ssl->tls->error_handle);
151     }
152     return ret;
153 }
154 
ssl_read(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)155 static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
156 {
157     int poll, ret;
158     transport_ssl_t *ssl = esp_transport_get_context_data(t);
159 
160     if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
161         return poll;
162     }
163     ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
164     if (ret < 0) {
165         ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
166         esp_transport_set_errors(t, ssl->tls->error_handle);
167     }
168     if (ret == 0) {
169         ret = -1;
170     }
171     return ret;
172 }
173 
ssl_close(esp_transport_handle_t t)174 static int ssl_close(esp_transport_handle_t t)
175 {
176     int ret = -1;
177     transport_ssl_t *ssl = esp_transport_get_context_data(t);
178     if (ssl->ssl_initialized) {
179         ret = esp_tls_conn_destroy(ssl->tls);
180         ssl->conn_state = TRANS_SSL_INIT;
181         ssl->ssl_initialized = false;
182     }
183     return ret;
184 }
185 
ssl_destroy(esp_transport_handle_t t)186 static int ssl_destroy(esp_transport_handle_t t)
187 {
188     transport_ssl_t *ssl = esp_transport_get_context_data(t);
189     esp_transport_close(t);
190     free(ssl);
191     return 0;
192 }
193 
esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)194 void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
195 {
196     transport_ssl_t *ssl = esp_transport_get_context_data(t);
197     if (t && ssl) {
198         ssl->cfg.use_global_ca_store = true;
199     }
200 }
201 
esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t,const psk_hint_key_t * psk_hint_key)202 void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
203 {
204     transport_ssl_t *ssl = esp_transport_get_context_data(t);
205     if (t && ssl) {
206         ssl->cfg.psk_hint_key =  psk_hint_key;
207     }
208 }
209 
esp_transport_ssl_set_cert_data(esp_transport_handle_t t,const char * data,int len)210 void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
211 {
212     transport_ssl_t *ssl = esp_transport_get_context_data(t);
213     if (t && ssl) {
214         ssl->cfg.cacert_pem_buf = (void *)data;
215         ssl->cfg.cacert_pem_bytes = len + 1;
216     }
217 }
218 
esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t,const char * data,int len)219 void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
220 {
221     transport_ssl_t *ssl = esp_transport_get_context_data(t);
222     if (t && ssl) {
223         ssl->cfg.cacert_buf = (void *)data;
224         ssl->cfg.cacert_bytes = len;
225     }
226 }
227 
esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t,const char * data,int len)228 void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
229 {
230     transport_ssl_t *ssl = esp_transport_get_context_data(t);
231     if (t && ssl) {
232         ssl->cfg.clientcert_pem_buf = (void *)data;
233         ssl->cfg.clientcert_pem_bytes = len + 1;
234     }
235 }
236 
esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t,const char * data,int len)237 void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
238 {
239     transport_ssl_t *ssl = esp_transport_get_context_data(t);
240     if (t && ssl) {
241         ssl->cfg.clientcert_buf = (void *)data;
242         ssl->cfg.clientcert_bytes = len;
243     }
244 }
245 
esp_transport_ssl_set_client_key_data(esp_transport_handle_t t,const char * data,int len)246 void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
247 {
248     transport_ssl_t *ssl = esp_transport_get_context_data(t);
249     if (t && ssl) {
250         ssl->cfg.clientkey_pem_buf = (void *)data;
251         ssl->cfg.clientkey_pem_bytes = len + 1;
252     }
253 }
254 
esp_transport_ssl_set_client_key_password(esp_transport_handle_t t,const char * password,int password_len)255 void esp_transport_ssl_set_client_key_password(esp_transport_handle_t t, const char *password, int password_len)
256 {
257     transport_ssl_t *ssl = esp_transport_get_context_data(t);
258     if (t && ssl) {
259         ssl->cfg.clientkey_password = (void *)password;
260         ssl->cfg.clientkey_password_len = password_len;
261     }
262 }
263 
esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t,const char * data,int len)264 void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
265 {
266     transport_ssl_t *ssl = esp_transport_get_context_data(t);
267     if (t && ssl) {
268         ssl->cfg.clientkey_buf = (void *)data;
269         ssl->cfg.clientkey_bytes = len;
270     }
271 }
272 
esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t,const char ** alpn_protos)273 void esp_transport_ssl_set_alpn_protocol(esp_transport_handle_t t, const char **alpn_protos)
274 {
275     transport_ssl_t *ssl = esp_transport_get_context_data(t);
276     if (t && ssl) {
277         ssl->cfg.alpn_protos = alpn_protos;
278     }
279 }
280 
esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)281 void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
282 {
283     transport_ssl_t *ssl = esp_transport_get_context_data(t);
284     if (t && ssl) {
285         ssl->cfg.skip_common_name = true;
286     }
287 }
288 
esp_transport_ssl_use_secure_element(esp_transport_handle_t t)289 void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
290 {
291     transport_ssl_t *ssl = esp_transport_get_context_data(t);
292     if (t && ssl) {
293         ssl->cfg.use_secure_element = true;
294     }
295 }
296 
ssl_get_socket(esp_transport_handle_t t)297 static int ssl_get_socket(esp_transport_handle_t t)
298 {
299     if (t) {
300         transport_ssl_t *ssl = t->data;
301         if (ssl && ssl->tls) {
302             return ssl->tls->sockfd;
303         }
304     }
305     return -1;
306 }
307 
esp_transport_ssl_set_ds_data(esp_transport_handle_t t,void * ds_data)308 void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data)
309 {
310     transport_ssl_t *ssl = esp_transport_get_context_data(t);
311     if (t && ssl) {
312         ssl->cfg.ds_data = ds_data;
313     }
314 }
315 
esp_transport_ssl_set_keep_alive(esp_transport_handle_t t,esp_transport_keep_alive_t * keep_alive_cfg)316 void esp_transport_ssl_set_keep_alive(esp_transport_handle_t t, esp_transport_keep_alive_t *keep_alive_cfg)
317 {
318     transport_ssl_t *ssl = esp_transport_get_context_data(t);
319     if (t && ssl) {
320         ssl->cfg.keep_alive_cfg = (tls_keep_alive_cfg_t *)keep_alive_cfg;
321     }
322 }
323 
esp_transport_ssl_init(void)324 esp_transport_handle_t esp_transport_ssl_init(void)
325 {
326     esp_transport_handle_t t = esp_transport_init();
327     transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
328     ESP_TRANSPORT_MEM_CHECK(TAG, ssl, {
329         esp_transport_destroy(t);
330         return NULL;
331     });
332     esp_transport_set_context_data(t, ssl);
333     esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
334     esp_transport_set_async_connect_func(t, ssl_connect_async);
335     t->_get_socket = ssl_get_socket;
336     return t;
337 }
338