1 // Copyright 2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <stdlib.h>
16 #include <string.h>
17 
18 #include <esp_err.h>
19 #include <esp_log.h>
20 #include <sys/queue.h>
21 
22 #include <protocomm.h>
23 #include <protocomm_security.h>
24 
25 #include "protocomm_priv.h"
26 
27 static const char *TAG = "protocomm";
28 
protocomm_new(void)29 protocomm_t *protocomm_new(void)
30 {
31     protocomm_t *pc;
32 
33     pc = (protocomm_t *) calloc(1, sizeof(protocomm_t));
34     if (!pc) {
35        ESP_LOGE(TAG, "Error allocating protocomm");
36        return NULL;
37     }
38     SLIST_INIT(&pc->endpoints);
39 
40     return pc;
41 }
42 
protocomm_delete(protocomm_t * pc)43 void protocomm_delete(protocomm_t *pc)
44 {
45     if (pc == NULL) {
46         return;
47     }
48 
49     protocomm_ep_t *it, *tmp;
50     /* Remove endpoints first */
51     SLIST_FOREACH_SAFE(it, &pc->endpoints, next, tmp) {
52         free(it);
53     }
54 
55     /* Free memory allocated to version string */
56     if (pc->ver) {
57         free((void *)pc->ver);
58     }
59 
60     /* Free memory allocated to security */
61     if (pc->sec && pc->sec->cleanup) {
62         pc->sec->cleanup(pc->sec_inst);
63     }
64     if (pc->pop) {
65         free(pc->pop);
66     }
67 
68     free(pc);
69 }
70 
search_endpoint(protocomm_t * pc,const char * ep_name)71 static protocomm_ep_t *search_endpoint(protocomm_t *pc, const char *ep_name)
72 {
73     protocomm_ep_t *it;
74     SLIST_FOREACH(it, &pc->endpoints, next) {
75         if (strcmp(it->ep_name, ep_name) == 0) {
76             return it;
77         }
78     }
79     return NULL;
80 }
81 
protocomm_add_endpoint_internal(protocomm_t * pc,const char * ep_name,protocomm_req_handler_t h,void * priv_data,uint32_t flag)82 static esp_err_t protocomm_add_endpoint_internal(protocomm_t *pc, const char *ep_name,
83                                                  protocomm_req_handler_t h, void *priv_data,
84                                                  uint32_t flag)
85 {
86     if ((pc == NULL) || (ep_name == NULL) || (h == NULL)) {
87         return ESP_ERR_INVALID_ARG;
88     }
89 
90     protocomm_ep_t *ep;
91     esp_err_t ret;
92 
93     ep = search_endpoint(pc, ep_name);
94     if (ep) {
95         ESP_LOGE(TAG, "Endpoint with this name already exists");
96         return ESP_FAIL;
97     }
98 
99     if (pc->add_endpoint) {
100         ret = pc->add_endpoint(ep_name, h, priv_data);
101         if (ret != ESP_OK) {
102             ESP_LOGE(TAG, "Error adding endpoint");
103             return ret;
104         }
105     }
106 
107     ep = (protocomm_ep_t *) calloc(1, sizeof(protocomm_ep_t));
108     if (!ep) {
109         ESP_LOGE(TAG, "Error allocating endpoint resource");
110         return ESP_ERR_NO_MEM;
111     }
112 
113     /* Initialize ep handler */
114     ep->ep_name = ep_name;
115     ep->req_handler = h;
116     ep->priv_data = priv_data;
117     ep->flag = flag;
118 
119     /* Add endpoint to the head of the singly linked list */
120     SLIST_INSERT_HEAD(&pc->endpoints, ep, next);
121 
122     return ESP_OK;
123 }
124 
protocomm_add_endpoint(protocomm_t * pc,const char * ep_name,protocomm_req_handler_t h,void * priv_data)125 esp_err_t protocomm_add_endpoint(protocomm_t *pc, const char *ep_name,
126                                  protocomm_req_handler_t h, void *priv_data)
127 {
128     return protocomm_add_endpoint_internal(pc, ep_name, h, priv_data, REQ_EP);
129 }
130 
protocomm_remove_endpoint(protocomm_t * pc,const char * ep_name)131 esp_err_t protocomm_remove_endpoint(protocomm_t *pc, const char *ep_name)
132 {
133     if ((pc == NULL) || (ep_name == NULL)) {
134         return ESP_ERR_INVALID_ARG;
135     }
136 
137     if (pc->remove_endpoint) {
138         pc->remove_endpoint(ep_name);
139     }
140 
141     protocomm_ep_t *it, *tmp;
142     SLIST_FOREACH_SAFE(it, &pc->endpoints, next, tmp) {
143         if (strcmp(ep_name, it->ep_name) == 0) {
144             SLIST_REMOVE(&pc->endpoints, it, protocomm_ep, next);
145             free(it);
146             return ESP_OK;
147         }
148     }
149     return ESP_ERR_NOT_FOUND;
150 }
151 
protocomm_open_session(protocomm_t * pc,uint32_t session_id)152 esp_err_t protocomm_open_session(protocomm_t *pc, uint32_t session_id)
153 {
154     if (!pc) {
155         return ESP_ERR_INVALID_ARG;
156     }
157 
158     if (pc->sec && pc->sec->new_transport_session) {
159         esp_err_t ret = pc->sec->new_transport_session(pc->sec_inst, session_id);
160         if (ret != ESP_OK) {
161             ESP_LOGE(TAG, "Failed to launch new session with ID: %d", session_id);
162             return ret;
163         }
164     }
165     return ESP_OK;
166 }
167 
protocomm_close_session(protocomm_t * pc,uint32_t session_id)168 esp_err_t protocomm_close_session(protocomm_t *pc, uint32_t session_id)
169 {
170     if (!pc) {
171         return ESP_ERR_INVALID_ARG;
172     }
173 
174     if (pc->sec && pc->sec->close_transport_session) {
175         esp_err_t ret = pc->sec->close_transport_session(pc->sec_inst, session_id);
176         if (ret != ESP_OK) {
177             ESP_LOGE(TAG, "Error while closing session with ID: %d", session_id);
178             return ret;
179         }
180     }
181     return ESP_OK;
182 }
183 
protocomm_req_handle(protocomm_t * pc,const char * ep_name,uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen)184 esp_err_t protocomm_req_handle(protocomm_t *pc, const char *ep_name, uint32_t session_id,
185                                const uint8_t *inbuf, ssize_t inlen,
186                                uint8_t **outbuf, ssize_t *outlen)
187 {
188     if (!pc || !ep_name || !outbuf || !outlen) {
189         ESP_LOGE(TAG, "Invalid params %p %p", pc, ep_name);
190         return ESP_ERR_INVALID_ARG;
191     }
192 
193     *outbuf = NULL;
194     *outlen = 0;
195 
196     protocomm_ep_t *ep = search_endpoint(pc, ep_name);
197     if (!ep) {
198         ESP_LOGE(TAG, "No registered endpoint for %s", ep_name);
199         return ESP_ERR_NOT_FOUND;
200     }
201 
202     esp_err_t ret = ESP_FAIL;
203     if (ep->flag & SEC_EP) {
204         /* Call the registered endpoint handler for establishing secure session */
205         ret = ep->req_handler(session_id, inbuf, inlen, outbuf, outlen, ep->priv_data);
206         ESP_LOGD(TAG, "SEC_EP Req handler returned %d", ret);
207     } else if (ep->flag & REQ_EP) {
208         if (pc->sec && pc->sec->decrypt) {
209             /* Decrypt the data first */
210             uint8_t *dec_inbuf = (uint8_t *) malloc(inlen);
211             if (!dec_inbuf) {
212                 ESP_LOGE(TAG, "Failed to allocate decrypt buf len %d", inlen);
213                 return ESP_ERR_NO_MEM;
214             }
215 
216             ssize_t dec_inbuf_len = inlen;
217             ret = pc->sec->decrypt(pc->sec_inst, session_id, inbuf, inlen, dec_inbuf, &dec_inbuf_len);
218             if (ret != ESP_OK) {
219                 ESP_LOGE(TAG, "Decryption of response failed for endpoint %s", ep_name);
220                 free(dec_inbuf);
221                 return ret;
222             }
223 
224             /* Invoke the request handler */
225             uint8_t *plaintext_resp = NULL;
226             ssize_t plaintext_resp_len = 0;
227             ret = ep->req_handler(session_id,
228                                   dec_inbuf, dec_inbuf_len,
229                                   &plaintext_resp, &plaintext_resp_len,
230                                   ep->priv_data);
231             if (ret != ESP_OK) {
232                 ESP_LOGE(TAG, "Request handler for %s failed", ep_name);
233                 free(plaintext_resp);
234                 free(dec_inbuf);
235                 return ret;
236             }
237             /* We don't need decrypted data anymore */
238             free(dec_inbuf);
239 
240             /* Encrypt response to be sent back */
241             uint8_t *enc_resp = (uint8_t *) malloc(plaintext_resp_len);
242             if (!enc_resp) {
243                 ESP_LOGE(TAG, "Failed to allocate decrypt buf len %d", inlen);
244                 free(plaintext_resp);
245                 return ESP_ERR_NO_MEM;
246             }
247 
248             ssize_t enc_resp_len = plaintext_resp_len;
249             ret = pc->sec->encrypt(pc->sec_inst, session_id, plaintext_resp, plaintext_resp_len,
250                                    enc_resp, &enc_resp_len);
251 
252             if (ret != ESP_OK) {
253                 ESP_LOGE(TAG, "Encryption of response failed for endpoint %s", ep_name);
254                 free(enc_resp);
255                 free(plaintext_resp);
256                 return ret;
257             }
258 
259             /* We no more need plaintext response */
260             free(plaintext_resp);
261 
262             /* Set outbuf and outlen appropriately */
263             *outbuf = enc_resp;
264             *outlen = enc_resp_len;
265         } else {
266             /* No encryption */
267             ret = ep->req_handler(session_id,
268                                   inbuf, inlen,
269                                   outbuf, outlen,
270                                   ep->priv_data);
271             ESP_LOGD(TAG, "No encrypt ret %d", ret);
272         }
273     } else if (ep->flag & VER_EP) {
274         ret = ep->req_handler(session_id, inbuf, inlen, outbuf, outlen, ep->priv_data);
275         ESP_LOGD(TAG, "VER_EP Req handler returned %d", ret);
276     }
277     return ret;
278 }
279 
protocomm_common_security_handler(uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen,void * priv_data)280 static int protocomm_common_security_handler(uint32_t session_id,
281                                              const uint8_t *inbuf, ssize_t inlen,
282                                              uint8_t **outbuf, ssize_t *outlen,
283                                              void *priv_data)
284 {
285     protocomm_t *pc = (protocomm_t *) priv_data;
286 
287     if (pc->sec && pc->sec->security_req_handler) {
288         return pc->sec->security_req_handler(pc->sec_inst,
289                                              pc->pop, session_id,
290                                              inbuf, inlen,
291                                              outbuf, outlen,
292                                              priv_data);
293     }
294 
295     return ESP_OK;
296 }
297 
protocomm_set_security(protocomm_t * pc,const char * ep_name,const protocomm_security_t * sec,const protocomm_security_pop_t * pop)298 esp_err_t protocomm_set_security(protocomm_t *pc, const char *ep_name,
299                                  const protocomm_security_t *sec,
300                                  const protocomm_security_pop_t *pop)
301 {
302     if ((pc == NULL) || (ep_name == NULL) || (sec == NULL)) {
303         return ESP_ERR_INVALID_ARG;
304     }
305 
306     if (pc->sec) {
307         return ESP_ERR_INVALID_STATE;
308     }
309 
310     esp_err_t ret = protocomm_add_endpoint_internal(pc, ep_name,
311                                                     protocomm_common_security_handler,
312                                                     (void *) pc, SEC_EP);
313     if (ret != ESP_OK) {
314         ESP_LOGE(TAG, "Error adding security endpoint");
315         return ret;
316     }
317 
318     if (sec->init) {
319         ret = sec->init(&pc->sec_inst);
320         if (ret != ESP_OK) {
321             ESP_LOGE(TAG, "Error initializing security");
322             protocomm_remove_endpoint(pc, ep_name);
323             return ret;
324         }
325     }
326     pc->sec = sec;
327 
328     if (pop) {
329         pc->pop = malloc(sizeof(protocomm_security_pop_t));
330         if (pc->pop == NULL) {
331             ESP_LOGE(TAG, "Error allocating Proof of Possession");
332             if (pc->sec && pc->sec->cleanup) {
333                 pc->sec->cleanup(pc->sec_inst);
334                 pc->sec_inst = NULL;
335                 pc->sec = NULL;
336             }
337 
338             protocomm_remove_endpoint(pc, ep_name);
339             return ESP_ERR_NO_MEM;
340         }
341         memcpy((void *)pc->pop, pop, sizeof(protocomm_security_pop_t));
342     }
343     return ESP_OK;
344 }
345 
protocomm_unset_security(protocomm_t * pc,const char * ep_name)346 esp_err_t protocomm_unset_security(protocomm_t *pc, const char *ep_name)
347 {
348     if ((pc == NULL) || (ep_name == NULL)) {
349         return ESP_FAIL;
350     }
351 
352     if (pc->sec && pc->sec->cleanup) {
353         pc->sec->cleanup(pc->sec_inst);
354         pc->sec_inst = NULL;
355         pc->sec = NULL;
356     }
357 
358     if (pc->pop) {
359         free(pc->pop);
360         pc->pop = NULL;
361     }
362 
363     return protocomm_remove_endpoint(pc, ep_name);
364 }
365 
protocomm_version_handler(uint32_t session_id,const uint8_t * inbuf,ssize_t inlen,uint8_t ** outbuf,ssize_t * outlen,void * priv_data)366 static int protocomm_version_handler(uint32_t session_id,
367                                      const uint8_t *inbuf, ssize_t inlen,
368                                      uint8_t **outbuf, ssize_t *outlen,
369                                      void *priv_data)
370 {
371     protocomm_t *pc = (protocomm_t *) priv_data;
372     if (!pc->ver) {
373         *outlen = 0;
374         *outbuf = NULL;
375         return ESP_OK;
376     }
377 
378     /* Output is a non null terminated string with length specified */
379     *outlen = strlen(pc->ver);
380     *outbuf = malloc(*outlen);
381     if (*outbuf == NULL) {
382         ESP_LOGE(TAG, "Failed to allocate memory for version response");
383         return ESP_ERR_NO_MEM;
384     }
385 
386     memcpy(*outbuf, pc->ver, *outlen);
387     return ESP_OK;
388 }
389 
protocomm_set_version(protocomm_t * pc,const char * ep_name,const char * version)390 esp_err_t protocomm_set_version(protocomm_t *pc, const char *ep_name, const char *version)
391 {
392     if ((pc == NULL) || (ep_name == NULL) || (version == NULL)) {
393         return ESP_ERR_INVALID_ARG;
394     }
395 
396     if (pc->ver) {
397         return ESP_ERR_INVALID_STATE;
398     }
399 
400     pc->ver = strdup(version);
401     if (pc->ver == NULL) {
402         ESP_LOGE(TAG, "Error allocating version string");
403         return ESP_ERR_NO_MEM;
404     }
405 
406     esp_err_t ret = protocomm_add_endpoint_internal(pc, ep_name,
407                                                     protocomm_version_handler,
408                                                     (void *) pc, VER_EP);
409     if (ret != ESP_OK) {
410         ESP_LOGE(TAG, "Error adding version endpoint");
411         return ret;
412     }
413     return ESP_OK;
414 }
415 
protocomm_unset_version(protocomm_t * pc,const char * ep_name)416 esp_err_t protocomm_unset_version(protocomm_t *pc, const char *ep_name)
417 {
418     if ((pc == NULL) || (ep_name == NULL)) {
419         return ESP_ERR_INVALID_ARG;
420     }
421 
422     if (pc->ver) {
423         free((char *)pc->ver);
424         pc->ver = NULL;
425     }
426 
427     return protocomm_remove_endpoint(pc, ep_name);
428 }
429