1 /*
2  * SPDX-FileCopyrightText: 2020-2022 Espressif Systems (Shanghai) CO LTD
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 #include <sys/param.h>
7 #include <stdbool.h>
8 #include "esp_mbedtls_dynamic_impl.h"
9 
10 int __real_mbedtls_ssl_handshake_client_step(mbedtls_ssl_context *ssl);
11 int __real_mbedtls_ssl_write_client_hello(mbedtls_ssl_context *ssl);
12 
13 int __wrap_mbedtls_ssl_handshake_client_step(mbedtls_ssl_context *ssl);
14 int __wrap_mbedtls_ssl_write_client_hello(mbedtls_ssl_context *ssl);
15 
16 static const char *TAG = "SSL client";
17 
manage_resource(mbedtls_ssl_context * ssl,bool add)18 static int manage_resource(mbedtls_ssl_context *ssl, bool add)
19 {
20     int state = add ? ssl->MBEDTLS_PRIVATE(state) : ssl->MBEDTLS_PRIVATE(state) - 1;
21 
22     if (mbedtls_ssl_is_handshake_over(ssl) || ssl->MBEDTLS_PRIVATE(handshake) == NULL) {
23         return 0;
24     }
25 
26     if (!add) {
27         if (!ssl->MBEDTLS_PRIVATE(out_left)) {
28             CHECK_OK(esp_mbedtls_free_tx_buffer(ssl));
29         }
30     }
31 
32     /* Change state now, so that it is right in mbedtls_ssl_read_record(), used
33      * by DTLS for dropping out-of-sequence ChangeCipherSpec records */
34 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
35     if( ssl->state == MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC &&
36         ssl->handshake->new_session_ticket != 0 )
37     {
38         ssl->state = MBEDTLS_SSL_NEW_SESSION_TICKET;
39     }
40 #endif
41 
42     switch (state) {
43         case MBEDTLS_SSL_HELLO_REQUEST:
44             break;
45         case MBEDTLS_SSL_CLIENT_HELLO:
46             if (add) {
47                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
48 
49                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
50             }
51             break;
52 
53 
54         case MBEDTLS_SSL_SERVER_HELLO:
55             if (add) {
56                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
57             } else {
58                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
59             }
60             break;
61         case MBEDTLS_SSL_SERVER_CERTIFICATE:
62             if (add) {
63                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
64             } else {
65                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
66 
67 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_CA_CERT
68                 esp_mbedtls_free_cacert(ssl);
69 #endif
70             }
71             break;
72         case MBEDTLS_SSL_SERVER_KEY_EXCHANGE:
73             if (add) {
74                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
75             } else {
76                 if (!ssl->MBEDTLS_PRIVATE(keep_current_message)) {
77                     CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
78                 }
79             }
80             break;
81         case MBEDTLS_SSL_CERTIFICATE_REQUEST:
82             if (add) {
83                 if (!ssl->MBEDTLS_PRIVATE(keep_current_message)) {
84                     CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
85                 }
86             } else {
87                 if (!ssl->MBEDTLS_PRIVATE(keep_current_message)) {
88                     CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
89                 }
90             }
91             break;
92         case MBEDTLS_SSL_SERVER_HELLO_DONE:
93             if (add) {
94                 if (!ssl->MBEDTLS_PRIVATE(keep_current_message)) {
95                     CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
96                 }
97             } else {
98                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
99             }
100             break;
101 
102 
103         case MBEDTLS_SSL_CLIENT_CERTIFICATE:
104             if (add) {
105                 size_t buffer_len = 3;
106 
107                 const mbedtls_ssl_config *conf = mbedtls_ssl_context_get_config(ssl);
108                 mbedtls_ssl_key_cert *key_cert = conf->MBEDTLS_PRIVATE(key_cert);
109 
110                 while (key_cert && key_cert->cert) {
111                     size_t num;
112 
113                     buffer_len += esp_mbedtls_get_crt_size(key_cert->cert, &num);
114                     buffer_len += num * 3;
115 
116                     key_cert = key_cert->next;
117                 }
118 
119                 buffer_len = MAX(buffer_len, MBEDTLS_SSL_OUT_BUFFER_LEN);
120 
121                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
122             }
123             break;
124         case MBEDTLS_SSL_CLIENT_KEY_EXCHANGE:
125             if (add) {
126                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
127 
128                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
129             }
130             break;
131         case MBEDTLS_SSL_CERTIFICATE_VERIFY:
132             if (add) {
133                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
134 
135                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
136             } else {
137 #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_CONFIG_DATA
138                 esp_mbedtls_free_dhm(ssl);
139                 esp_mbedtls_free_keycert_key(ssl);
140                 esp_mbedtls_free_keycert(ssl);
141 #endif
142             }
143             break;
144         case MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC:
145             if (add) {
146                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
147 
148                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
149             }
150             break;
151         case MBEDTLS_SSL_CLIENT_FINISHED:
152             if (add) {
153                 size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN;
154 
155                 CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len));
156             }
157             break;
158 
159 
160 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
161         case MBEDTLS_SSL_NEW_SESSION_TICKET:
162             if (add) {
163                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
164             } else {
165                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
166             }
167             break;
168 #endif
169 
170 
171         case MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC:
172             if (add) {
173                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
174             } else {
175                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
176             }
177             break;
178         case MBEDTLS_SSL_SERVER_FINISHED:
179             if (add) {
180                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
181             } else {
182                 CHECK_OK(esp_mbedtls_free_rx_buffer(ssl));
183             }
184             break;
185         case MBEDTLS_SSL_FLUSH_BUFFERS:
186             break;
187         case MBEDTLS_SSL_HANDSHAKE_WRAPUP:
188 #if defined(MBEDTLS_SSL_RENEGOTIATION)
189             if (add && ssl->MBEDTLS_PRIVATE(renego_status)) {
190                 CHECK_OK(esp_mbedtls_add_rx_buffer(ssl));
191             }
192 #endif
193             break;
194         default:
195             break;
196     }
197 
198     return 0;
199 }
200 
__wrap_mbedtls_ssl_handshake_client_step(mbedtls_ssl_context * ssl)201 int __wrap_mbedtls_ssl_handshake_client_step(mbedtls_ssl_context *ssl)
202 {
203     CHECK_OK(manage_resource(ssl, true));
204 
205     CHECK_OK(__real_mbedtls_ssl_handshake_client_step(ssl));
206 
207     CHECK_OK(manage_resource(ssl, false));
208 
209     return 0;
210 }
211 
__wrap_mbedtls_ssl_write_client_hello(mbedtls_ssl_context * ssl)212 int __wrap_mbedtls_ssl_write_client_hello(mbedtls_ssl_context *ssl)
213 {
214     CHECK_OK(manage_resource(ssl, true));
215 
216     CHECK_OK(__real_mbedtls_ssl_write_client_hello(ssl));
217 
218     CHECK_OK(manage_resource(ssl, false));
219 
220     return 0;
221 }
222