1 // Copyright 2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <stdlib.h>
16 #include <string.h>
17
18 #include <esp_err.h>
19 #include <esp_log.h>
20 #include <sys/queue.h>
21
22 #include <protocomm.h>
23 #include <protocomm_security.h>
24
25 #include "protocomm_priv.h"
26
27 static const char *TAG = "protocomm";
28
protocomm_new(void)29 protocomm_t *protocomm_new(void)
30 {
31 protocomm_t *pc;
32
33 pc = (protocomm_t *) calloc(1, sizeof(protocomm_t));
34 if (!pc) {
35 ESP_LOGE(TAG, "Error allocating protocomm");
36 return NULL;
37 }
38 SLIST_INIT(&pc->endpoints);
39
40 return pc;
41 }
42
protocomm_delete(protocomm_t * pc)43 void protocomm_delete(protocomm_t *pc)
44 {
45 if (pc == NULL) {
46 return;
47 }
48
49 protocomm_ep_t *it, *tmp;
50 /* Remove endpoints first */
51 SLIST_FOREACH_SAFE(it, &pc->endpoints, next, tmp) {
52 free(it);
53 }
54
55 /* Free memory allocated to version string */
56 if (pc->ver) {
57 free((void *)pc->ver);
58 }
59
60 /* Free memory allocated to security */
61 if (pc->sec && pc->sec->cleanup) {
62 pc->sec->cleanup(pc->sec_inst);
63 }
64 if (pc->pop) {
65 free(pc->pop);
66 }
67
68 free(pc);
69 }
70
search_endpoint(protocomm_t * pc,const char * ep_name)71 static protocomm_ep_t *search_endpoint(protocomm_t *pc, const char *ep_name)
72 {
73 protocomm_ep_t *it;
74 SLIST_FOREACH(it, &pc->endpoints, next) {
75 if (strcmp(it->ep_name, ep_name) == 0) {
76 return it;
77 }
78 }
79 return NULL;
80 }
81
protocomm_add_endpoint_internal(protocomm_t * pc,const char * ep_name,protocomm_req_handler_t h,void * priv_data,uint32_t flag)82 static esp_err_t protocomm_add_endpoint_internal(protocomm_t *pc, const char *ep_name,
83 protocomm_req_handler_t h, void *priv_data,
84 uint32_t flag)
85 {
86 if ((pc == NULL) || (ep_name == NULL) || (h == NULL)) {
87 return ESP_ERR_INVALID_ARG;
88 }
89
90 protocomm_ep_t *ep;
91 esp_err_t ret;
92
93 ep = search_endpoint(pc, ep_name);
94 if (ep) {
95 ESP_LOGE(TAG, "Endpoint with this name already exists");
96 return ESP_FAIL;
97 }
98
99 if (pc->add_endpoint) {
100 ret = pc->add_endpoint(ep_name, h, priv_data);
101 if (ret != ESP_OK) {
102 ESP_LOGE(TAG, "Error adding endpoint");
103 return ret;
104 }
105 }
106
107 ep = (protocomm_ep_t *) calloc(1, sizeof(protocomm_ep_t));
108 if (!ep) {
109 ESP_LOGE(TAG, "Error allocating endpoint resource");
110 return ESP_ERR_NO_MEM;
111 }
112
113 /* Initialize ep handler */
114 ep->ep_name = ep_name;
115 ep->req_handler = h;
116 ep->priv_data = priv_data;
117 ep->flag = flag;
118
119 /* Add endpoint to the head of the singly linked list */
120 SLIST_INSERT_HEAD(&pc->endpoints, ep, next);
121
122 return ESP_OK;
123 }
124
protocomm_add_endpoint(protocomm_t * pc,const char * ep_name,protocomm_req_handler_t h,void * priv_data)125 esp_err_t protocomm_add_endpoint(protocomm_t *pc, const char *ep_name,
126 protocomm_req_handler_t h, void *priv_data)
127 {
128 return protocomm_add_endpoint_internal(pc, ep_name, h, priv_data, REQ_EP);
129 }
130
protocomm_remove_endpoint(protocomm_t * pc,const char * ep_name)131 esp_err_t protocomm_remove_endpoint(protocomm_t *pc, const char *ep_name)
132 {
133 if ((pc == NULL) || (ep_name == NULL)) {
134 return ESP_ERR_INVALID_ARG;
135 }
136
137 if (pc->remove_endpoint) {
138 pc->remove_endpoint(ep_name);
139 }
140
141 protocomm_ep_t *it, *tmp;
142 SLIST_FOREACH_SAFE(it, &pc->endpoints, next, tmp) {
143 if (strcmp(ep_name, it->ep_name) == 0) {
144 SLIST_REMOVE(&pc->endpoints, it, protocomm_ep, next);
145 free(it);
146 return ESP_OK;
147 }
148 }
149 return ESP_ERR_NOT_FOUND;
150 }
151
protocomm_open_session(protocomm_t * pc,uint32_t session_id)152 esp_err_t protocomm_open_session(protocomm_t *pc, uint32_t session_id)
153 {
154 if (!pc) {
155 return ESP_ERR_INVALID_ARG;
156 }
157
158 if (pc->sec && pc->sec->new_transport_session) {
159 esp_err_t ret = pc->sec->new_transport_session(pc->sec_inst, session_id);
160 if (ret != ESP_OK) {
161 ESP_LOGE(TAG, "Failed to launch new session with ID: %d", session_id);
162 return ret;
163 }
164 }
165 return ESP_OK;
166 }
167
protocomm_close_session(protocomm_t * pc,uint32_t session_id)168 esp_err_t protocomm_close_session(protocomm_t *pc, uint32_t session_id)
169 {
170 if (!pc) {
171 return ESP_ERR_INVALID_ARG;
172 }
173
174 if (pc->sec && pc->sec->close_transport_session) {
175 esp_err_t ret = pc->sec->close_transport_session(pc->sec_inst, session_id);
176 if (ret != ESP_OK) {
177 ESP_LOGE(TAG, "Error while closing session with ID: %d", session_id);
178 return ret;
179 }
180 }
181 return ESP_OK;
182 }
183
protocomm_req_handle(protocomm_t * pc,const char * ep_name,uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen)184 esp_err_t protocomm_req_handle(protocomm_t *pc, const char *ep_name, uint32_t session_id,
185 const uint8_t *inbuf, ssize_t inlen,
186 uint8_t **outbuf, ssize_t *outlen)
187 {
188 if (!pc || !ep_name || !outbuf || !outlen) {
189 ESP_LOGE(TAG, "Invalid params %p %p", pc, ep_name);
190 return ESP_ERR_INVALID_ARG;
191 }
192
193 *outbuf = NULL;
194 *outlen = 0;
195
196 protocomm_ep_t *ep = search_endpoint(pc, ep_name);
197 if (!ep) {
198 ESP_LOGE(TAG, "No registered endpoint for %s", ep_name);
199 return ESP_ERR_NOT_FOUND;
200 }
201
202 esp_err_t ret = ESP_FAIL;
203 if (ep->flag & SEC_EP) {
204 /* Call the registered endpoint handler for establishing secure session */
205 ret = ep->req_handler(session_id, inbuf, inlen, outbuf, outlen, ep->priv_data);
206 ESP_LOGD(TAG, "SEC_EP Req handler returned %d", ret);
207 } else if (ep->flag & REQ_EP) {
208 if (pc->sec && pc->sec->decrypt) {
209 /* Decrypt the data first */
210 uint8_t *dec_inbuf = (uint8_t *) malloc(inlen);
211 if (!dec_inbuf) {
212 ESP_LOGE(TAG, "Failed to allocate decrypt buf len %d", inlen);
213 return ESP_ERR_NO_MEM;
214 }
215
216 ssize_t dec_inbuf_len = inlen;
217 ret = pc->sec->decrypt(pc->sec_inst, session_id, inbuf, inlen, dec_inbuf, &dec_inbuf_len);
218 if (ret != ESP_OK) {
219 ESP_LOGE(TAG, "Decryption of response failed for endpoint %s", ep_name);
220 free(dec_inbuf);
221 return ret;
222 }
223
224 /* Invoke the request handler */
225 uint8_t *plaintext_resp = NULL;
226 ssize_t plaintext_resp_len = 0;
227 ret = ep->req_handler(session_id,
228 dec_inbuf, dec_inbuf_len,
229 &plaintext_resp, &plaintext_resp_len,
230 ep->priv_data);
231 if (ret != ESP_OK) {
232 ESP_LOGE(TAG, "Request handler for %s failed", ep_name);
233 free(plaintext_resp);
234 free(dec_inbuf);
235 return ret;
236 }
237 /* We don't need decrypted data anymore */
238 free(dec_inbuf);
239
240 /* Encrypt response to be sent back */
241 uint8_t *enc_resp = (uint8_t *) malloc(plaintext_resp_len);
242 if (!enc_resp) {
243 ESP_LOGE(TAG, "Failed to allocate decrypt buf len %d", inlen);
244 free(plaintext_resp);
245 return ESP_ERR_NO_MEM;
246 }
247
248 ssize_t enc_resp_len = plaintext_resp_len;
249 ret = pc->sec->encrypt(pc->sec_inst, session_id, plaintext_resp, plaintext_resp_len,
250 enc_resp, &enc_resp_len);
251
252 if (ret != ESP_OK) {
253 ESP_LOGE(TAG, "Encryption of response failed for endpoint %s", ep_name);
254 free(enc_resp);
255 free(plaintext_resp);
256 return ret;
257 }
258
259 /* We no more need plaintext response */
260 free(plaintext_resp);
261
262 /* Set outbuf and outlen appropriately */
263 *outbuf = enc_resp;
264 *outlen = enc_resp_len;
265 } else {
266 /* No encryption */
267 ret = ep->req_handler(session_id,
268 inbuf, inlen,
269 outbuf, outlen,
270 ep->priv_data);
271 ESP_LOGD(TAG, "No encrypt ret %d", ret);
272 }
273 } else if (ep->flag & VER_EP) {
274 ret = ep->req_handler(session_id, inbuf, inlen, outbuf, outlen, ep->priv_data);
275 ESP_LOGD(TAG, "VER_EP Req handler returned %d", ret);
276 }
277 return ret;
278 }
279
protocomm_common_security_handler(uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen,void * priv_data)280 static int protocomm_common_security_handler(uint32_t session_id,
281 const uint8_t *inbuf, ssize_t inlen,
282 uint8_t **outbuf, ssize_t *outlen,
283 void *priv_data)
284 {
285 protocomm_t *pc = (protocomm_t *) priv_data;
286
287 if (pc->sec && pc->sec->security_req_handler) {
288 return pc->sec->security_req_handler(pc->sec_inst,
289 pc->pop, session_id,
290 inbuf, inlen,
291 outbuf, outlen,
292 priv_data);
293 }
294
295 return ESP_OK;
296 }
297
protocomm_set_security(protocomm_t * pc,const char * ep_name,const protocomm_security_t * sec,const protocomm_security_pop_t * pop)298 esp_err_t protocomm_set_security(protocomm_t *pc, const char *ep_name,
299 const protocomm_security_t *sec,
300 const protocomm_security_pop_t *pop)
301 {
302 if ((pc == NULL) || (ep_name == NULL) || (sec == NULL)) {
303 return ESP_ERR_INVALID_ARG;
304 }
305
306 if (pc->sec) {
307 return ESP_ERR_INVALID_STATE;
308 }
309
310 esp_err_t ret = protocomm_add_endpoint_internal(pc, ep_name,
311 protocomm_common_security_handler,
312 (void *) pc, SEC_EP);
313 if (ret != ESP_OK) {
314 ESP_LOGE(TAG, "Error adding security endpoint");
315 return ret;
316 }
317
318 if (sec->init) {
319 ret = sec->init(&pc->sec_inst);
320 if (ret != ESP_OK) {
321 ESP_LOGE(TAG, "Error initializing security");
322 protocomm_remove_endpoint(pc, ep_name);
323 return ret;
324 }
325 }
326 pc->sec = sec;
327
328 if (pop) {
329 pc->pop = malloc(sizeof(protocomm_security_pop_t));
330 if (pc->pop == NULL) {
331 ESP_LOGE(TAG, "Error allocating Proof of Possession");
332 if (pc->sec && pc->sec->cleanup) {
333 pc->sec->cleanup(pc->sec_inst);
334 pc->sec_inst = NULL;
335 pc->sec = NULL;
336 }
337
338 protocomm_remove_endpoint(pc, ep_name);
339 return ESP_ERR_NO_MEM;
340 }
341 memcpy((void *)pc->pop, pop, sizeof(protocomm_security_pop_t));
342 }
343 return ESP_OK;
344 }
345
protocomm_unset_security(protocomm_t * pc,const char * ep_name)346 esp_err_t protocomm_unset_security(protocomm_t *pc, const char *ep_name)
347 {
348 if ((pc == NULL) || (ep_name == NULL)) {
349 return ESP_FAIL;
350 }
351
352 if (pc->sec && pc->sec->cleanup) {
353 pc->sec->cleanup(pc->sec_inst);
354 pc->sec_inst = NULL;
355 pc->sec = NULL;
356 }
357
358 if (pc->pop) {
359 free(pc->pop);
360 pc->pop = NULL;
361 }
362
363 return protocomm_remove_endpoint(pc, ep_name);
364 }
365
protocomm_version_handler(uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen,void * priv_data)366 static int protocomm_version_handler(uint32_t session_id,
367 const uint8_t *inbuf, ssize_t inlen,
368 uint8_t **outbuf, ssize_t *outlen,
369 void *priv_data)
370 {
371 protocomm_t *pc = (protocomm_t *) priv_data;
372 if (!pc->ver) {
373 *outlen = 0;
374 *outbuf = NULL;
375 return ESP_OK;
376 }
377
378 /* Output is a non null terminated string with length specified */
379 *outlen = strlen(pc->ver);
380 *outbuf = malloc(*outlen);
381 if (*outbuf == NULL) {
382 ESP_LOGE(TAG, "Failed to allocate memory for version response");
383 return ESP_ERR_NO_MEM;
384 }
385
386 memcpy(*outbuf, pc->ver, *outlen);
387 return ESP_OK;
388 }
389
protocomm_set_version(protocomm_t * pc,const char * ep_name,const char * version)390 esp_err_t protocomm_set_version(protocomm_t *pc, const char *ep_name, const char *version)
391 {
392 if ((pc == NULL) || (ep_name == NULL) || (version == NULL)) {
393 return ESP_ERR_INVALID_ARG;
394 }
395
396 if (pc->ver) {
397 return ESP_ERR_INVALID_STATE;
398 }
399
400 pc->ver = strdup(version);
401 if (pc->ver == NULL) {
402 ESP_LOGE(TAG, "Error allocating version string");
403 return ESP_ERR_NO_MEM;
404 }
405
406 esp_err_t ret = protocomm_add_endpoint_internal(pc, ep_name,
407 protocomm_version_handler,
408 (void *) pc, VER_EP);
409 if (ret != ESP_OK) {
410 ESP_LOGE(TAG, "Error adding version endpoint");
411 return ret;
412 }
413 return ESP_OK;
414 }
415
protocomm_unset_version(protocomm_t * pc,const char * ep_name)416 esp_err_t protocomm_unset_version(protocomm_t *pc, const char *ep_name)
417 {
418 if ((pc == NULL) || (ep_name == NULL)) {
419 return ESP_ERR_INVALID_ARG;
420 }
421
422 if (pc->ver) {
423 free((char *)pc->ver);
424 pc->ver = NULL;
425 }
426
427 return protocomm_remove_endpoint(pc, ep_name);
428 }
429