1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements a TCP CLI tool.
32  */
33 
34 #include "openthread-core-config.h"
35 
36 #include "cli_config.h"
37 
38 #if OPENTHREAD_CONFIG_TCP_ENABLE && OPENTHREAD_CONFIG_CLI_TCP_ENABLE
39 
40 #include "cli_tcp.hpp"
41 
42 #include <openthread/nat64.h>
43 #include <openthread/tcp.h>
44 
45 #include "cli/cli.hpp"
46 #include "common/encoding.hpp"
47 #include "common/timer.hpp"
48 
49 #if OPENTHREAD_CONFIG_TLS_ENABLE
50 #include <mbedtls/debug.h>
51 #include <mbedtls/ecjpake.h>
52 #include "crypto/mbedtls.hpp"
53 #endif
54 
55 namespace ot {
56 namespace Cli {
57 
58 #if OPENTHREAD_CONFIG_TLS_ENABLE
59 const int TcpExample::sCipherSuites[] = {MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8,
60                                          MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, 0};
61 #endif
62 
TcpExample(otInstance * aInstance,OutputImplementer & aOutputImplementer)63 TcpExample::TcpExample(otInstance *aInstance, OutputImplementer &aOutputImplementer)
64     : Output(aInstance, aOutputImplementer)
65     , mInitialized(false)
66     , mEndpointConnected(false)
67     , mEndpointConnectedFastOpen(false)
68     , mSendBusy(false)
69     , mUseCircularSendBuffer(true)
70     , mUseTls(false)
71     , mTlsHandshakeComplete(false)
72     , mBenchmarkBytesTotal(0)
73     , mBenchmarkBytesUnsent(0)
74     , mBenchmarkTimeUsed(0)
75 {
76     mEndpointAndCircularSendBuffer.mEndpoint   = &mEndpoint;
77     mEndpointAndCircularSendBuffer.mSendBuffer = &mSendBuffer;
78 }
79 
80 #if OPENTHREAD_CONFIG_TLS_ENABLE
MbedTlsDebugOutput(void * ctx,int level,const char * file,int line,const char * str)81 void TcpExample::MbedTlsDebugOutput(void *ctx, int level, const char *file, int line, const char *str)
82 {
83     TcpExample &tcpExample = *static_cast<TcpExample *>(ctx);
84     tcpExample.OutputLine("%s:%d:%d: %s", file, line, level, str);
85 }
86 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
87 
Process(Arg aArgs[])88 template <> otError TcpExample::Process<Cmd("init")>(Arg aArgs[])
89 {
90     otError error = OT_ERROR_NONE;
91     size_t  receiveBufferSize;
92 
93     VerifyOrExit(!mInitialized, error = OT_ERROR_ALREADY);
94 
95     if (aArgs[0].IsEmpty())
96     {
97         mUseCircularSendBuffer = true;
98         mUseTls                = false;
99         receiveBufferSize      = sizeof(mReceiveBufferBytes);
100     }
101     else
102     {
103         if (aArgs[0] == "linked")
104         {
105             mUseCircularSendBuffer = false;
106             mUseTls                = false;
107         }
108         else if (aArgs[0] == "circular")
109         {
110             mUseCircularSendBuffer = true;
111             mUseTls                = false;
112         }
113 #if OPENTHREAD_CONFIG_TLS_ENABLE
114         else if (aArgs[0] == "tls")
115         {
116             mUseCircularSendBuffer = true;
117             mUseTls                = true;
118 
119             // mbedtls_debug_set_threshold(0);
120 
121             otPlatCryptoRandomInit();
122             mbedtls_x509_crt_init(&mSrvCert);
123             mbedtls_pk_init(&mPKey);
124 
125             mbedtls_ssl_init(&mSslContext);
126             mbedtls_ssl_config_init(&mSslConfig);
127             mbedtls_ssl_conf_rng(&mSslConfig, Crypto::MbedTls::CryptoSecurePrng, nullptr);
128             // mbedtls_ssl_conf_dbg(&mSslConfig, MbedTlsDebugOutput, this);
129             mbedtls_ssl_conf_authmode(&mSslConfig, MBEDTLS_SSL_VERIFY_NONE);
130             mbedtls_ssl_conf_ciphersuites(&mSslConfig, sCipherSuites);
131 
132 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
133             mbedtls_ssl_conf_min_tls_version(&mSslConfig, MBEDTLS_SSL_VERSION_TLS1_2);
134             mbedtls_ssl_conf_max_tls_version(&mSslConfig, MBEDTLS_SSL_VERSION_TLS1_2);
135 #else
136             mbedtls_ssl_conf_min_version(&mSslConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
137             mbedtls_ssl_conf_max_version(&mSslConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
138 #endif
139 
140 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
141 #include "crypto/mbedtls.hpp"
142             int rv = mbedtls_pk_parse_key(&mPKey, reinterpret_cast<const unsigned char *>(sSrvKey), sSrvKeyLength,
143                                           nullptr, 0, Crypto::MbedTls::CryptoSecurePrng, nullptr);
144 #else
145             int rv = mbedtls_pk_parse_key(&mPKey, reinterpret_cast<const unsigned char *>(sSrvKey), sSrvKeyLength,
146                                           nullptr, 0);
147 #endif
148             if (rv != 0)
149             {
150                 OutputLine("mbedtls_pk_parse_key returned %d", rv);
151             }
152 
153             rv = mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(sSrvPem), sSrvPemLength);
154             if (rv != 0)
155             {
156                 OutputLine("mbedtls_x509_crt_parse (1) returned %d", rv);
157             }
158             rv = mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(sCasPem), sCasPemLength);
159             if (rv != 0)
160             {
161                 OutputLine("mbedtls_x509_crt_parse (2) returned %d", rv);
162             }
163             rv = mbedtls_ssl_setup(&mSslContext, &mSslConfig);
164             if (rv != 0)
165             {
166                 OutputLine("mbedtls_ssl_setup returned %d", rv);
167             }
168         }
169 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
170         else
171         {
172             ExitNow(error = OT_ERROR_INVALID_ARGS);
173         }
174 
175         if (aArgs[1].IsEmpty())
176         {
177             receiveBufferSize = sizeof(mReceiveBufferBytes);
178         }
179         else
180         {
181             uint32_t windowSize;
182 
183             SuccessOrExit(error = aArgs[1].ParseAsUint32(windowSize));
184 
185             receiveBufferSize = windowSize + ((windowSize + 7) >> 3);
186             VerifyOrExit(receiveBufferSize <= sizeof(mReceiveBufferBytes) && receiveBufferSize != 0,
187                          error = OT_ERROR_INVALID_ARGS);
188         }
189     }
190 
191     otTcpCircularSendBufferInitialize(&mSendBuffer, mSendBufferBytes, sizeof(mSendBufferBytes));
192 
193     {
194         otTcpEndpointInitializeArgs endpointArgs;
195 
196         memset(&endpointArgs, 0x00, sizeof(endpointArgs));
197         endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
198         if (mUseCircularSendBuffer)
199         {
200             endpointArgs.mForwardProgressCallback = HandleTcpForwardProgressCallback;
201         }
202         else
203         {
204             endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
205         }
206         endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
207         endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
208         endpointArgs.mContext                  = this;
209         endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
210         endpointArgs.mReceiveBufferSize        = receiveBufferSize;
211 
212         SuccessOrExit(error = otTcpEndpointInitialize(GetInstancePtr(), &mEndpoint, &endpointArgs));
213     }
214 
215     {
216         otTcpListenerInitializeArgs listenerArgs;
217 
218         memset(&listenerArgs, 0x00, sizeof(listenerArgs));
219         listenerArgs.mAcceptReadyCallback = HandleTcpAcceptReadyCallback;
220         listenerArgs.mAcceptDoneCallback  = HandleTcpAcceptDoneCallback;
221         listenerArgs.mContext             = this;
222 
223         error = otTcpListenerInitialize(GetInstancePtr(), &mListener, &listenerArgs);
224         if (error != OT_ERROR_NONE)
225         {
226             IgnoreReturnValue(otTcpEndpointDeinitialize(&mEndpoint));
227             ExitNow();
228         }
229     }
230 
231     mInitialized = true;
232 
233 exit:
234     return error;
235 }
236 
Process(Arg aArgs[])237 template <> otError TcpExample::Process<Cmd("deinit")>(Arg aArgs[])
238 {
239     otError error = OT_ERROR_NONE;
240     otError endpointError;
241     otError bufferError;
242     otError listenerError;
243 
244     VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
245     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
246 
247 #if OPENTHREAD_CONFIG_TLS_ENABLE
248     if (mUseTls)
249     {
250         otPlatCryptoRandomDeinit();
251         mbedtls_ssl_config_free(&mSslConfig);
252         mbedtls_ssl_free(&mSslContext);
253 
254         mbedtls_pk_free(&mPKey);
255         mbedtls_x509_crt_free(&mSrvCert);
256     }
257 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
258 
259     endpointError = otTcpEndpointDeinitialize(&mEndpoint);
260     mSendBusy     = false;
261 
262     otTcpCircularSendBufferForceDiscardAll(&mSendBuffer);
263     bufferError = otTcpCircularSendBufferDeinitialize(&mSendBuffer);
264 
265     listenerError = otTcpListenerDeinitialize(&mListener);
266     mInitialized  = false;
267 
268     SuccessOrExit(error = endpointError);
269     SuccessOrExit(error = bufferError);
270     SuccessOrExit(error = listenerError);
271 
272 exit:
273     return error;
274 }
275 
Process(Arg aArgs[])276 template <> otError TcpExample::Process<Cmd("bind")>(Arg aArgs[])
277 {
278     otError    error;
279     otSockAddr sockaddr;
280 
281     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
282 
283     SuccessOrExit(error = aArgs[0].ParseAsIp6Address(sockaddr.mAddress));
284     SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
285     VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
286 
287     error = otTcpBind(&mEndpoint, &sockaddr);
288 
289 exit:
290     return error;
291 }
292 
Process(Arg aArgs[])293 template <> otError TcpExample::Process<Cmd("connect")>(Arg aArgs[])
294 {
295     otError    error;
296     otSockAddr sockaddr;
297     bool       nat64SynthesizedAddress;
298     uint32_t   flags;
299 
300     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
301 
302     SuccessOrExit(
303         error = Interpreter::ParseToIp6Address(GetInstancePtr(), aArgs[0], sockaddr.mAddress, nat64SynthesizedAddress));
304     if (nat64SynthesizedAddress)
305     {
306         OutputFormat("Connecting to synthesized IPv6 address: ");
307         OutputIp6AddressLine(sockaddr.mAddress);
308     }
309 
310     SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
311     if (aArgs[2].IsEmpty())
312     {
313         flags = OT_TCP_CONNECT_NO_FAST_OPEN;
314     }
315     else
316     {
317         if (aArgs[2] == "slow")
318         {
319             flags = OT_TCP_CONNECT_NO_FAST_OPEN;
320         }
321         else if (aArgs[2] == "fast")
322         {
323             flags = 0;
324         }
325         else
326         {
327             ExitNow(error = OT_ERROR_INVALID_ARGS);
328         }
329         VerifyOrExit(aArgs[3].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
330     }
331 
332 #if OPENTHREAD_CONFIG_TLS_ENABLE
333     if (mUseTls)
334     {
335         int rv = mbedtls_ssl_config_defaults(&mSslConfig, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM,
336                                              MBEDTLS_SSL_PRESET_DEFAULT);
337         if (rv != 0)
338         {
339             OutputLine("mbedtls_ssl_config_defaults returned %d", rv);
340         }
341     }
342 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
343 
344     SuccessOrExit(error = otTcpConnect(&mEndpoint, &sockaddr, flags));
345     mEndpointConnected         = true;
346     mEndpointConnectedFastOpen = ((flags & OT_TCP_CONNECT_NO_FAST_OPEN) == 0);
347 
348 #if OPENTHREAD_CONFIG_TLS_ENABLE
349     if (mUseTls && mEndpointConnectedFastOpen)
350     {
351         PrepareTlsHandshake();
352         ContinueTlsHandshake();
353     }
354 #endif
355 
356 exit:
357     return error;
358 }
359 
Process(Arg aArgs[])360 template <> otError TcpExample::Process<Cmd("send")>(Arg aArgs[])
361 {
362     otError error;
363 
364     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
365     VerifyOrExit(mBenchmarkBytesTotal == 0, error = OT_ERROR_BUSY);
366     VerifyOrExit(!aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
367     VerifyOrExit(aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
368 
369     if (mUseCircularSendBuffer)
370     {
371 #if OPENTHREAD_CONFIG_TLS_ENABLE
372         if (mUseTls)
373         {
374             int rv = mbedtls_ssl_write(&mSslContext, reinterpret_cast<unsigned char *>(aArgs[0].GetCString()),
375                                        aArgs[0].GetLength());
376             if (rv < 0 && rv != MBEDTLS_ERR_SSL_WANT_WRITE && rv != MBEDTLS_ERR_SSL_WANT_READ)
377             {
378                 ExitNow(error = kErrorFailed);
379             }
380             error = kErrorNone;
381         }
382         else
383 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
384         {
385             size_t written;
386             SuccessOrExit(error = otTcpCircularSendBufferWrite(&mEndpoint, &mSendBuffer, aArgs[0].GetCString(),
387                                                                aArgs[0].GetLength(), &written, 0));
388         }
389     }
390     else
391     {
392         VerifyOrExit(!mSendBusy, error = OT_ERROR_BUSY);
393 
394         mSendLink.mNext   = nullptr;
395         mSendLink.mData   = mSendBufferBytes;
396         mSendLink.mLength = OT_MIN(aArgs[0].GetLength(), sizeof(mSendBufferBytes));
397         memcpy(mSendBufferBytes, aArgs[0].GetCString(), mSendLink.mLength);
398 
399         SuccessOrExit(error = otTcpSendByReference(&mEndpoint, &mSendLink, 0));
400         mSendBusy = true;
401     }
402 
403 exit:
404     return error;
405 }
406 
Process(Arg aArgs[])407 template <> otError TcpExample::Process<Cmd("benchmark")>(Arg aArgs[])
408 {
409     otError error = OT_ERROR_NONE;
410 
411     if (aArgs[0] == "result")
412     {
413         OutputFormat("TCP Benchmark Status: ");
414         if (mBenchmarkBytesTotal != 0)
415         {
416             OutputLine("Ongoing");
417         }
418         else if (mBenchmarkTimeUsed != 0)
419         {
420             OutputLine("Completed");
421             OutputBenchmarkResult();
422         }
423         else
424         {
425             OutputLine("Untested");
426         }
427     }
428     else if (aArgs[0] == "run")
429     {
430         VerifyOrExit(!mSendBusy, error = OT_ERROR_BUSY);
431         VerifyOrExit(mBenchmarkBytesTotal == 0, error = OT_ERROR_BUSY);
432 
433         if (aArgs[1].IsEmpty())
434         {
435             mBenchmarkBytesTotal = OPENTHREAD_CONFIG_CLI_TCP_DEFAULT_BENCHMARK_SIZE;
436         }
437         else
438         {
439             SuccessOrExit(error = aArgs[1].ParseAsUint32(mBenchmarkBytesTotal));
440             VerifyOrExit(mBenchmarkBytesTotal != 0, error = OT_ERROR_INVALID_ARGS);
441         }
442         VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
443 
444         mBenchmarkStart       = TimerMilli::GetNow();
445         mBenchmarkBytesUnsent = mBenchmarkBytesTotal;
446 
447         if (mUseCircularSendBuffer)
448         {
449             SuccessOrExit(error = ContinueBenchmarkCircularSend());
450         }
451         else
452         {
453             uint32_t benchmarkLinksLeft =
454                 (mBenchmarkBytesTotal + sizeof(mSendBufferBytes) - 1) / sizeof(mSendBufferBytes);
455             uint32_t toSendOut = OT_MIN(OT_ARRAY_LENGTH(mBenchmarkLinks), benchmarkLinksLeft);
456 
457             /* We could also point the linked buffers directly to sBenchmarkData. */
458             memset(mSendBufferBytes, 'a', sizeof(mSendBufferBytes));
459 
460             for (uint32_t i = 0; i != toSendOut; i++)
461             {
462                 mBenchmarkLinks[i].mNext   = nullptr;
463                 mBenchmarkLinks[i].mData   = mSendBufferBytes;
464                 mBenchmarkLinks[i].mLength = sizeof(mSendBufferBytes);
465                 if (i == 0 && mBenchmarkBytesTotal % sizeof(mSendBufferBytes) != 0)
466                 {
467                     mBenchmarkLinks[i].mLength = mBenchmarkBytesTotal % sizeof(mSendBufferBytes);
468                 }
469                 error = otTcpSendByReference(&mEndpoint, &mBenchmarkLinks[i],
470                                              i == toSendOut - 1 ? 0 : OT_TCP_SEND_MORE_TO_COME);
471                 VerifyOrExit(error == OT_ERROR_NONE, mBenchmarkBytesTotal = 0);
472             }
473         }
474     }
475     else
476     {
477         error = OT_ERROR_INVALID_ARGS;
478     }
479 
480 exit:
481     return error;
482 }
483 
Process(Arg aArgs[])484 template <> otError TcpExample::Process<Cmd("sendend")>(Arg aArgs[])
485 {
486     otError error;
487 
488     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
489     VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
490 
491     error = otTcpSendEndOfStream(&mEndpoint);
492 
493 exit:
494     return error;
495 }
496 
Process(Arg aArgs[])497 template <> otError TcpExample::Process<Cmd("abort")>(Arg aArgs[])
498 {
499     otError error;
500 
501     VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
502     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
503 
504     SuccessOrExit(error = otTcpAbort(&mEndpoint));
505     mEndpointConnected         = false;
506     mEndpointConnectedFastOpen = false;
507 
508 exit:
509     return error;
510 }
511 
Process(Arg aArgs[])512 template <> otError TcpExample::Process<Cmd("listen")>(Arg aArgs[])
513 {
514     otError    error;
515     otSockAddr sockaddr;
516 
517     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
518 
519     SuccessOrExit(error = aArgs[0].ParseAsIp6Address(sockaddr.mAddress));
520     SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
521     VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
522 
523     SuccessOrExit(error = otTcpStopListening(&mListener));
524     error = otTcpListen(&mListener, &sockaddr);
525 
526 exit:
527     return error;
528 }
529 
Process(Arg aArgs[])530 template <> otError TcpExample::Process<Cmd("stoplistening")>(Arg aArgs[])
531 {
532     otError error;
533 
534     VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
535     VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
536 
537     error = otTcpStopListening(&mListener);
538 
539 exit:
540     return error;
541 }
542 
Process(Arg aArgs[])543 otError TcpExample::Process(Arg aArgs[])
544 {
545 #define CmdEntry(aCommandString)                                  \
546     {                                                             \
547         aCommandString, &TcpExample::Process<Cmd(aCommandString)> \
548     }
549 
550     static constexpr Command kCommands[] = {
551         CmdEntry("abort"), CmdEntry("benchmark"), CmdEntry("bind"), CmdEntry("connect"), CmdEntry("deinit"),
552         CmdEntry("init"),  CmdEntry("listen"),    CmdEntry("send"), CmdEntry("sendend"), CmdEntry("stoplistening"),
553     };
554 
555     static_assert(BinarySearch::IsSorted(kCommands), "kCommands is not sorted");
556 
557     otError        error = OT_ERROR_INVALID_COMMAND;
558     const Command *command;
559 
560     if (aArgs[0].IsEmpty() || (aArgs[0] == "help"))
561     {
562         OutputCommandTable(kCommands);
563         ExitNow(error = aArgs[0].IsEmpty() ? error : OT_ERROR_NONE);
564     }
565 
566     command = BinarySearch::Find(aArgs[0].GetCString(), kCommands);
567     VerifyOrExit(command != nullptr);
568 
569     error = (this->*command->mHandler)(aArgs + 1);
570 
571 exit:
572     return error;
573 }
574 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)575 void TcpExample::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
576 {
577     static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
578 }
579 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)580 void TcpExample::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
581 {
582     static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
583 }
584 
HandleTcpForwardProgressCallback(otTcpEndpoint * aEndpoint,size_t aInSendBuffer,size_t aBacklog)585 void TcpExample::HandleTcpForwardProgressCallback(otTcpEndpoint *aEndpoint, size_t aInSendBuffer, size_t aBacklog)
586 {
587     static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))
588         ->HandleTcpForwardProgress(aEndpoint, aInSendBuffer, aBacklog);
589 }
590 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)591 void TcpExample::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
592                                                    size_t         aBytesAvailable,
593                                                    bool           aEndOfStream,
594                                                    size_t         aBytesRemaining)
595 {
596     static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))
597         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
598 }
599 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)600 void TcpExample::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
601 {
602     static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
603 }
604 
HandleTcpAcceptReadyCallback(otTcpListener * aListener,const otSockAddr * aPeer,otTcpEndpoint ** aAcceptInto)605 otTcpIncomingConnectionAction TcpExample::HandleTcpAcceptReadyCallback(otTcpListener    *aListener,
606                                                                        const otSockAddr *aPeer,
607                                                                        otTcpEndpoint   **aAcceptInto)
608 {
609     return static_cast<TcpExample *>(otTcpListenerGetContext(aListener))
610         ->HandleTcpAcceptReady(aListener, aPeer, aAcceptInto);
611 }
612 
HandleTcpAcceptDoneCallback(otTcpListener * aListener,otTcpEndpoint * aEndpoint,const otSockAddr * aPeer)613 void TcpExample::HandleTcpAcceptDoneCallback(otTcpListener    *aListener,
614                                              otTcpEndpoint    *aEndpoint,
615                                              const otSockAddr *aPeer)
616 {
617     static_cast<TcpExample *>(otTcpListenerGetContext(aListener))->HandleTcpAcceptDone(aListener, aEndpoint, aPeer);
618 }
619 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)620 void TcpExample::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
621 {
622     OT_UNUSED_VARIABLE(aEndpoint);
623     OutputLine("TCP: Connection established");
624 #if OPENTHREAD_CONFIG_TLS_ENABLE
625     if (mUseTls && !mEndpointConnectedFastOpen)
626     {
627         PrepareTlsHandshake();
628         ContinueTlsHandshake();
629     }
630 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
631 }
632 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)633 void TcpExample::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
634 {
635     OT_UNUSED_VARIABLE(aEndpoint);
636     OT_ASSERT(!mUseCircularSendBuffer); // this callback is not used when using the circular send buffer
637 
638     if (mBenchmarkBytesTotal == 0)
639     {
640         // If the benchmark encountered an error, we might end up here. So,
641         // tolerate some benchmark links finishing in this case.
642         if (aData == &mSendLink)
643         {
644             OT_ASSERT(mSendBusy);
645             mSendBusy = false;
646         }
647     }
648     else
649     {
650         OT_ASSERT(aData != &mSendLink);
651         OT_ASSERT(mBenchmarkBytesUnsent >= aData->mLength);
652         mBenchmarkBytesUnsent -= aData->mLength; // could be less than sizeof(mSendBufferBytes) for the first link
653         if (mBenchmarkBytesUnsent >= OT_ARRAY_LENGTH(mBenchmarkLinks) * sizeof(mSendBufferBytes))
654         {
655             aData->mLength = sizeof(mSendBufferBytes);
656             if (otTcpSendByReference(&mEndpoint, aData, 0) != OT_ERROR_NONE)
657             {
658                 OutputLine("TCP Benchmark Failed");
659                 mBenchmarkBytesTotal = 0;
660             }
661         }
662         else if (mBenchmarkBytesUnsent == 0)
663         {
664             CompleteBenchmark();
665         }
666     }
667 }
668 
HandleTcpForwardProgress(otTcpEndpoint * aEndpoint,size_t aInSendBuffer,size_t aBacklog)669 void TcpExample::HandleTcpForwardProgress(otTcpEndpoint *aEndpoint, size_t aInSendBuffer, size_t aBacklog)
670 {
671     OT_UNUSED_VARIABLE(aEndpoint);
672     OT_UNUSED_VARIABLE(aBacklog);
673     OT_ASSERT(mUseCircularSendBuffer); // this callback is only used when using the circular send buffer
674 
675     otTcpCircularSendBufferHandleForwardProgress(&mSendBuffer, aInSendBuffer);
676 
677 #if OPENTHREAD_CONFIG_TLS_ENABLE
678     if (mUseTls)
679     {
680         ContinueTlsHandshake();
681     }
682 #endif
683 
684     /* Handle case where we're in a benchmark. */
685     if (mBenchmarkBytesTotal != 0)
686     {
687         if (mBenchmarkBytesUnsent != 0)
688         {
689             /* Continue sending out data if there's data we haven't sent. */
690             IgnoreError(ContinueBenchmarkCircularSend());
691         }
692         else if (aInSendBuffer == 0)
693         {
694             /* Handle case where all data is sent out and the send buffer has drained. */
695             CompleteBenchmark();
696         }
697     }
698 }
699 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)700 void TcpExample::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
701                                            size_t         aBytesAvailable,
702                                            bool           aEndOfStream,
703                                            size_t         aBytesRemaining)
704 {
705     OT_UNUSED_VARIABLE(aBytesRemaining);
706     OT_ASSERT(aEndpoint == &mEndpoint);
707 
708     /* If we get data before the handshake completes, then this is a TFO connection. */
709     if (!mEndpointConnected)
710     {
711         mEndpointConnected         = true;
712         mEndpointConnectedFastOpen = true;
713 
714 #if OPENTHREAD_CONFIG_TLS_ENABLE
715         if (mUseTls)
716         {
717             PrepareTlsHandshake();
718         }
719 #endif
720     }
721 
722 #if OPENTHREAD_CONFIG_TLS_ENABLE
723     if (mUseTls && ContinueTlsHandshake())
724     {
725         return;
726     }
727 #endif
728 
729     if ((mTlsHandshakeComplete || !mUseTls) && aBytesAvailable > 0)
730     {
731 #if OPENTHREAD_CONFIG_TLS_ENABLE
732         if (mUseTls)
733         {
734             uint8_t buffer[500];
735             for (;;)
736             {
737                 int rv = mbedtls_ssl_read(&mSslContext, buffer, sizeof(buffer));
738                 if (rv < 0)
739                 {
740                     if (rv == MBEDTLS_ERR_SSL_WANT_READ)
741                     {
742                         break;
743                     }
744                     OutputLine("TLS receive failure: %d", rv);
745                 }
746                 else
747                 {
748                     OutputLine("TLS: Received %u bytes: %.*s", static_cast<unsigned>(rv), rv,
749                                reinterpret_cast<const char *>(buffer));
750                 }
751             }
752             OutputLine("(TCP: Received %u bytes)", static_cast<unsigned>(aBytesAvailable));
753         }
754         else
755 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
756         {
757             const otLinkedBuffer *data;
758             size_t                totalReceived = 0;
759             IgnoreError(otTcpReceiveByReference(aEndpoint, &data));
760             for (; data != nullptr; data = data->mNext)
761             {
762                 OutputLine("TCP: Received %u bytes: %.*s", static_cast<unsigned>(data->mLength),
763                            static_cast<unsigned>(data->mLength), reinterpret_cast<const char *>(data->mData));
764                 totalReceived += data->mLength;
765             }
766             OT_ASSERT(aBytesAvailable == totalReceived);
767             IgnoreReturnValue(otTcpCommitReceive(aEndpoint, totalReceived, 0));
768         }
769     }
770 
771     if (aEndOfStream)
772     {
773         OutputLine("TCP: Reached end of stream");
774     }
775 }
776 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)777 void TcpExample::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
778 {
779     static const char *const kReasonStrings[] = {
780         "Disconnected",            // (0) OT_TCP_DISCONNECTED_REASON_NORMAL
781         "Connection refused",      // (1) OT_TCP_DISCONNECTED_REASON_REFUSED
782         "Connection reset",        // (2) OT_TCP_DISCONNECTED_REASON_RESET
783         "Entered TIME-WAIT state", // (3) OT_TCP_DISCONNECTED_REASON_TIME_WAIT
784         "Connection timed out",    // (4) OT_TCP_DISCONNECTED_REASON_TIMED_OUT
785     };
786 
787     OT_UNUSED_VARIABLE(aEndpoint);
788 
789     static_assert(0 == OT_TCP_DISCONNECTED_REASON_NORMAL, "OT_TCP_DISCONNECTED_REASON_NORMAL value is incorrect");
790     static_assert(1 == OT_TCP_DISCONNECTED_REASON_REFUSED, "OT_TCP_DISCONNECTED_REASON_REFUSED value is incorrect");
791     static_assert(2 == OT_TCP_DISCONNECTED_REASON_RESET, "OT_TCP_DISCONNECTED_REASON_RESET value is incorrect");
792     static_assert(3 == OT_TCP_DISCONNECTED_REASON_TIME_WAIT, "OT_TCP_DISCONNECTED_REASON_TIME_WAIT value is incorrect");
793     static_assert(4 == OT_TCP_DISCONNECTED_REASON_TIMED_OUT, "OT_TCP_DISCONNECTED_REASON_TIMED_OUT value is incorrect");
794 
795     OutputLine("TCP: %s", Stringify(aReason, kReasonStrings));
796 
797 #if OPENTHREAD_CONFIG_TLS_ENABLE
798     if (mUseTls)
799     {
800         mbedtls_ssl_session_reset(&mSslContext);
801     }
802 #endif
803 
804     // We set this to false even for the TIME-WAIT state, so that we can reuse
805     // the active socket if an incoming connection comes in instead of waiting
806     // for the 2MSL timeout.
807     mEndpointConnected         = false;
808     mEndpointConnectedFastOpen = false;
809     mSendBusy                  = false;
810 
811     // Mark the benchmark as inactive if the connection was disconnected.
812     mBenchmarkBytesTotal  = 0;
813     mBenchmarkBytesUnsent = 0;
814 
815     otTcpCircularSendBufferForceDiscardAll(&mSendBuffer);
816 }
817 
HandleTcpAcceptReady(otTcpListener * aListener,const otSockAddr * aPeer,otTcpEndpoint ** aAcceptInto)818 otTcpIncomingConnectionAction TcpExample::HandleTcpAcceptReady(otTcpListener    *aListener,
819                                                                const otSockAddr *aPeer,
820                                                                otTcpEndpoint   **aAcceptInto)
821 {
822     otTcpIncomingConnectionAction action;
823 
824     OT_UNUSED_VARIABLE(aListener);
825 
826     if (mEndpointConnected)
827     {
828         OutputFormat("TCP: Ignoring incoming connection request from ");
829         OutputSockAddr(*aPeer);
830         OutputLine(" (active socket is busy)");
831 
832         ExitNow(action = OT_TCP_INCOMING_CONNECTION_ACTION_DEFER);
833     }
834 
835     *aAcceptInto = &mEndpoint;
836     action       = OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT;
837 
838 #if OPENTHREAD_CONFIG_TLS_ENABLE
839     /*
840      * Natural to wait until the AcceptDone callback but with TFO we could get data before that
841      * so it doesn't make sense to wait until then.
842      */
843     if (mUseTls)
844     {
845         int rv;
846 
847         rv = mbedtls_ssl_config_defaults(&mSslConfig, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
848                                          MBEDTLS_SSL_PRESET_DEFAULT);
849         if (rv != 0)
850         {
851             OutputLine("mbedtls_ssl_config_defaults returned %d", rv);
852         }
853         mbedtls_ssl_conf_ca_chain(&mSslConfig, mSrvCert.next, nullptr);
854         rv = mbedtls_ssl_conf_own_cert(&mSslConfig, &mSrvCert, &mPKey);
855         if (rv != 0)
856         {
857             OutputLine("mbedtls_ssl_conf_own_cert returned %d", rv);
858         }
859     }
860 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
861 
862 exit:
863     return action;
864 }
865 
HandleTcpAcceptDone(otTcpListener * aListener,otTcpEndpoint * aEndpoint,const otSockAddr * aPeer)866 void TcpExample::HandleTcpAcceptDone(otTcpListener *aListener, otTcpEndpoint *aEndpoint, const otSockAddr *aPeer)
867 {
868     OT_UNUSED_VARIABLE(aListener);
869     OT_UNUSED_VARIABLE(aEndpoint);
870 
871     mEndpointConnected = true;
872     OutputFormat("Accepted connection from ");
873     OutputSockAddrLine(*aPeer);
874 }
875 
ContinueBenchmarkCircularSend(void)876 otError TcpExample::ContinueBenchmarkCircularSend(void)
877 {
878     otError error = OT_ERROR_NONE;
879     size_t  freeSpace;
880 
881     while (mBenchmarkBytesUnsent != 0 && (freeSpace = otTcpCircularSendBufferGetFreeSpace(&mSendBuffer)) != 0)
882     {
883         size_t   toSendThisIteration = OT_MIN(mBenchmarkBytesUnsent, sBenchmarkDataLength);
884         uint32_t flag                = (toSendThisIteration < freeSpace && toSendThisIteration < mBenchmarkBytesUnsent)
885                                            ? OT_TCP_CIRCULAR_SEND_BUFFER_WRITE_MORE_TO_COME
886                                            : 0;
887         size_t   written             = 0;
888 
889 #if OPENTHREAD_CONFIG_TLS_ENABLE
890         if (mUseTls)
891         {
892             int rv = mbedtls_ssl_write(&mSslContext, reinterpret_cast<const unsigned char *>(sBenchmarkData),
893                                        toSendThisIteration);
894             if (rv > 0)
895             {
896                 written = static_cast<size_t>(rv);
897                 OT_ASSERT(written <= mBenchmarkBytesUnsent);
898             }
899             else if (rv != MBEDTLS_ERR_SSL_WANT_WRITE && rv != MBEDTLS_ERR_SSL_WANT_READ)
900             {
901                 ExitNow(error = kErrorFailed);
902             }
903             error = kErrorNone;
904         }
905         else
906 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
907         {
908             SuccessOrExit(error = otTcpCircularSendBufferWrite(&mEndpoint, &mSendBuffer, sBenchmarkData,
909                                                                toSendThisIteration, &written, flag));
910         }
911         mBenchmarkBytesUnsent -= written;
912     }
913 
914 exit:
915     if (error != OT_ERROR_NONE)
916     {
917         OutputLine("TCP Benchmark Failed");
918         mBenchmarkBytesTotal  = 0;
919         mBenchmarkBytesUnsent = 0;
920     }
921 
922     return error;
923 }
924 
OutputBenchmarkResult(void)925 void TcpExample::OutputBenchmarkResult(void)
926 {
927     uint32_t thousandTimesGoodput =
928         (1000 * (mBenchmarkLastBytesTotal << 3) + (mBenchmarkTimeUsed >> 1)) / mBenchmarkTimeUsed;
929 
930     OutputLine("TCP Benchmark Complete: Transferred %lu bytes in %lu milliseconds", ToUlong(mBenchmarkLastBytesTotal),
931                ToUlong(mBenchmarkTimeUsed));
932     OutputLine("TCP Goodput: %lu.%03u kb/s", ToUlong(thousandTimesGoodput / 1000),
933                static_cast<uint16_t>(thousandTimesGoodput % 1000));
934 }
935 
CompleteBenchmark(void)936 void TcpExample::CompleteBenchmark(void)
937 {
938     mBenchmarkTimeUsed       = TimerMilli::GetNow() - mBenchmarkStart;
939     mBenchmarkLastBytesTotal = mBenchmarkBytesTotal;
940 
941     OutputBenchmarkResult();
942 
943     mBenchmarkBytesTotal = 0;
944 }
945 
946 #if OPENTHREAD_CONFIG_TLS_ENABLE
PrepareTlsHandshake(void)947 void TcpExample::PrepareTlsHandshake(void)
948 {
949     int rv;
950     rv = mbedtls_ssl_set_hostname(&mSslContext, "localhost");
951     if (rv != 0)
952     {
953         OutputLine("mbedtls_ssl_set_hostname returned %d", rv);
954     }
955     rv = mbedtls_ssl_set_hs_ecjpake_password(&mSslContext, reinterpret_cast<const unsigned char *>(sEcjpakePassword),
956                                              sEcjpakePasswordLength);
957     if (rv != 0)
958     {
959         OutputLine("mbedtls_ssl_set_hs_ecjpake_password returned %d", rv);
960     }
961     mbedtls_ssl_set_bio(&mSslContext, &mEndpointAndCircularSendBuffer, otTcpMbedTlsSslSendCallback,
962                         otTcpMbedTlsSslRecvCallback, nullptr);
963     mTlsHandshakeComplete = false;
964 }
965 
ContinueTlsHandshake(void)966 bool TcpExample::ContinueTlsHandshake(void)
967 {
968     bool wasNotAlreadyDone = false;
969     int  rv;
970 
971     if (!mTlsHandshakeComplete)
972     {
973         rv = mbedtls_ssl_handshake(&mSslContext);
974         if (rv == 0)
975         {
976             OutputLine("TLS Handshake Complete");
977             mTlsHandshakeComplete = true;
978         }
979         else if (rv != MBEDTLS_ERR_SSL_WANT_READ && rv != MBEDTLS_ERR_SSL_WANT_WRITE)
980         {
981             OutputLine("TLS Handshake Failed: %d", rv);
982         }
983         wasNotAlreadyDone = true;
984     }
985 
986     return wasNotAlreadyDone;
987 }
988 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
989 
990 } // namespace Cli
991 } // namespace ot
992 
993 #endif // OPENTHREAD_CONFIG_TCP_ENABLE && OPENTHREAD_CONFIG_CLI_TCP_ENABLE
994