1 // Copyright 2015-2021 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <stdlib.h>
16 #include <string.h>
17 #include <ctype.h>
18 #include <sys/random.h>
19 #include <sys/socket.h>
20 #include "esp_log.h"
21 #include "esp_transport.h"
22 #include "esp_transport_tcp.h"
23 #include "esp_transport_ws.h"
24 #include "esp_transport_utils.h"
25 #include "esp_transport_internal.h"
26 #include "errno.h"
27 #include "esp_tls_crypto.h"
28 
29 static const char *TAG = "TRANSPORT_WS";
30 
31 #define WS_BUFFER_SIZE              CONFIG_WS_BUFFER_SIZE
32 #define WS_FIN                      0x80
33 #define WS_OPCODE_CONT              0x00
34 #define WS_OPCODE_TEXT              0x01
35 #define WS_OPCODE_BINARY            0x02
36 #define WS_OPCODE_CLOSE             0x08
37 #define WS_OPCODE_PING              0x09
38 #define WS_OPCODE_PONG              0x0a
39 #define WS_OPCODE_CONTROL_FRAME     0x08
40 
41 // Second byte
42 #define WS_MASK                     0x80
43 #define WS_SIZE16                   126
44 #define WS_SIZE64                   127
45 #define MAX_WEBSOCKET_HEADER_SIZE   16
46 #define WS_RESPONSE_OK              101
47 #define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125
48 
49 
50 typedef struct {
51     uint8_t opcode;
52     char mask_key[4];                   /*!< Mask key for this payload */
53     int payload_len;                    /*!< Total length of the payload */
54     int bytes_remaining;                /*!< Bytes left to read of the payload  */
55     bool header_received;               /*!< Flag to indicate that a new message header was received */
56 } ws_transport_frame_state_t;
57 
58 typedef struct {
59     char *path;
60     char *buffer;
61     char *sub_protocol;
62     char *user_agent;
63     char *headers;
64     bool propagate_control_frames;
65     ws_transport_frame_state_t frame_state;
66     esp_transport_handle_t parent;
67 } transport_ws_t;
68 
69 /**
70  * @brief               Handles control frames
71  *
72  * This API is used internally to handle control frames at the transport layer.
73  * The API could be possibly promoted to a public API if needed by some clients
74  *
75  * @param t             Websocket transport handle
76  * @param buffer        Buffer with the actual payload of the control packet to be processed
77  * @param len           Length of the buffer (typically the same as the payload buffer)
78  * @param timeout_ms    The timeout milliseconds
79  * @param client_closed To indicate that the connection has been closed by the client
80 *                       (to prevent echoing the CLOSE packet if true, as this is the actual echo from the server)
81  *
82  * @return
83  *      0 - no activity, or successfully responded to PING
84  *      -1 - Failure: Error on read or the actual payload longer then buffer
85  *      1 - Close handshake success
86  *      2 - Got PONG message
87  */
88 
89 static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, bool client_closed);
90 
ws_get_bin_opcode(ws_transport_opcodes_t opcode)91 static inline uint8_t ws_get_bin_opcode(ws_transport_opcodes_t opcode)
92 {
93     return (uint8_t)opcode;
94 }
95 
ws_get_payload_transport_handle(esp_transport_handle_t t)96 static esp_transport_handle_t ws_get_payload_transport_handle(esp_transport_handle_t t)
97 {
98     transport_ws_t *ws = esp_transport_get_context_data(t);
99 
100     /* Reading parts of a frame directly will disrupt the WS internal frame state,
101         reset bytes_remaining to prepare for reading a new frame */
102     ws->frame_state.bytes_remaining = 0;
103 
104     return ws->parent;
105 }
106 
trimwhitespace(const char * str)107 static char *trimwhitespace(const char *str)
108 {
109     char *end;
110 
111     // Trim leading space
112     while (isspace((unsigned char)*str)) str++;
113 
114     if (*str == 0) {
115         return (char *)str;
116     }
117 
118     // Trim trailing space
119     end = (char *)(str + strlen(str) - 1);
120     while (end > str && isspace((unsigned char)*end)) end--;
121 
122     // Write new null terminator
123     *(end + 1) = 0;
124 
125     return (char *)str;
126 }
127 
get_http_header(const char * buffer,const char * key)128 static char *get_http_header(const char *buffer, const char *key)
129 {
130     char *found = strcasestr(buffer, key);
131     if (found) {
132         found += strlen(key);
133         char *found_end = strstr(found, "\r\n");
134         if (found_end) {
135             found_end[0] = 0;//terminal string
136 
137             return trimwhitespace(found);
138         }
139     }
140     return NULL;
141 }
142 
ws_connect(esp_transport_handle_t t,const char * host,int port,int timeout_ms)143 static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
144 {
145     transport_ws_t *ws = esp_transport_get_context_data(t);
146     if (esp_transport_connect(ws->parent, host, port, timeout_ms) < 0) {
147         ESP_LOGE(TAG, "Error connecting to host %s:%d", host, port);
148         return -1;
149     }
150 
151     unsigned char random_key[16];
152     getrandom(random_key, sizeof(random_key), 0);
153 
154     // Size of base64 coded string is equal '((input_size * 4) / 3) + (input_size / 96) + 6' including Z-term
155     unsigned char client_key[28] = {0};
156 
157     const char *user_agent_ptr = (ws->user_agent)?(ws->user_agent):"ESP32 Websocket Client";
158 
159     size_t outlen = 0;
160     esp_crypto_base64_encode(client_key, sizeof(client_key), &outlen, random_key, sizeof(random_key));
161     int len = snprintf(ws->buffer, WS_BUFFER_SIZE,
162                          "GET %s HTTP/1.1\r\n"
163                          "Connection: Upgrade\r\n"
164                          "Host: %s:%d\r\n"
165                          "User-Agent: %s\r\n"
166                          "Upgrade: websocket\r\n"
167                          "Sec-WebSocket-Version: 13\r\n"
168                          "Sec-WebSocket-Key: %s\r\n",
169                          ws->path,
170                          host, port, user_agent_ptr,
171                          client_key);
172     if (len <= 0 || len >= WS_BUFFER_SIZE) {
173         ESP_LOGE(TAG, "Error in request generation, desired request len: %d, buffer size: %d", len, WS_BUFFER_SIZE);
174         return -1;
175     }
176     if (ws->sub_protocol) {
177         ESP_LOGD(TAG, "sub_protocol: %s", ws->sub_protocol);
178         int r = snprintf(ws->buffer + len, WS_BUFFER_SIZE - len, "Sec-WebSocket-Protocol: %s\r\n", ws->sub_protocol);
179         len += r;
180         if (r <= 0 || len >= WS_BUFFER_SIZE) {
181             ESP_LOGE(TAG, "Error in request generation"
182                           "(snprintf of subprotocol returned %d, desired request len: %d, buffer size: %d", r, len, WS_BUFFER_SIZE);
183             return -1;
184         }
185     }
186     if (ws->headers) {
187         ESP_LOGD(TAG, "headers: %s", ws->headers);
188         int r = snprintf(ws->buffer + len, WS_BUFFER_SIZE - len, "%s", ws->headers);
189         len += r;
190         if (r <= 0 || len >= WS_BUFFER_SIZE) {
191             ESP_LOGE(TAG, "Error in request generation"
192                           "(strncpy of headers returned %d, desired request len: %d, buffer size: %d", r, len, WS_BUFFER_SIZE);
193             return -1;
194         }
195     }
196     int r = snprintf(ws->buffer + len, WS_BUFFER_SIZE - len, "\r\n");
197     len += r;
198     if (r <= 0 || len >= WS_BUFFER_SIZE) {
199         ESP_LOGE(TAG, "Error in request generation"
200                        "(snprintf of header terminal returned %d, desired request len: %d, buffer size: %d", r, len, WS_BUFFER_SIZE);
201         return -1;
202     }
203     ESP_LOGD(TAG, "Write upgrade request\r\n%s", ws->buffer);
204     if (esp_transport_write(ws->parent, ws->buffer, len, timeout_ms) <= 0) {
205         ESP_LOGE(TAG, "Error write Upgrade header %s", ws->buffer);
206         return -1;
207     }
208     int header_len = 0;
209     do {
210         if ((len = esp_transport_read(ws->parent, ws->buffer + header_len, WS_BUFFER_SIZE - header_len, timeout_ms)) <= 0) {
211             ESP_LOGE(TAG, "Error read response for Upgrade header %s", ws->buffer);
212             return -1;
213         }
214         header_len += len;
215         ws->buffer[header_len] = '\0';
216         ESP_LOGD(TAG, "Read header chunk %d, current header size: %d", len, header_len);
217     } while (NULL == strstr(ws->buffer, "\r\n\r\n") && header_len < WS_BUFFER_SIZE);
218 
219     char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
220     if (server_key == NULL) {
221         ESP_LOGE(TAG, "Sec-WebSocket-Accept not found");
222         return -1;
223     }
224 
225     // See esp_crypto_sha1() arg size
226     unsigned char expected_server_sha1[20];
227     // Size of base64 coded string see above
228     unsigned char expected_server_key[33] = {0};
229     // If you are interested, see https://tools.ietf.org/html/rfc6455
230     const char expected_server_magic[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
231     unsigned char expected_server_text[sizeof(client_key) + sizeof(expected_server_magic) + 1];
232     strcpy((char*)expected_server_text, (char*)client_key);
233     strcat((char*)expected_server_text, expected_server_magic);
234 
235     size_t key_len = strlen((char*)expected_server_text);
236     esp_crypto_sha1(expected_server_text, key_len, expected_server_sha1);
237     esp_crypto_base64_encode(expected_server_key, sizeof(expected_server_key),  &outlen, expected_server_sha1, sizeof(expected_server_sha1));
238     expected_server_key[ (outlen < sizeof(expected_server_key)) ? outlen : (sizeof(expected_server_key) - 1) ] = 0;
239     ESP_LOGD(TAG, "server key=%s, send_key=%s, expected_server_key=%s", (char *)server_key, (char*)client_key, expected_server_key);
240     if (strcmp((char*)expected_server_key, (char*)server_key) != 0) {
241         ESP_LOGE(TAG, "Invalid websocket key");
242         return -1;
243     }
244     return 0;
245 }
246 
_ws_write(esp_transport_handle_t t,int opcode,int mask_flag,const char * b,int len,int timeout_ms)247 static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const char *b, int len, int timeout_ms)
248 {
249     transport_ws_t *ws = esp_transport_get_context_data(t);
250     char *buffer = (char *)b;
251     char ws_header[MAX_WEBSOCKET_HEADER_SIZE];
252     char *mask;
253     int header_len = 0, i;
254 
255     int poll_write;
256     if ((poll_write = esp_transport_poll_write(ws->parent, timeout_ms)) <= 0) {
257         ESP_LOGE(TAG, "Error transport_poll_write");
258         return poll_write;
259     }
260     ws_header[header_len++] = opcode;
261 
262     if (len <= 125) {
263         ws_header[header_len++] = (uint8_t)(len | mask_flag);
264     } else if (len < 65536) {
265         ws_header[header_len++] = WS_SIZE16 | mask_flag;
266         ws_header[header_len++] = (uint8_t)(len >> 8);
267         ws_header[header_len++] = (uint8_t)(len & 0xFF);
268     } else {
269         ws_header[header_len++] = WS_SIZE64 | mask_flag;
270         /* Support maximum 4 bytes length */
271         ws_header[header_len++] = 0; //(uint8_t)((len >> 56) & 0xFF);
272         ws_header[header_len++] = 0; //(uint8_t)((len >> 48) & 0xFF);
273         ws_header[header_len++] = 0; //(uint8_t)((len >> 40) & 0xFF);
274         ws_header[header_len++] = 0; //(uint8_t)((len >> 32) & 0xFF);
275         ws_header[header_len++] = (uint8_t)((len >> 24) & 0xFF);
276         ws_header[header_len++] = (uint8_t)((len >> 16) & 0xFF);
277         ws_header[header_len++] = (uint8_t)((len >> 8) & 0xFF);
278         ws_header[header_len++] = (uint8_t)((len >> 0) & 0xFF);
279     }
280 
281     if (mask_flag) {
282         mask = &ws_header[header_len];
283         getrandom(ws_header + header_len, 4, 0);
284         header_len += 4;
285 
286         for (i = 0; i < len; ++i) {
287             buffer[i] = (buffer[i] ^ mask[i % 4]);
288         }
289     }
290 
291     if (esp_transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) {
292         ESP_LOGE(TAG, "Error write header");
293         return -1;
294     }
295     if (len == 0) {
296         return 0;
297     }
298 
299     int ret = esp_transport_write(ws->parent, buffer, len, timeout_ms);
300     // in case of masked transport we have to revert back to the original data, as ws layer
301     // does not create its own copy of data to be sent
302     if (mask_flag) {
303         mask = &ws_header[header_len-4];
304         for (i = 0; i < len; ++i) {
305             buffer[i] = (buffer[i] ^ mask[i % 4]);
306         }
307     }
308     return ret;
309 }
310 
esp_transport_ws_send_raw(esp_transport_handle_t t,ws_transport_opcodes_t opcode,const char * b,int len,int timeout_ms)311 int esp_transport_ws_send_raw(esp_transport_handle_t t, ws_transport_opcodes_t opcode, const char *b, int len, int timeout_ms)
312 {
313     uint8_t op_code = ws_get_bin_opcode(opcode);
314     if (t == NULL) {
315         ESP_LOGE(TAG, "Transport must be a valid ws handle");
316         return ESP_ERR_INVALID_ARG;
317     }
318     ESP_LOGD(TAG, "Sending raw ws message with opcode %d", op_code);
319     return _ws_write(t, op_code, WS_MASK, b, len, timeout_ms);
320 }
321 
ws_write(esp_transport_handle_t t,const char * b,int len,int timeout_ms)322 static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeout_ms)
323 {
324     if (len == 0) {
325         // Default transport write of zero length in ws layer sends out a ping message.
326         // This behaviour could however be altered in IDF 5.0, since a separate API for sending
327         // messages with user defined opcodes has been introduced.
328         ESP_LOGD(TAG, "Write PING message");
329         return _ws_write(t, WS_OPCODE_PING | WS_FIN, WS_MASK, NULL, 0, timeout_ms);
330     }
331     return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms);
332 }
333 
334 
ws_read_payload(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)335 static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
336 {
337     transport_ws_t *ws = esp_transport_get_context_data(t);
338 
339     int bytes_to_read;
340     int rlen = 0;
341 
342     if (ws->frame_state.bytes_remaining > len) {
343         ESP_LOGD(TAG, "Actual data to receive (%d) are longer than ws buffer (%d)", ws->frame_state.bytes_remaining, len);
344         bytes_to_read = len;
345 
346     } else {
347         bytes_to_read = ws->frame_state.bytes_remaining;
348     }
349 
350     // Receive and process payload
351     if (bytes_to_read != 0 && (rlen = esp_transport_read(ws->parent, buffer, bytes_to_read, timeout_ms)) <= 0) {
352         ESP_LOGE(TAG, "Error read data");
353         return rlen;
354     }
355     ws->frame_state.bytes_remaining -= rlen;
356 
357     if (ws->frame_state.mask_key) {
358         for (int i = 0; i < bytes_to_read; i++) {
359             buffer[i] = (buffer[i] ^ ws->frame_state.mask_key[i % 4]);
360         }
361     }
362     return rlen;
363 }
364 
365 
366 /* Read and parse the WS header, determine length of payload */
ws_read_header(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)367 static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
368 {
369     transport_ws_t *ws = esp_transport_get_context_data(t);
370     int payload_len;
371 
372     char ws_header[MAX_WEBSOCKET_HEADER_SIZE];
373     char *data_ptr = ws_header, mask;
374     int rlen;
375     int poll_read;
376     ws->frame_state.header_received = false;
377     if ((poll_read = esp_transport_poll_read(ws->parent, timeout_ms)) <= 0) {
378         return poll_read;
379     }
380 
381     // Receive and process header first (based on header size)
382     int header = 2;
383     int mask_len = 4;
384     if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
385         ESP_LOGE(TAG, "Error read data");
386         return rlen;
387     }
388     ws->frame_state.header_received = true;
389     ws->frame_state.opcode = (*data_ptr & 0x0F);
390     data_ptr ++;
391     mask = ((*data_ptr >> 7) & 0x01);
392     payload_len = (*data_ptr & 0x7F);
393     data_ptr++;
394     ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", ws->frame_state.opcode, mask, payload_len);
395     if (payload_len == 126) {
396         // headerLen += 2;
397         if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
398             ESP_LOGE(TAG, "Error read data");
399             return rlen;
400         }
401         payload_len = data_ptr[0] << 8 | data_ptr[1];
402     } else if (payload_len == 127) {
403         // headerLen += 8;
404         header = 8;
405         if ((rlen = esp_transport_read(ws->parent, data_ptr, header, timeout_ms)) <= 0) {
406             ESP_LOGE(TAG, "Error read data");
407             return rlen;
408         }
409 
410         if (data_ptr[0] != 0 || data_ptr[1] != 0 || data_ptr[2] != 0 || data_ptr[3] != 0) {
411             // really too big!
412             payload_len = 0xFFFFFFFF;
413         } else {
414             payload_len = data_ptr[4] << 24 | data_ptr[5] << 16 | data_ptr[6] << 8 | data_ptr[7];
415         }
416     }
417 
418     if (mask) {
419         // Read and store mask
420         if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, mask_len, timeout_ms)) <= 0) {
421             ESP_LOGE(TAG, "Error read data");
422             return rlen;
423         }
424         memcpy(ws->frame_state.mask_key, buffer, mask_len);
425     } else {
426         memset(ws->frame_state.mask_key, 0, mask_len);
427     }
428 
429     ws->frame_state.payload_len = payload_len;
430     ws->frame_state.bytes_remaining = payload_len;
431 
432     return payload_len;
433 }
434 
ws_handle_control_frame_internal(esp_transport_handle_t t,int timeout_ms)435 static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeout_ms)
436 {
437     transport_ws_t *ws = esp_transport_get_context_data(t);
438     char *control_frame_buffer = NULL;
439     int control_frame_buffer_len = 0;
440     int payload_len = ws->frame_state.payload_len;
441     int ret = 0;
442 
443     // If no new header reception in progress, or not a control frame
444     // just pass 0 -> no need to handle control frames
445     if (ws->frame_state.header_received == false ||
446         !(ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME)) {
447         return 0;
448     }
449 
450     if (payload_len > WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN) {
451         ESP_LOGE(TAG, "Not enough room for reading control frames (need=%d, max_allowed=%d)",
452                  ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN);
453         return -1;
454     }
455 
456     // Now we can handle the control frame correctly (either zero payload, or a short one for which we allocate mem)
457     control_frame_buffer_len = payload_len;
458     if (control_frame_buffer_len > 0) {
459         control_frame_buffer = malloc(control_frame_buffer_len);
460         if (control_frame_buffer == NULL) {
461             ESP_LOGE(TAG, "Cannot allocate buffer for control frames, need-%d", control_frame_buffer_len);
462             return -1;
463         }
464     } else {
465         control_frame_buffer_len = 0;
466     }
467 
468     // read the payload of the control frame
469     int actual_len = ws_read_payload(t, control_frame_buffer, control_frame_buffer_len, timeout_ms);
470     if (actual_len != payload_len) {
471         ESP_LOGE(TAG, "Control frame (opcode=%d) payload read failed (payload_len=%d, read_len=%d)",
472                  ws->frame_state.opcode, payload_len, actual_len);
473         ret = -1;
474         goto free_payload_buffer;
475     }
476 
477     ret = esp_transport_ws_handle_control_frames(t, control_frame_buffer, control_frame_buffer_len, timeout_ms, false);
478 
479 free_payload_buffer:
480     free(control_frame_buffer);
481     return ret > 0 ? 0 : ret; // We don't propagate control frames, pass 0 to upper layers
482 
483 }
484 
ws_read(esp_transport_handle_t t,char * buffer,int len,int timeout_ms)485 static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
486 {
487     int rlen = 0;
488     transport_ws_t *ws = esp_transport_get_context_data(t);
489 
490     // If message exceeds buffer len then subsequent reads will skip reading header and read whatever is left of the payload
491     if (ws->frame_state.bytes_remaining <= 0) {
492 
493         if ( (rlen = ws_read_header(t, buffer, len, timeout_ms)) < 0) {
494             // If something when wrong then we prepare for reading a new header
495             ws->frame_state.bytes_remaining = 0;
496             return rlen;
497         }
498 
499         // If the new opcode is a control frame and we don't pass it to the app
500         //  - try to handle it internally using the application buffer
501         if (ws->frame_state.header_received && (ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME) &&
502             ws->propagate_control_frames == false) {
503             // automatically handle only 0 payload frames and make the transport read to return 0 on success
504             // which might be interpreted as timeouts
505             return ws_handle_control_frame_internal(t, timeout_ms);
506         }
507 
508         if (rlen == 0) {
509             ws->frame_state.bytes_remaining = 0;
510             return 0; // timeout
511         }
512     }
513 
514     if (ws->frame_state.payload_len) {
515         if ( (rlen = ws_read_payload(t, buffer, len, timeout_ms)) <= 0) {
516             ESP_LOGE(TAG, "Error reading payload data");
517             ws->frame_state.bytes_remaining = 0;
518             return rlen;
519         }
520     }
521 
522     return rlen;
523 }
524 
525 
ws_poll_read(esp_transport_handle_t t,int timeout_ms)526 static int ws_poll_read(esp_transport_handle_t t, int timeout_ms)
527 {
528     transport_ws_t *ws = esp_transport_get_context_data(t);
529     return esp_transport_poll_read(ws->parent, timeout_ms);
530 }
531 
ws_poll_write(esp_transport_handle_t t,int timeout_ms)532 static int ws_poll_write(esp_transport_handle_t t, int timeout_ms)
533 {
534     transport_ws_t *ws = esp_transport_get_context_data(t);
535     return esp_transport_poll_write(ws->parent, timeout_ms);;
536 }
537 
ws_close(esp_transport_handle_t t)538 static int ws_close(esp_transport_handle_t t)
539 {
540     transport_ws_t *ws = esp_transport_get_context_data(t);
541     return esp_transport_close(ws->parent);
542 }
543 
ws_destroy(esp_transport_handle_t t)544 static esp_err_t ws_destroy(esp_transport_handle_t t)
545 {
546     transport_ws_t *ws = esp_transport_get_context_data(t);
547     free(ws->buffer);
548     free(ws->path);
549     free(ws->sub_protocol);
550     free(ws->user_agent);
551     free(ws->headers);
552     free(ws);
553     return 0;
554 }
internal_esp_transport_ws_set_path(esp_transport_handle_t t,const char * path)555 static esp_err_t internal_esp_transport_ws_set_path(esp_transport_handle_t t, const char *path)
556 {
557     if (t == NULL) {
558         return ESP_ERR_INVALID_ARG;
559     }
560     transport_ws_t *ws = esp_transport_get_context_data(t);
561     if (ws->path) {
562         free(ws->path);
563     }
564     if (path == NULL) {
565         ws->path = NULL;
566         return ESP_OK;
567     }
568     ws->path = strdup(path);
569     if (ws->path == NULL) {
570         return ESP_ERR_NO_MEM;
571     }
572     return ESP_OK;
573 }
574 
esp_transport_ws_set_path(esp_transport_handle_t t,const char * path)575 void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path)
576 {
577     esp_err_t err = internal_esp_transport_ws_set_path(t, path);
578     if (err != ESP_OK) {
579         ESP_LOGE(TAG, "esp_transport_ws_set_path has internally failed with err=%d", err);
580     }
581 }
582 
ws_get_socket(esp_transport_handle_t t)583 static int ws_get_socket(esp_transport_handle_t t)
584 {
585     if (t) {
586         transport_ws_t *ws = t->data;
587         if (ws && ws->parent && ws->parent->_get_socket) {
588             return ws->parent->_get_socket(ws->parent);
589         }
590     }
591     return -1;
592 }
593 
esp_transport_ws_init(esp_transport_handle_t parent_handle)594 esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle)
595 {
596     esp_transport_handle_t t = esp_transport_init();
597     if (t == NULL) {
598         return NULL;
599     }
600     transport_ws_t *ws = calloc(1, sizeof(transport_ws_t));
601     ESP_TRANSPORT_MEM_CHECK(TAG, ws, return NULL);
602     ws->parent = parent_handle;
603 
604     ws->path = strdup("/");
605     ESP_TRANSPORT_MEM_CHECK(TAG, ws->path, {
606         free(ws);
607         esp_transport_destroy(t);
608         return NULL;
609     });
610     ws->buffer = malloc(WS_BUFFER_SIZE);
611     ESP_TRANSPORT_MEM_CHECK(TAG, ws->buffer, {
612         free(ws->path);
613         free(ws);
614         esp_transport_destroy(t);
615         return NULL;
616     });
617 
618     esp_transport_set_func(t, ws_connect, ws_read, ws_write, ws_close, ws_poll_read, ws_poll_write, ws_destroy);
619     // websocket underlying transfer is the payload transfer handle
620     esp_transport_set_parent_transport_func(t, ws_get_payload_transport_handle);
621 
622     esp_transport_set_context_data(t, ws);
623     t->_get_socket = ws_get_socket;
624     return t;
625 }
626 
esp_transport_ws_set_subprotocol(esp_transport_handle_t t,const char * sub_protocol)627 esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol)
628 {
629     if (t == NULL) {
630         return ESP_ERR_INVALID_ARG;
631     }
632     transport_ws_t *ws = esp_transport_get_context_data(t);
633     if (ws->sub_protocol) {
634         free(ws->sub_protocol);
635     }
636     if (sub_protocol == NULL) {
637         ws->sub_protocol = NULL;
638         return ESP_OK;
639     }
640     ws->sub_protocol = strdup(sub_protocol);
641     if (ws->sub_protocol == NULL) {
642         return ESP_ERR_NO_MEM;
643     }
644     return ESP_OK;
645 }
646 
esp_transport_ws_set_user_agent(esp_transport_handle_t t,const char * user_agent)647 esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *user_agent)
648 {
649     if (t == NULL) {
650         return ESP_ERR_INVALID_ARG;
651     }
652     transport_ws_t *ws = esp_transport_get_context_data(t);
653     if (ws->user_agent) {
654         free(ws->user_agent);
655     }
656     if (user_agent == NULL) {
657         ws->user_agent = NULL;
658         return ESP_OK;
659     }
660     ws->user_agent = strdup(user_agent);
661     if (ws->user_agent == NULL) {
662         return ESP_ERR_NO_MEM;
663     }
664     return ESP_OK;
665 }
666 
esp_transport_ws_set_headers(esp_transport_handle_t t,const char * headers)667 esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers)
668 {
669     if (t == NULL) {
670         return ESP_ERR_INVALID_ARG;
671     }
672     transport_ws_t *ws = esp_transport_get_context_data(t);
673     if (ws->headers) {
674         free(ws->headers);
675     }
676     if (headers == NULL) {
677         ws->headers = NULL;
678         return ESP_OK;
679     }
680     ws->headers = strdup(headers);
681     if (ws->headers == NULL) {
682         return ESP_ERR_NO_MEM;
683     }
684     return ESP_OK;
685 }
686 
esp_transport_ws_set_config(esp_transport_handle_t t,const esp_transport_ws_config_t * config)687 esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config)
688 {
689     if (t == NULL) {
690         return ESP_ERR_INVALID_ARG;
691     }
692     esp_err_t err = ESP_OK;
693     transport_ws_t *ws = esp_transport_get_context_data(t);
694     if (config->ws_path) {
695         err = internal_esp_transport_ws_set_path(t, config->ws_path);
696         ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
697     }
698     if (config->sub_protocol) {
699         err = esp_transport_ws_set_subprotocol(t, config->sub_protocol);
700         ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
701     }
702     if (config->user_agent) {
703         err = esp_transport_ws_set_user_agent(t, config->user_agent);
704         ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
705     }
706     if (config->headers) {
707         err = esp_transport_ws_set_headers(t, config->headers);
708         ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
709     }
710     ws->propagate_control_frames = config->propagate_control_frames;
711 
712     return err;
713 }
714 
esp_transport_ws_get_read_opcode(esp_transport_handle_t t)715 ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
716 {
717     transport_ws_t *ws = esp_transport_get_context_data(t);
718     if (ws->frame_state.header_received) {
719         // convert the header byte to enum if correctly received
720         return (ws_transport_opcodes_t)ws->frame_state.opcode;
721     }
722     return WS_TRANSPORT_OPCODES_NONE;
723 }
724 
esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)725 int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)
726 {
727     transport_ws_t *ws = esp_transport_get_context_data(t);
728     return ws->frame_state.payload_len;
729 }
730 
esp_transport_ws_handle_control_frames(esp_transport_handle_t t,char * buffer,int len,int timeout_ms,bool client_closed)731 static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, bool client_closed)
732 {
733     transport_ws_t *ws = esp_transport_get_context_data(t);
734 
735     // If no new header reception in progress, or not a control frame
736     // just pass 0 -> no need to handle control frames
737     if (ws->frame_state.header_received == false ||
738         !(ws->frame_state.opcode & WS_OPCODE_CONTROL_FRAME)) {
739         return 0;
740     }
741     int actual_len;
742     int payload_len = ws->frame_state.payload_len;
743 
744     ESP_LOGD(TAG, "Handling control frame with %d bytes payload", payload_len);
745     if (payload_len > len) {
746         ESP_LOGE(TAG, "Not enough room for processing the payload (need=%d, available=%d)", payload_len, len);
747         ws->frame_state.bytes_remaining = payload_len - len;
748         return -1;
749     }
750 
751     if (ws->frame_state.opcode == WS_OPCODE_PING) {
752         // handle PING frames internally: just send a PONG with the same payload
753         actual_len = _ws_write(t, WS_OPCODE_PONG | WS_FIN, WS_MASK, buffer,
754                                payload_len, timeout_ms);
755         if (actual_len != payload_len) {
756             ESP_LOGE(TAG, "PONG send failed (payload_len=%d, written_len=%d)", payload_len, actual_len);
757             return -1;
758         }
759         ESP_LOGD(TAG, "PONG sent correctly (payload_len=%d)", payload_len);
760 
761         // control frame handled correctly, reset the flag indicating new header received
762         ws->frame_state.header_received = false;
763         return 0;
764 
765     } else if (ws->frame_state.opcode == WS_OPCODE_CLOSE) {
766         // handle CLOSE by the server: send a zero payload frame
767         if (buffer && payload_len > 0) {     // if some payload, print out the status code
768             uint16_t *code_network_order = (uint16_t *) buffer;
769             ESP_LOGI(TAG, "Got CLOSE frame with status code=%u", ntohs(*code_network_order));
770         }
771 
772         if (client_closed == false) {
773             // Only echo the closing frame if not initiated by the client
774             if (_ws_write(t, WS_OPCODE_CLOSE | WS_FIN, WS_MASK, NULL,0, timeout_ms) < 0) {
775                 ESP_LOGE(TAG, "Sending CLOSE frame with 0 payload failed");
776                 return -1;
777             }
778             ESP_LOGD(TAG, "CLOSE frame with no payload sent correctly");
779         }
780 
781         // control frame handled correctly, reset the flag indicating new header received
782         ws->frame_state.header_received = false;
783         int ret = esp_transport_ws_poll_connection_closed(t, timeout_ms);
784         if (ret == 0) {
785             ESP_LOGW(TAG, "Connection cannot be terminated gracefully within timeout=%d", timeout_ms);
786             return -1;
787         }
788         if (ret < 0) {
789             ESP_LOGW(TAG, "Connection terminated while waiting for clean TCP close");
790             return -1;
791         }
792         ESP_LOGI(TAG, "Connection terminated gracefully");
793         return 1;
794     } else if (ws->frame_state.opcode == WS_OPCODE_PONG) {
795         // handle PONG: just indicate return code
796         ESP_LOGD(TAG, "Received PONG frame with payload=%d", payload_len);
797         // control frame handled correctly, reset the flag indicating new header received
798         ws->frame_state.header_received = false;
799         return 2;
800     }
801     return 0;
802 }
803 
esp_transport_ws_poll_connection_closed(esp_transport_handle_t t,int timeout_ms)804 int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms)
805 {
806     struct timeval timeout;
807     int sock = esp_transport_get_socket(t);
808     fd_set readset;
809     fd_set errset;
810     FD_ZERO(&readset);
811     FD_ZERO(&errset);
812     FD_SET(sock, &readset);
813     FD_SET(sock, &errset);
814 
815     int ret = select(sock + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
816     if (ret > 0) {
817         if (FD_ISSET(sock, &readset)) {
818             uint8_t buffer;
819             if (recv(sock, &buffer, 1, MSG_PEEK) <= 0) {
820                 // socket is readable, but reads zero bytes -- connection cleanly closed by FIN flag
821                 return 1;
822             }
823             ESP_LOGW(TAG, "esp_transport_ws_poll_connection_closed: unexpected data readable on socket=%d", sock);
824         } else if (FD_ISSET(sock, &errset)) {
825             int sock_errno = 0;
826             uint32_t optlen = sizeof(sock_errno);
827             getsockopt(sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
828             ESP_LOGD(TAG, "esp_transport_ws_poll_connection_closed select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), sock);
829             if (sock_errno == ENOTCONN || sock_errno == ECONNRESET || sock_errno == ECONNABORTED) {
830                 // the three err codes above might be caused by connection termination by RTS flag
831                 // which we still assume as expected closing sequence of ws-transport connection
832                 return 1;
833             }
834             ESP_LOGE(TAG, "esp_transport_ws_poll_connection_closed: unexpected errno=%d on socket=%d", sock_errno, sock);
835         }
836         return -1; // indicates error: socket unexpectedly reads an actual data, or unexpected errno code
837     }
838     return ret;
839 
840 }
841