1 /*
2  *  UDP proxy: emulate an unreliable UDP connection for DTLS testing
3  *
4  *  Copyright The Mbed TLS Contributors
5  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6  */
7 
8 /*
9  * Warning: this is an internal utility program we use for tests.
10  * It does break some abstractions from the NET layer, and is thus NOT an
11  * example of good general usage.
12  */
13 
14 #define MBEDTLS_ALLOW_PRIVATE_ACCESS
15 
16 #include "mbedtls/build_info.h"
17 
18 #if defined(MBEDTLS_PLATFORM_C)
19 #include "mbedtls/platform.h"
20 #else
21 #include <stdio.h>
22 #include <stdlib.h>
23 #if defined(MBEDTLS_HAVE_TIME)
24 #include <time.h>
25 #define mbedtls_time            time
26 #define mbedtls_time_t          time_t
27 #endif
28 #define mbedtls_printf          printf
29 #define mbedtls_calloc          calloc
30 #define mbedtls_free            free
31 #define mbedtls_exit            exit
32 #define MBEDTLS_EXIT_SUCCESS    EXIT_SUCCESS
33 #define MBEDTLS_EXIT_FAILURE    EXIT_FAILURE
34 #endif /* MBEDTLS_PLATFORM_C */
35 
36 #if !defined(MBEDTLS_NET_C)
main(void)37 int main(void)
38 {
39     mbedtls_printf("MBEDTLS_NET_C not defined.\n");
40     mbedtls_exit(0);
41 }
42 #else
43 
44 #include "mbedtls/net_sockets.h"
45 #include "mbedtls/error.h"
46 #include "mbedtls/ssl.h"
47 #include "mbedtls/timing.h"
48 
49 #include <string.h>
50 
51 /* For select() */
52 #if (defined(_WIN32) || defined(_WIN32_WCE)) && !defined(EFIX64) && \
53     !defined(EFI32)
54 #include <winsock2.h>
55 #include <windows.h>
56 #if defined(_MSC_VER)
57 #if defined(_WIN32_WCE)
58 #pragma comment( lib, "ws2.lib" )
59 #else
60 #pragma comment( lib, "ws2_32.lib" )
61 #endif
62 #endif /* _MSC_VER */
63 #else /* ( _WIN32 || _WIN32_WCE ) && !EFIX64 && !EFI32 */
64 #if defined(MBEDTLS_HAVE_TIME) || (defined(MBEDTLS_TIMING_C) && !defined(MBEDTLS_TIMING_ALT))
65 #include <sys/time.h>
66 #endif
67 #include <sys/select.h>
68 #include <sys/types.h>
69 #include <unistd.h>
70 #endif /* ( _WIN32 || _WIN32_WCE ) && !EFIX64 && !EFI32 */
71 
72 #define MAX_MSG_SIZE            16384 + 2048 /* max record/datagram size */
73 
74 #define DFL_SERVER_ADDR         "localhost"
75 #define DFL_SERVER_PORT         "4433"
76 #define DFL_LISTEN_ADDR         "localhost"
77 #define DFL_LISTEN_PORT         "5556"
78 #define DFL_PACK                0
79 
80 #if defined(MBEDTLS_TIMING_C)
81 #define USAGE_PACK                                                          \
82     "    pack=%%d             default: 0     (don't pack)\n"                \
83     "                         options: t > 0 (pack for t milliseconds)\n"
84 #else
85 #define USAGE_PACK
86 #endif
87 
88 #define USAGE                                                               \
89     "\n usage: udp_proxy param=<>...\n"                                     \
90     "\n acceptable parameters:\n"                                           \
91     "    server_addr=%%s      default: localhost\n"                         \
92     "    server_port=%%d      default: 4433\n"                              \
93     "    listen_addr=%%s      default: localhost\n"                         \
94     "    listen_port=%%d      default: 4433\n"                              \
95     "\n"                                                                    \
96     "    duplicate=%%d        default: 0 (no duplication)\n"                \
97     "                        duplicate about 1:N packets randomly\n"        \
98     "    delay=%%d            default: 0 (no delayed packets)\n"            \
99     "                        delay about 1:N packets randomly\n"            \
100     "    delay_ccs=0/1       default: 0 (don't delay ChangeCipherSpec)\n"   \
101     "    delay_cli=%%s        Handshake message from client that should be\n" \
102     "                        delayed. Possible values are 'ClientHello',\n" \
103     "                        'Certificate', 'CertificateVerify', and\n"     \
104     "                        'ClientKeyExchange'.\n"                        \
105     "                        May be used multiple times, even for the same\n" \
106     "                        message, in which case the respective message\n" \
107     "                        gets delayed multiple times.\n"                 \
108     "    delay_srv=%%s        Handshake message from server that should be\n" \
109     "                        delayed. Possible values are 'HelloRequest',\n" \
110     "                        'ServerHello', 'ServerHelloDone', 'Certificate'\n" \
111     "                        'ServerKeyExchange', 'NewSessionTicket',\n" \
112     "                        'HelloVerifyRequest' and ''CertificateRequest'.\n" \
113     "                        May be used multiple times, even for the same\n" \
114     "                        message, in which case the respective message\n" \
115     "                        gets delayed multiple times.\n"                 \
116     "    drop=%%d             default: 0 (no dropped packets)\n"            \
117     "                        drop about 1:N packets randomly\n"             \
118     "    mtu=%%d              default: 0 (unlimited)\n"                     \
119     "                        drop packets larger than N bytes\n"            \
120     "    bad_ad=0/1          default: 0 (don't add bad ApplicationData)\n"  \
121     "    bad_cid=%%d          default: 0 (don't corrupt Connection IDs)\n"   \
122     "                        duplicate 1:N packets containing a CID,\n" \
123     "                        modifying CID in first instance of the packet.\n" \
124     "    protect_hvr=0/1     default: 0 (don't protect HelloVerifyRequest)\n" \
125     "    protect_len=%%d      default: (don't protect packets of this size)\n" \
126     "    inject_clihlo=0/1   default: 0 (don't inject fake ClientHello)\n"  \
127     "\n"                                                                    \
128     "    seed=%%d             default: (use current time)\n"                \
129     USAGE_PACK                                                              \
130     "\n"
131 
132 /*
133  * global options
134  */
135 
136 #define MAX_DELAYED_HS 10
137 
138 static struct options {
139     const char *server_addr;    /* address to forward packets to            */
140     const char *server_port;    /* port to forward packets to               */
141     const char *listen_addr;    /* address for accepting client connections */
142     const char *listen_port;    /* port for accepting client connections    */
143 
144     int duplicate;              /* duplicate 1 in N packets (none if 0)     */
145     int delay;                  /* delay 1 packet in N (none if 0)          */
146     int delay_ccs;              /* delay ChangeCipherSpec                   */
147     char *delay_cli[MAX_DELAYED_HS];  /* handshake types of messages from
148                                        * client that should be delayed.     */
149     uint8_t delay_cli_cnt;      /* Number of entries in delay_cli.          */
150     char *delay_srv[MAX_DELAYED_HS];  /* handshake types of messages from
151                                        * server that should be delayed.     */
152     uint8_t delay_srv_cnt;      /* Number of entries in delay_srv.          */
153     int drop;                   /* drop 1 packet in N (none if 0)           */
154     int mtu;                    /* drop packets larger than this            */
155     int bad_ad;                 /* inject corrupted ApplicationData record  */
156     unsigned bad_cid;           /* inject corrupted CID record              */
157     int protect_hvr;            /* never drop or delay HelloVerifyRequest   */
158     int protect_len;            /* never drop/delay packet of the given size*/
159     int inject_clihlo;          /* inject fake ClientHello after handshake  */
160     unsigned pack;              /* merge packets into single datagram for
161                                  * at most \c merge milliseconds if > 0     */
162     unsigned int seed;          /* seed for "random" events                 */
163 } opt;
164 
exit_usage(const char * name,const char * value)165 static void exit_usage(const char *name, const char *value)
166 {
167     if (value == NULL) {
168         mbedtls_printf(" unknown option or missing value: %s\n", name);
169     } else {
170         mbedtls_printf(" option %s: illegal value: %s\n", name, value);
171     }
172 
173     mbedtls_printf(USAGE);
174     mbedtls_exit(1);
175 }
176 
get_options(int argc,char * argv[])177 static void get_options(int argc, char *argv[])
178 {
179     int i;
180     char *p, *q;
181 
182     opt.server_addr    = DFL_SERVER_ADDR;
183     opt.server_port    = DFL_SERVER_PORT;
184     opt.listen_addr    = DFL_LISTEN_ADDR;
185     opt.listen_port    = DFL_LISTEN_PORT;
186     opt.pack           = DFL_PACK;
187     /* Other members default to 0 */
188 
189     opt.delay_cli_cnt = 0;
190     opt.delay_srv_cnt = 0;
191     memset(opt.delay_cli, 0, sizeof(opt.delay_cli));
192     memset(opt.delay_srv, 0, sizeof(opt.delay_srv));
193 
194     for (i = 1; i < argc; i++) {
195         p = argv[i];
196         if ((q = strchr(p, '=')) == NULL) {
197             exit_usage(p, NULL);
198         }
199         *q++ = '\0';
200 
201         if (strcmp(p, "server_addr") == 0) {
202             opt.server_addr = q;
203         } else if (strcmp(p, "server_port") == 0) {
204             opt.server_port = q;
205         } else if (strcmp(p, "listen_addr") == 0) {
206             opt.listen_addr = q;
207         } else if (strcmp(p, "listen_port") == 0) {
208             opt.listen_port = q;
209         } else if (strcmp(p, "duplicate") == 0) {
210             opt.duplicate = atoi(q);
211             if (opt.duplicate < 0 || opt.duplicate > 20) {
212                 exit_usage(p, q);
213             }
214         } else if (strcmp(p, "delay") == 0) {
215             opt.delay = atoi(q);
216             if (opt.delay < 0 || opt.delay > 20 || opt.delay == 1) {
217                 exit_usage(p, q);
218             }
219         } else if (strcmp(p, "delay_ccs") == 0) {
220             opt.delay_ccs = atoi(q);
221             if (opt.delay_ccs < 0 || opt.delay_ccs > 1) {
222                 exit_usage(p, q);
223             }
224         } else if (strcmp(p, "delay_cli") == 0 ||
225                    strcmp(p, "delay_srv") == 0) {
226             uint8_t *delay_cnt;
227             char **delay_list;
228             size_t len;
229             char *buf;
230 
231             if (strcmp(p, "delay_cli") == 0) {
232                 delay_cnt  = &opt.delay_cli_cnt;
233                 delay_list = opt.delay_cli;
234             } else {
235                 delay_cnt  = &opt.delay_srv_cnt;
236                 delay_list = opt.delay_srv;
237             }
238 
239             if (*delay_cnt == MAX_DELAYED_HS) {
240                 mbedtls_printf(" too many uses of %s: only %d allowed\n",
241                                p, MAX_DELAYED_HS);
242                 exit_usage(p, NULL);
243             }
244 
245             len = strlen(q);
246             buf = mbedtls_calloc(1, len + 1);
247             if (buf == NULL) {
248                 mbedtls_printf(" Allocation failure\n");
249                 exit(1);
250             }
251             memcpy(buf, q, len + 1);
252 
253             delay_list[(*delay_cnt)++] = buf;
254         } else if (strcmp(p, "drop") == 0) {
255             opt.drop = atoi(q);
256             if (opt.drop < 0 || opt.drop > 20 || opt.drop == 1) {
257                 exit_usage(p, q);
258             }
259         } else if (strcmp(p, "pack") == 0) {
260 #if defined(MBEDTLS_TIMING_C)
261             opt.pack = (unsigned) atoi(q);
262 #else
263             mbedtls_printf(" option pack only defined if MBEDTLS_TIMING_C is enabled\n");
264             exit(1);
265 #endif
266         } else if (strcmp(p, "mtu") == 0) {
267             opt.mtu = atoi(q);
268             if (opt.mtu < 0 || opt.mtu > MAX_MSG_SIZE) {
269                 exit_usage(p, q);
270             }
271         } else if (strcmp(p, "bad_ad") == 0) {
272             opt.bad_ad = atoi(q);
273             if (opt.bad_ad < 0 || opt.bad_ad > 1) {
274                 exit_usage(p, q);
275             }
276         }
277 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
278         else if (strcmp(p, "bad_cid") == 0) {
279             opt.bad_cid = (unsigned) atoi(q);
280         }
281 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
282         else if (strcmp(p, "protect_hvr") == 0) {
283             opt.protect_hvr = atoi(q);
284             if (opt.protect_hvr < 0 || opt.protect_hvr > 1) {
285                 exit_usage(p, q);
286             }
287         } else if (strcmp(p, "protect_len") == 0) {
288             opt.protect_len = atoi(q);
289             if (opt.protect_len < 0) {
290                 exit_usage(p, q);
291             }
292         } else if (strcmp(p, "inject_clihlo") == 0) {
293             opt.inject_clihlo = atoi(q);
294             if (opt.inject_clihlo < 0 || opt.inject_clihlo > 1) {
295                 exit_usage(p, q);
296             }
297         } else if (strcmp(p, "seed") == 0) {
298             opt.seed = atoi(q);
299             if (opt.seed == 0) {
300                 exit_usage(p, q);
301             }
302         } else {
303             exit_usage(p, NULL);
304         }
305     }
306 }
307 
msg_type(unsigned char * msg,size_t len)308 static const char *msg_type(unsigned char *msg, size_t len)
309 {
310     if (len < 1) {
311         return "Invalid";
312     }
313     switch (msg[0]) {
314         case MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC:    return "ChangeCipherSpec";
315         case MBEDTLS_SSL_MSG_ALERT:                 return "Alert";
316         case MBEDTLS_SSL_MSG_APPLICATION_DATA:      return "ApplicationData";
317         case MBEDTLS_SSL_MSG_CID:                   return "CID";
318         case MBEDTLS_SSL_MSG_HANDSHAKE:             break; /* See below */
319         default:                            return "Unknown";
320     }
321 
322     if (len < 13 + 12) {
323         return "Invalid handshake";
324     }
325 
326     /*
327      * Our handshake message are less than 2^16 bytes long, so they should
328      * have 0 as the first byte of length, frag_offset and frag_length.
329      * Otherwise, assume they are encrypted.
330      */
331     if (msg[14] || msg[19] || msg[22]) {
332         return "Encrypted handshake";
333     }
334 
335     switch (msg[13]) {
336         case MBEDTLS_SSL_HS_HELLO_REQUEST:          return "HelloRequest";
337         case MBEDTLS_SSL_HS_CLIENT_HELLO:           return "ClientHello";
338         case MBEDTLS_SSL_HS_SERVER_HELLO:           return "ServerHello";
339         case MBEDTLS_SSL_HS_HELLO_VERIFY_REQUEST:   return "HelloVerifyRequest";
340         case MBEDTLS_SSL_HS_NEW_SESSION_TICKET:     return "NewSessionTicket";
341         case MBEDTLS_SSL_HS_CERTIFICATE:            return "Certificate";
342         case MBEDTLS_SSL_HS_SERVER_KEY_EXCHANGE:    return "ServerKeyExchange";
343         case MBEDTLS_SSL_HS_CERTIFICATE_REQUEST:    return "CertificateRequest";
344         case MBEDTLS_SSL_HS_SERVER_HELLO_DONE:      return "ServerHelloDone";
345         case MBEDTLS_SSL_HS_CERTIFICATE_VERIFY:     return "CertificateVerify";
346         case MBEDTLS_SSL_HS_CLIENT_KEY_EXCHANGE:    return "ClientKeyExchange";
347         case MBEDTLS_SSL_HS_FINISHED:               return "Finished";
348         default:                            return "Unknown handshake";
349     }
350 }
351 
352 #if defined(MBEDTLS_TIMING_C)
353 /* Return elapsed time in milliseconds since the first call */
elapsed_time(void)354 static unsigned elapsed_time(void)
355 {
356     static int initialized = 0;
357     static struct mbedtls_timing_hr_time hires;
358 
359     if (initialized == 0) {
360         (void) mbedtls_timing_get_timer(&hires, 1);
361         initialized = 1;
362         return 0;
363     }
364 
365     return mbedtls_timing_get_timer(&hires, 0);
366 }
367 
368 typedef struct {
369     mbedtls_net_context *ctx;
370 
371     const char *description;
372 
373     unsigned packet_lifetime;
374     unsigned num_datagrams;
375 
376     unsigned char data[MAX_MSG_SIZE];
377     size_t len;
378 
379 } ctx_buffer;
380 
381 static ctx_buffer outbuf[2];
382 
ctx_buffer_flush(ctx_buffer * buf)383 static int ctx_buffer_flush(ctx_buffer *buf)
384 {
385     int ret;
386 
387     mbedtls_printf("  %05u flush    %s: %u bytes, %u datagrams, last %u ms\n",
388                    elapsed_time(), buf->description,
389                    (unsigned) buf->len, buf->num_datagrams,
390                    elapsed_time() - buf->packet_lifetime);
391 
392     ret = mbedtls_net_send(buf->ctx, buf->data, buf->len);
393 
394     buf->len           = 0;
395     buf->num_datagrams = 0;
396 
397     return ret;
398 }
399 
ctx_buffer_time_remaining(ctx_buffer * buf)400 static unsigned ctx_buffer_time_remaining(ctx_buffer *buf)
401 {
402     unsigned const cur_time = elapsed_time();
403 
404     if (buf->num_datagrams == 0) {
405         return (unsigned) -1;
406     }
407 
408     if (cur_time - buf->packet_lifetime >= opt.pack) {
409         return 0;
410     }
411 
412     return opt.pack - (cur_time - buf->packet_lifetime);
413 }
414 
ctx_buffer_append(ctx_buffer * buf,const unsigned char * data,size_t len)415 static int ctx_buffer_append(ctx_buffer *buf,
416                              const unsigned char *data,
417                              size_t len)
418 {
419     int ret;
420 
421     if (len > (size_t) INT_MAX) {
422         return -1;
423     }
424 
425     if (len > sizeof(buf->data)) {
426         mbedtls_printf("  ! buffer size %u too large (max %u)\n",
427                        (unsigned) len, (unsigned) sizeof(buf->data));
428         return -1;
429     }
430 
431     if (sizeof(buf->data) - buf->len < len) {
432         if ((ret = ctx_buffer_flush(buf)) <= 0) {
433             mbedtls_printf("ctx_buffer_flush failed with -%#04x", (unsigned int) -ret);
434             return ret;
435         }
436     }
437 
438     memcpy(buf->data + buf->len, data, len);
439 
440     buf->len += len;
441     if (++buf->num_datagrams == 1) {
442         buf->packet_lifetime = elapsed_time();
443     }
444 
445     return (int) len;
446 }
447 #endif /* MBEDTLS_TIMING_C */
448 
dispatch_data(mbedtls_net_context * ctx,const unsigned char * data,size_t len)449 static int dispatch_data(mbedtls_net_context *ctx,
450                          const unsigned char *data,
451                          size_t len)
452 {
453     int ret;
454 #if defined(MBEDTLS_TIMING_C)
455     ctx_buffer *buf = NULL;
456     if (opt.pack > 0) {
457         if (outbuf[0].ctx == ctx) {
458             buf = &outbuf[0];
459         } else if (outbuf[1].ctx == ctx) {
460             buf = &outbuf[1];
461         }
462 
463         if (buf == NULL) {
464             return -1;
465         }
466 
467         return ctx_buffer_append(buf, data, len);
468     }
469 #endif /* MBEDTLS_TIMING_C */
470 
471     ret = mbedtls_net_send(ctx, data, len);
472     if (ret < 0) {
473         mbedtls_printf("net_send returned -%#04x\n", (unsigned int) -ret);
474     }
475     return ret;
476 }
477 
478 typedef struct {
479     mbedtls_net_context *dst;
480     const char *way;
481     const char *type;
482     unsigned len;
483     unsigned char buf[MAX_MSG_SIZE];
484 } packet;
485 
486 /* Print packet. Outgoing packets come with a reason (forward, dupl, etc.) */
print_packet(const packet * p,const char * why)487 void print_packet(const packet *p, const char *why)
488 {
489 #if defined(MBEDTLS_TIMING_C)
490     if (why == NULL) {
491         mbedtls_printf("  %05u dispatch %s %s (%u bytes)\n",
492                        elapsed_time(), p->way, p->type, p->len);
493     } else {
494         mbedtls_printf("  %05u dispatch %s %s (%u bytes): %s\n",
495                        elapsed_time(), p->way, p->type, p->len, why);
496     }
497 #else
498     if (why == NULL) {
499         mbedtls_printf("        dispatch %s %s (%u bytes)\n",
500                        p->way, p->type, p->len);
501     } else {
502         mbedtls_printf("        dispatch %s %s (%u bytes): %s\n",
503                        p->way, p->type, p->len, why);
504     }
505 #endif
506 
507     fflush(stdout);
508 }
509 
510 /*
511  * In order to test the server's behaviour when receiving a ClientHello after
512  * the connection is established (this could be a hard reset from the client,
513  * but the server must not drop the existing connection before establishing
514  * client reachability, see RFC 6347 Section 4.2.8), we memorize the first
515  * ClientHello we see (which can't have a cookie), then replay it after the
516  * first ApplicationData record - then we're done.
517  *
518  * This is controlled by the inject_clihlo option.
519  *
520  * We want an explicit state and a place to store the packet.
521  */
522 typedef enum {
523     ICH_INIT,       /* haven't seen the first ClientHello yet */
524     ICH_CACHED,     /* cached the initial ClientHello */
525     ICH_INJECTED,   /* ClientHello already injected, done */
526 } inject_clihlo_state_t;
527 
528 static inject_clihlo_state_t inject_clihlo_state;
529 static packet initial_clihlo;
530 
send_packet(const packet * p,const char * why)531 int send_packet(const packet *p, const char *why)
532 {
533     int ret;
534     mbedtls_net_context *dst = p->dst;
535 
536     /* save initial ClientHello? */
537     if (opt.inject_clihlo != 0 &&
538         inject_clihlo_state == ICH_INIT &&
539         strcmp(p->type, "ClientHello") == 0) {
540         memcpy(&initial_clihlo, p, sizeof(packet));
541         inject_clihlo_state = ICH_CACHED;
542     }
543 
544     /* insert corrupted CID record? */
545     if (opt.bad_cid != 0 &&
546         strcmp(p->type, "CID") == 0 &&
547         (rand() % opt.bad_cid) == 0) {
548         unsigned char buf[MAX_MSG_SIZE];
549         memcpy(buf, p->buf, p->len);
550 
551         /* The CID resides at offset 11 in the DTLS record header. */
552         buf[11] ^= 1;
553         print_packet(p, "modified CID");
554 
555         if ((ret = dispatch_data(dst, buf, p->len)) <= 0) {
556             mbedtls_printf("  ! dispatch returned %d\n", ret);
557             return ret;
558         }
559     }
560 
561     /* insert corrupted ApplicationData record? */
562     if (opt.bad_ad &&
563         strcmp(p->type, "ApplicationData") == 0) {
564         unsigned char buf[MAX_MSG_SIZE];
565         memcpy(buf, p->buf, p->len);
566 
567         if (p->len <= 13) {
568             mbedtls_printf("  ! can't corrupt empty AD record");
569         } else {
570             ++buf[13];
571             print_packet(p, "corrupted");
572         }
573 
574         if ((ret = dispatch_data(dst, buf, p->len)) <= 0) {
575             mbedtls_printf("  ! dispatch returned %d\n", ret);
576             return ret;
577         }
578     }
579 
580     print_packet(p, why);
581     if ((ret = dispatch_data(dst, p->buf, p->len)) <= 0) {
582         mbedtls_printf("  ! dispatch returned %d\n", ret);
583         return ret;
584     }
585 
586     /* Don't duplicate Application Data, only handshake covered */
587     if (opt.duplicate != 0 &&
588         strcmp(p->type, "ApplicationData") != 0 &&
589         rand() % opt.duplicate == 0) {
590         print_packet(p, "duplicated");
591 
592         if ((ret = dispatch_data(dst, p->buf, p->len)) <= 0) {
593             mbedtls_printf("  ! dispatch returned %d\n", ret);
594             return ret;
595         }
596     }
597 
598     /* Inject ClientHello after first ApplicationData */
599     if (opt.inject_clihlo != 0 &&
600         inject_clihlo_state == ICH_CACHED &&
601         strcmp(p->type, "ApplicationData") == 0) {
602         print_packet(&initial_clihlo, "injected");
603 
604         if ((ret = dispatch_data(dst, initial_clihlo.buf,
605                                  initial_clihlo.len)) <= 0) {
606             mbedtls_printf("  ! dispatch returned %d\n", ret);
607             return ret;
608         }
609 
610         inject_clihlo_state = ICH_INJECTED;
611     }
612 
613     return 0;
614 }
615 
616 #define MAX_DELAYED_MSG 5
617 static size_t prev_len;
618 static packet prev[MAX_DELAYED_MSG];
619 
clear_pending(void)620 void clear_pending(void)
621 {
622     memset(&prev, 0, sizeof(prev));
623     prev_len = 0;
624 }
625 
delay_packet(packet * delay)626 void delay_packet(packet *delay)
627 {
628     if (prev_len == MAX_DELAYED_MSG) {
629         return;
630     }
631 
632     memcpy(&prev[prev_len++], delay, sizeof(packet));
633 }
634 
send_delayed(void)635 int send_delayed(void)
636 {
637     uint8_t offset;
638     int ret;
639     for (offset = 0; offset < prev_len; offset++) {
640         ret = send_packet(&prev[offset], "delayed");
641         if (ret != 0) {
642             return ret;
643         }
644     }
645 
646     clear_pending();
647     return 0;
648 }
649 
650 /*
651  * Avoid dropping or delaying a packet that was already dropped or delayed
652  * ("held") twice: this only results in uninteresting timeouts. We can't rely
653  * on type to identify packets, since during renegotiation they're all
654  * encrypted. So, rely on size mod 2048 (which is usually just size).
655  *
656  * We only hold packets at the level of entire datagrams, not at the level
657  * of records. In particular, if the peer changes the way it packs multiple
658  * records into a single datagram, we don't necessarily count the number of
659  * times a record has been held correctly. However, the only known reason
660  * why a peer would change datagram packing is disabling the latter on
661  * retransmission, in which case we'd hold involved records at most
662  * HOLD_MAX + 1 times.
663  */
664 static unsigned char held[2048] = { 0 };
665 #define HOLD_MAX 2
666 
handle_message(const char * way,mbedtls_net_context * dst,mbedtls_net_context * src)667 int handle_message(const char *way,
668                    mbedtls_net_context *dst,
669                    mbedtls_net_context *src)
670 {
671     int ret;
672     packet cur;
673     size_t id;
674 
675     uint8_t delay_idx;
676     char **delay_list;
677     uint8_t delay_list_len;
678 
679     /* receive packet */
680     if ((ret = mbedtls_net_recv(src, cur.buf, sizeof(cur.buf))) <= 0) {
681         mbedtls_printf("  ! mbedtls_net_recv returned %d\n", ret);
682         return ret;
683     }
684 
685     cur.len  = ret;
686     cur.type = msg_type(cur.buf, cur.len);
687     cur.way  = way;
688     cur.dst  = dst;
689     print_packet(&cur, NULL);
690 
691     id = cur.len % sizeof(held);
692 
693     if (strcmp(way, "S <- C") == 0) {
694         delay_list     = opt.delay_cli;
695         delay_list_len = opt.delay_cli_cnt;
696     } else {
697         delay_list     = opt.delay_srv;
698         delay_list_len = opt.delay_srv_cnt;
699     }
700 
701     /* Check if message type is in the list of messages
702      * that should be delayed */
703     for (delay_idx = 0; delay_idx < delay_list_len; delay_idx++) {
704         if (delay_list[delay_idx] == NULL) {
705             continue;
706         }
707 
708         if (strcmp(delay_list[delay_idx], cur.type) == 0) {
709             /* Delay message */
710             delay_packet(&cur);
711 
712             /* Remove entry from list */
713             mbedtls_free(delay_list[delay_idx]);
714             delay_list[delay_idx] = NULL;
715 
716             return 0;
717         }
718     }
719 
720     /* do we want to drop, delay, or forward it? */
721     if ((opt.mtu != 0 &&
722          cur.len > (unsigned) opt.mtu) ||
723         (opt.drop != 0 &&
724          strcmp(cur.type, "CID") != 0             &&
725          strcmp(cur.type, "ApplicationData") != 0 &&
726          !(opt.protect_hvr &&
727            strcmp(cur.type, "HelloVerifyRequest") == 0) &&
728          cur.len != (size_t) opt.protect_len &&
729          held[id] < HOLD_MAX &&
730          rand() % opt.drop == 0)) {
731         ++held[id];
732     } else if ((opt.delay_ccs == 1 &&
733                 strcmp(cur.type, "ChangeCipherSpec") == 0) ||
734                (opt.delay != 0 &&
735                 strcmp(cur.type, "CID") != 0             &&
736                 strcmp(cur.type, "ApplicationData") != 0 &&
737                 !(opt.protect_hvr &&
738                   strcmp(cur.type, "HelloVerifyRequest") == 0) &&
739                 cur.len != (size_t) opt.protect_len &&
740                 held[id] < HOLD_MAX &&
741                 rand() % opt.delay == 0)) {
742         ++held[id];
743         delay_packet(&cur);
744     } else {
745         /* forward and possibly duplicate */
746         if ((ret = send_packet(&cur, "forwarded")) != 0) {
747             return ret;
748         }
749 
750         /* send previously delayed messages if any */
751         ret = send_delayed();
752         if (ret != 0) {
753             return ret;
754         }
755     }
756 
757     return 0;
758 }
759 
main(int argc,char * argv[])760 int main(int argc, char *argv[])
761 {
762     int ret = 1;
763     int exit_code = MBEDTLS_EXIT_FAILURE;
764     uint8_t delay_idx;
765 
766     mbedtls_net_context listen_fd, client_fd, server_fd;
767 
768 #if defined(MBEDTLS_TIMING_C)
769     struct timeval tm;
770 #endif
771 
772     struct timeval *tm_ptr = NULL;
773 
774     int nb_fds;
775     fd_set read_fds;
776 
777     mbedtls_net_init(&listen_fd);
778     mbedtls_net_init(&client_fd);
779     mbedtls_net_init(&server_fd);
780 
781     get_options(argc, argv);
782 
783     /*
784      * Decisions to drop/delay/duplicate packets are pseudo-random: dropping
785      * exactly 1 in N packets would lead to problems when a flight has exactly
786      * N packets: the same packet would be dropped on every resend.
787      *
788      * In order to be able to reproduce problems reliably, the seed may be
789      * specified explicitly.
790      */
791     if (opt.seed == 0) {
792 #if defined(MBEDTLS_HAVE_TIME)
793         opt.seed = (unsigned int) mbedtls_time(NULL);
794 #else
795         opt.seed = 1;
796 #endif /* MBEDTLS_HAVE_TIME */
797         mbedtls_printf("  . Pseudo-random seed: %u\n", opt.seed);
798     }
799 
800     srand(opt.seed);
801 
802     /*
803      * 0. "Connect" to the server
804      */
805     mbedtls_printf("  . Connect to server on UDP/%s/%s ...",
806                    opt.server_addr, opt.server_port);
807     fflush(stdout);
808 
809     if ((ret = mbedtls_net_connect(&server_fd, opt.server_addr, opt.server_port,
810                                    MBEDTLS_NET_PROTO_UDP)) != 0) {
811         mbedtls_printf(" failed\n  ! mbedtls_net_connect returned %d\n\n", ret);
812         goto exit;
813     }
814 
815     mbedtls_printf(" ok\n");
816 
817     /*
818      * 1. Setup the "listening" UDP socket
819      */
820     mbedtls_printf("  . Bind on UDP/%s/%s ...",
821                    opt.listen_addr, opt.listen_port);
822     fflush(stdout);
823 
824     if ((ret = mbedtls_net_bind(&listen_fd, opt.listen_addr, opt.listen_port,
825                                 MBEDTLS_NET_PROTO_UDP)) != 0) {
826         mbedtls_printf(" failed\n  ! mbedtls_net_bind returned %d\n\n", ret);
827         goto exit;
828     }
829 
830     mbedtls_printf(" ok\n");
831 
832     /*
833      * 2. Wait until a client connects
834      */
835 accept:
836     mbedtls_net_free(&client_fd);
837 
838     mbedtls_printf("  . Waiting for a remote connection ...");
839     fflush(stdout);
840 
841     if ((ret = mbedtls_net_accept(&listen_fd, &client_fd,
842                                   NULL, 0, NULL)) != 0) {
843         mbedtls_printf(" failed\n  ! mbedtls_net_accept returned %d\n\n", ret);
844         goto exit;
845     }
846 
847     mbedtls_printf(" ok\n");
848 
849     /*
850      * 3. Forward packets forever (kill the process to terminate it)
851      */
852     clear_pending();
853     memset(held, 0, sizeof(held));
854 
855     nb_fds = client_fd.fd;
856     if (nb_fds < server_fd.fd) {
857         nb_fds = server_fd.fd;
858     }
859     if (nb_fds < listen_fd.fd) {
860         nb_fds = listen_fd.fd;
861     }
862     ++nb_fds;
863 
864 #if defined(MBEDTLS_TIMING_C)
865     if (opt.pack > 0) {
866         outbuf[0].ctx = &server_fd;
867         outbuf[0].description = "S <- C";
868         outbuf[0].num_datagrams = 0;
869         outbuf[0].len = 0;
870 
871         outbuf[1].ctx = &client_fd;
872         outbuf[1].description = "S -> C";
873         outbuf[1].num_datagrams = 0;
874         outbuf[1].len = 0;
875     }
876 #endif /* MBEDTLS_TIMING_C */
877 
878     while (1) {
879 #if defined(MBEDTLS_TIMING_C)
880         if (opt.pack > 0) {
881             unsigned max_wait_server, max_wait_client, max_wait;
882             max_wait_server = ctx_buffer_time_remaining(&outbuf[0]);
883             max_wait_client = ctx_buffer_time_remaining(&outbuf[1]);
884 
885             max_wait = (unsigned) -1;
886 
887             if (max_wait_server == 0) {
888                 ctx_buffer_flush(&outbuf[0]);
889             } else {
890                 max_wait = max_wait_server;
891             }
892 
893             if (max_wait_client == 0) {
894                 ctx_buffer_flush(&outbuf[1]);
895             } else {
896                 if (max_wait_client < max_wait) {
897                     max_wait = max_wait_client;
898                 }
899             }
900 
901             if (max_wait != (unsigned) -1) {
902                 tm.tv_sec  = max_wait / 1000;
903                 tm.tv_usec = (max_wait % 1000) * 1000;
904 
905                 tm_ptr = &tm;
906             } else {
907                 tm_ptr = NULL;
908             }
909         }
910 #endif /* MBEDTLS_TIMING_C */
911 
912         FD_ZERO(&read_fds);
913         FD_SET(server_fd.fd, &read_fds);
914         FD_SET(client_fd.fd, &read_fds);
915         FD_SET(listen_fd.fd, &read_fds);
916 
917         if ((ret = select(nb_fds, &read_fds, NULL, NULL, tm_ptr)) < 0) {
918             perror("select");
919             goto exit;
920         }
921 
922         if (FD_ISSET(listen_fd.fd, &read_fds)) {
923             goto accept;
924         }
925 
926         if (FD_ISSET(client_fd.fd, &read_fds)) {
927             if ((ret = handle_message("S <- C",
928                                       &server_fd, &client_fd)) != 0) {
929                 goto accept;
930             }
931         }
932 
933         if (FD_ISSET(server_fd.fd, &read_fds)) {
934             if ((ret = handle_message("S -> C",
935                                       &client_fd, &server_fd)) != 0) {
936                 goto accept;
937             }
938         }
939 
940     }
941 
942     exit_code = MBEDTLS_EXIT_SUCCESS;
943 
944 exit:
945 
946 #ifdef MBEDTLS_ERROR_C
947     if (exit_code != MBEDTLS_EXIT_SUCCESS) {
948         char error_buf[100];
949         mbedtls_strerror(ret, error_buf, 100);
950         mbedtls_printf("Last error was: -0x%04X - %s\n\n", (unsigned int) -ret, error_buf);
951         fflush(stdout);
952     }
953 #endif
954 
955     for (delay_idx = 0; delay_idx < MAX_DELAYED_HS; delay_idx++) {
956         mbedtls_free(opt.delay_cli[delay_idx]);
957         mbedtls_free(opt.delay_srv[delay_idx]);
958     }
959 
960     mbedtls_net_free(&client_fd);
961     mbedtls_net_free(&server_fd);
962     mbedtls_net_free(&listen_fd);
963 
964     mbedtls_exit(exit_code);
965 }
966 
967 #endif /* MBEDTLS_NET_C */
968