1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements a TCP CLI tool.
32 */
33
34 #include "openthread-core-config.h"
35
36 #include "cli_config.h"
37
38 #if OPENTHREAD_CONFIG_TCP_ENABLE && OPENTHREAD_CONFIG_CLI_TCP_ENABLE
39
40 #include "cli_tcp.hpp"
41
42 #include <openthread/nat64.h>
43 #include <openthread/tcp.h>
44
45 #include "cli/cli.hpp"
46 #include "common/encoding.hpp"
47 #include "common/timer.hpp"
48
49 #if OPENTHREAD_CONFIG_TLS_ENABLE
50 #include <mbedtls/debug.h>
51 #include <mbedtls/ecjpake.h>
52 #include "crypto/mbedtls.hpp"
53 #endif
54
55 namespace ot {
56 namespace Cli {
57
58 #if OPENTHREAD_CONFIG_TLS_ENABLE
59 const int TcpExample::sCipherSuites[] = {MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8,
60 MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, 0};
61 #endif
62
TcpExample(otInstance * aInstance,OutputImplementer & aOutputImplementer)63 TcpExample::TcpExample(otInstance *aInstance, OutputImplementer &aOutputImplementer)
64 : Output(aInstance, aOutputImplementer)
65 , mInitialized(false)
66 , mEndpointConnected(false)
67 , mEndpointConnectedFastOpen(false)
68 , mSendBusy(false)
69 , mUseCircularSendBuffer(true)
70 , mUseTls(false)
71 , mTlsHandshakeComplete(false)
72 , mBenchmarkBytesTotal(0)
73 , mBenchmarkBytesUnsent(0)
74 , mBenchmarkTimeUsed(0)
75 {
76 mEndpointAndCircularSendBuffer.mEndpoint = &mEndpoint;
77 mEndpointAndCircularSendBuffer.mSendBuffer = &mSendBuffer;
78 }
79
80 #if OPENTHREAD_CONFIG_TLS_ENABLE
MbedTlsDebugOutput(void * ctx,int level,const char * file,int line,const char * str)81 void TcpExample::MbedTlsDebugOutput(void *ctx, int level, const char *file, int line, const char *str)
82 {
83 TcpExample &tcpExample = *static_cast<TcpExample *>(ctx);
84 tcpExample.OutputLine("%s:%d:%d: %s", file, line, level, str);
85 }
86 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
87
Process(Arg aArgs[])88 template <> otError TcpExample::Process<Cmd("init")>(Arg aArgs[])
89 {
90 otError error = OT_ERROR_NONE;
91 size_t receiveBufferSize;
92
93 VerifyOrExit(!mInitialized, error = OT_ERROR_ALREADY);
94
95 if (aArgs[0].IsEmpty())
96 {
97 mUseCircularSendBuffer = true;
98 mUseTls = false;
99 receiveBufferSize = sizeof(mReceiveBufferBytes);
100 }
101 else
102 {
103 if (aArgs[0] == "linked")
104 {
105 mUseCircularSendBuffer = false;
106 mUseTls = false;
107 }
108 else if (aArgs[0] == "circular")
109 {
110 mUseCircularSendBuffer = true;
111 mUseTls = false;
112 }
113 #if OPENTHREAD_CONFIG_TLS_ENABLE
114 else if (aArgs[0] == "tls")
115 {
116 mUseCircularSendBuffer = true;
117 mUseTls = true;
118
119 // mbedtls_debug_set_threshold(0);
120
121 otPlatCryptoRandomInit();
122 mbedtls_x509_crt_init(&mSrvCert);
123 mbedtls_pk_init(&mPKey);
124
125 mbedtls_ssl_init(&mSslContext);
126 mbedtls_ssl_config_init(&mSslConfig);
127 mbedtls_ssl_conf_rng(&mSslConfig, Crypto::MbedTls::CryptoSecurePrng, nullptr);
128 // mbedtls_ssl_conf_dbg(&mSslConfig, MbedTlsDebugOutput, this);
129 mbedtls_ssl_conf_authmode(&mSslConfig, MBEDTLS_SSL_VERIFY_NONE);
130 mbedtls_ssl_conf_ciphersuites(&mSslConfig, sCipherSuites);
131
132 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
133 mbedtls_ssl_conf_min_tls_version(&mSslConfig, MBEDTLS_SSL_VERSION_TLS1_2);
134 mbedtls_ssl_conf_max_tls_version(&mSslConfig, MBEDTLS_SSL_VERSION_TLS1_2);
135 #else
136 mbedtls_ssl_conf_min_version(&mSslConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
137 mbedtls_ssl_conf_max_version(&mSslConfig, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
138 #endif
139
140 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
141 #include "crypto/mbedtls.hpp"
142 int rv = mbedtls_pk_parse_key(&mPKey, reinterpret_cast<const unsigned char *>(sSrvKey), sSrvKeyLength,
143 nullptr, 0, Crypto::MbedTls::CryptoSecurePrng, nullptr);
144 #else
145 int rv = mbedtls_pk_parse_key(&mPKey, reinterpret_cast<const unsigned char *>(sSrvKey), sSrvKeyLength,
146 nullptr, 0);
147 #endif
148 if (rv != 0)
149 {
150 OutputLine("mbedtls_pk_parse_key returned %d", rv);
151 }
152
153 rv = mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(sSrvPem), sSrvPemLength);
154 if (rv != 0)
155 {
156 OutputLine("mbedtls_x509_crt_parse (1) returned %d", rv);
157 }
158 rv = mbedtls_x509_crt_parse(&mSrvCert, reinterpret_cast<const unsigned char *>(sCasPem), sCasPemLength);
159 if (rv != 0)
160 {
161 OutputLine("mbedtls_x509_crt_parse (2) returned %d", rv);
162 }
163 rv = mbedtls_ssl_setup(&mSslContext, &mSslConfig);
164 if (rv != 0)
165 {
166 OutputLine("mbedtls_ssl_setup returned %d", rv);
167 }
168 }
169 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
170 else
171 {
172 ExitNow(error = OT_ERROR_INVALID_ARGS);
173 }
174
175 if (aArgs[1].IsEmpty())
176 {
177 receiveBufferSize = sizeof(mReceiveBufferBytes);
178 }
179 else
180 {
181 uint32_t windowSize;
182
183 SuccessOrExit(error = aArgs[1].ParseAsUint32(windowSize));
184
185 receiveBufferSize = windowSize + ((windowSize + 7) >> 3);
186 VerifyOrExit(receiveBufferSize <= sizeof(mReceiveBufferBytes) && receiveBufferSize != 0,
187 error = OT_ERROR_INVALID_ARGS);
188 }
189 }
190
191 otTcpCircularSendBufferInitialize(&mSendBuffer, mSendBufferBytes, sizeof(mSendBufferBytes));
192
193 {
194 otTcpEndpointInitializeArgs endpointArgs;
195
196 memset(&endpointArgs, 0x00, sizeof(endpointArgs));
197 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
198 if (mUseCircularSendBuffer)
199 {
200 endpointArgs.mForwardProgressCallback = HandleTcpForwardProgressCallback;
201 }
202 else
203 {
204 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
205 }
206 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
207 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
208 endpointArgs.mContext = this;
209 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
210 endpointArgs.mReceiveBufferSize = receiveBufferSize;
211
212 SuccessOrExit(error = otTcpEndpointInitialize(GetInstancePtr(), &mEndpoint, &endpointArgs));
213 }
214
215 {
216 otTcpListenerInitializeArgs listenerArgs;
217
218 memset(&listenerArgs, 0x00, sizeof(listenerArgs));
219 listenerArgs.mAcceptReadyCallback = HandleTcpAcceptReadyCallback;
220 listenerArgs.mAcceptDoneCallback = HandleTcpAcceptDoneCallback;
221 listenerArgs.mContext = this;
222
223 error = otTcpListenerInitialize(GetInstancePtr(), &mListener, &listenerArgs);
224 if (error != OT_ERROR_NONE)
225 {
226 IgnoreReturnValue(otTcpEndpointDeinitialize(&mEndpoint));
227 ExitNow();
228 }
229 }
230
231 mInitialized = true;
232
233 exit:
234 return error;
235 }
236
Process(Arg aArgs[])237 template <> otError TcpExample::Process<Cmd("deinit")>(Arg aArgs[])
238 {
239 otError error = OT_ERROR_NONE;
240 otError endpointError;
241 otError bufferError;
242 otError listenerError;
243
244 VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
245 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
246
247 #if OPENTHREAD_CONFIG_TLS_ENABLE
248 if (mUseTls)
249 {
250 otPlatCryptoRandomDeinit();
251 mbedtls_ssl_config_free(&mSslConfig);
252 mbedtls_ssl_free(&mSslContext);
253
254 mbedtls_pk_free(&mPKey);
255 mbedtls_x509_crt_free(&mSrvCert);
256 }
257 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
258
259 endpointError = otTcpEndpointDeinitialize(&mEndpoint);
260 mSendBusy = false;
261
262 otTcpCircularSendBufferForceDiscardAll(&mSendBuffer);
263 bufferError = otTcpCircularSendBufferDeinitialize(&mSendBuffer);
264
265 listenerError = otTcpListenerDeinitialize(&mListener);
266 mInitialized = false;
267
268 SuccessOrExit(error = endpointError);
269 SuccessOrExit(error = bufferError);
270 SuccessOrExit(error = listenerError);
271
272 exit:
273 return error;
274 }
275
Process(Arg aArgs[])276 template <> otError TcpExample::Process<Cmd("bind")>(Arg aArgs[])
277 {
278 otError error;
279 otSockAddr sockaddr;
280
281 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
282
283 SuccessOrExit(error = aArgs[0].ParseAsIp6Address(sockaddr.mAddress));
284 SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
285 VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
286
287 error = otTcpBind(&mEndpoint, &sockaddr);
288
289 exit:
290 return error;
291 }
292
Process(Arg aArgs[])293 template <> otError TcpExample::Process<Cmd("connect")>(Arg aArgs[])
294 {
295 otError error;
296 otSockAddr sockaddr;
297 bool nat64SynthesizedAddress;
298 uint32_t flags;
299
300 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
301
302 SuccessOrExit(
303 error = Interpreter::ParseToIp6Address(GetInstancePtr(), aArgs[0], sockaddr.mAddress, nat64SynthesizedAddress));
304 if (nat64SynthesizedAddress)
305 {
306 OutputFormat("Connecting to synthesized IPv6 address: ");
307 OutputIp6AddressLine(sockaddr.mAddress);
308 }
309
310 SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
311 if (aArgs[2].IsEmpty())
312 {
313 flags = OT_TCP_CONNECT_NO_FAST_OPEN;
314 }
315 else
316 {
317 if (aArgs[2] == "slow")
318 {
319 flags = OT_TCP_CONNECT_NO_FAST_OPEN;
320 }
321 else if (aArgs[2] == "fast")
322 {
323 flags = 0;
324 }
325 else
326 {
327 ExitNow(error = OT_ERROR_INVALID_ARGS);
328 }
329 VerifyOrExit(aArgs[3].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
330 }
331
332 #if OPENTHREAD_CONFIG_TLS_ENABLE
333 if (mUseTls)
334 {
335 int rv = mbedtls_ssl_config_defaults(&mSslConfig, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM,
336 MBEDTLS_SSL_PRESET_DEFAULT);
337 if (rv != 0)
338 {
339 OutputLine("mbedtls_ssl_config_defaults returned %d", rv);
340 }
341 }
342 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
343
344 SuccessOrExit(error = otTcpConnect(&mEndpoint, &sockaddr, flags));
345 mEndpointConnected = true;
346 mEndpointConnectedFastOpen = ((flags & OT_TCP_CONNECT_NO_FAST_OPEN) == 0);
347
348 #if OPENTHREAD_CONFIG_TLS_ENABLE
349 if (mUseTls && mEndpointConnectedFastOpen)
350 {
351 PrepareTlsHandshake();
352 ContinueTlsHandshake();
353 }
354 #endif
355
356 exit:
357 return error;
358 }
359
Process(Arg aArgs[])360 template <> otError TcpExample::Process<Cmd("send")>(Arg aArgs[])
361 {
362 otError error;
363
364 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
365 VerifyOrExit(mBenchmarkBytesTotal == 0, error = OT_ERROR_BUSY);
366 VerifyOrExit(!aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
367 VerifyOrExit(aArgs[1].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
368
369 if (mUseCircularSendBuffer)
370 {
371 #if OPENTHREAD_CONFIG_TLS_ENABLE
372 if (mUseTls)
373 {
374 int rv = mbedtls_ssl_write(&mSslContext, reinterpret_cast<unsigned char *>(aArgs[0].GetCString()),
375 aArgs[0].GetLength());
376 if (rv < 0 && rv != MBEDTLS_ERR_SSL_WANT_WRITE && rv != MBEDTLS_ERR_SSL_WANT_READ)
377 {
378 ExitNow(error = kErrorFailed);
379 }
380 error = kErrorNone;
381 }
382 else
383 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
384 {
385 size_t written;
386 SuccessOrExit(error = otTcpCircularSendBufferWrite(&mEndpoint, &mSendBuffer, aArgs[0].GetCString(),
387 aArgs[0].GetLength(), &written, 0));
388 }
389 }
390 else
391 {
392 VerifyOrExit(!mSendBusy, error = OT_ERROR_BUSY);
393
394 mSendLink.mNext = nullptr;
395 mSendLink.mData = mSendBufferBytes;
396 mSendLink.mLength = OT_MIN(aArgs[0].GetLength(), sizeof(mSendBufferBytes));
397 memcpy(mSendBufferBytes, aArgs[0].GetCString(), mSendLink.mLength);
398
399 SuccessOrExit(error = otTcpSendByReference(&mEndpoint, &mSendLink, 0));
400 mSendBusy = true;
401 }
402
403 exit:
404 return error;
405 }
406
Process(Arg aArgs[])407 template <> otError TcpExample::Process<Cmd("benchmark")>(Arg aArgs[])
408 {
409 otError error = OT_ERROR_NONE;
410
411 if (aArgs[0] == "result")
412 {
413 OutputFormat("TCP Benchmark Status: ");
414 if (mBenchmarkBytesTotal != 0)
415 {
416 OutputLine("Ongoing");
417 }
418 else if (mBenchmarkTimeUsed != 0)
419 {
420 OutputLine("Completed");
421 OutputBenchmarkResult();
422 }
423 else
424 {
425 OutputLine("Untested");
426 }
427 }
428 else if (aArgs[0] == "run")
429 {
430 VerifyOrExit(!mSendBusy, error = OT_ERROR_BUSY);
431 VerifyOrExit(mBenchmarkBytesTotal == 0, error = OT_ERROR_BUSY);
432
433 if (aArgs[1].IsEmpty())
434 {
435 mBenchmarkBytesTotal = OPENTHREAD_CONFIG_CLI_TCP_DEFAULT_BENCHMARK_SIZE;
436 }
437 else
438 {
439 SuccessOrExit(error = aArgs[1].ParseAsUint32(mBenchmarkBytesTotal));
440 VerifyOrExit(mBenchmarkBytesTotal != 0, error = OT_ERROR_INVALID_ARGS);
441 }
442 VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
443
444 mBenchmarkStart = TimerMilli::GetNow();
445 mBenchmarkBytesUnsent = mBenchmarkBytesTotal;
446
447 if (mUseCircularSendBuffer)
448 {
449 SuccessOrExit(error = ContinueBenchmarkCircularSend());
450 }
451 else
452 {
453 uint32_t benchmarkLinksLeft =
454 (mBenchmarkBytesTotal + sizeof(mSendBufferBytes) - 1) / sizeof(mSendBufferBytes);
455 uint32_t toSendOut = OT_MIN(OT_ARRAY_LENGTH(mBenchmarkLinks), benchmarkLinksLeft);
456
457 /* We could also point the linked buffers directly to sBenchmarkData. */
458 memset(mSendBufferBytes, 'a', sizeof(mSendBufferBytes));
459
460 for (uint32_t i = 0; i != toSendOut; i++)
461 {
462 mBenchmarkLinks[i].mNext = nullptr;
463 mBenchmarkLinks[i].mData = mSendBufferBytes;
464 mBenchmarkLinks[i].mLength = sizeof(mSendBufferBytes);
465 if (i == 0 && mBenchmarkBytesTotal % sizeof(mSendBufferBytes) != 0)
466 {
467 mBenchmarkLinks[i].mLength = mBenchmarkBytesTotal % sizeof(mSendBufferBytes);
468 }
469 error = otTcpSendByReference(&mEndpoint, &mBenchmarkLinks[i],
470 i == toSendOut - 1 ? 0 : OT_TCP_SEND_MORE_TO_COME);
471 VerifyOrExit(error == OT_ERROR_NONE, mBenchmarkBytesTotal = 0);
472 }
473 }
474 }
475 else
476 {
477 error = OT_ERROR_INVALID_ARGS;
478 }
479
480 exit:
481 return error;
482 }
483
Process(Arg aArgs[])484 template <> otError TcpExample::Process<Cmd("sendend")>(Arg aArgs[])
485 {
486 otError error;
487
488 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
489 VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
490
491 error = otTcpSendEndOfStream(&mEndpoint);
492
493 exit:
494 return error;
495 }
496
Process(Arg aArgs[])497 template <> otError TcpExample::Process<Cmd("abort")>(Arg aArgs[])
498 {
499 otError error;
500
501 VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
502 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
503
504 SuccessOrExit(error = otTcpAbort(&mEndpoint));
505 mEndpointConnected = false;
506 mEndpointConnectedFastOpen = false;
507
508 exit:
509 return error;
510 }
511
Process(Arg aArgs[])512 template <> otError TcpExample::Process<Cmd("listen")>(Arg aArgs[])
513 {
514 otError error;
515 otSockAddr sockaddr;
516
517 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
518
519 SuccessOrExit(error = aArgs[0].ParseAsIp6Address(sockaddr.mAddress));
520 SuccessOrExit(error = aArgs[1].ParseAsUint16(sockaddr.mPort));
521 VerifyOrExit(aArgs[2].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
522
523 SuccessOrExit(error = otTcpStopListening(&mListener));
524 error = otTcpListen(&mListener, &sockaddr);
525
526 exit:
527 return error;
528 }
529
Process(Arg aArgs[])530 template <> otError TcpExample::Process<Cmd("stoplistening")>(Arg aArgs[])
531 {
532 otError error;
533
534 VerifyOrExit(aArgs[0].IsEmpty(), error = OT_ERROR_INVALID_ARGS);
535 VerifyOrExit(mInitialized, error = OT_ERROR_INVALID_STATE);
536
537 error = otTcpStopListening(&mListener);
538
539 exit:
540 return error;
541 }
542
Process(Arg aArgs[])543 otError TcpExample::Process(Arg aArgs[])
544 {
545 #define CmdEntry(aCommandString) \
546 { \
547 aCommandString, &TcpExample::Process<Cmd(aCommandString)> \
548 }
549
550 static constexpr Command kCommands[] = {
551 CmdEntry("abort"), CmdEntry("benchmark"), CmdEntry("bind"), CmdEntry("connect"), CmdEntry("deinit"),
552 CmdEntry("init"), CmdEntry("listen"), CmdEntry("send"), CmdEntry("sendend"), CmdEntry("stoplistening"),
553 };
554
555 static_assert(BinarySearch::IsSorted(kCommands), "kCommands is not sorted");
556
557 otError error = OT_ERROR_INVALID_COMMAND;
558 const Command *command;
559
560 if (aArgs[0].IsEmpty() || (aArgs[0] == "help"))
561 {
562 OutputCommandTable(kCommands);
563 ExitNow(error = aArgs[0].IsEmpty() ? error : OT_ERROR_NONE);
564 }
565
566 command = BinarySearch::Find(aArgs[0].GetCString(), kCommands);
567 VerifyOrExit(command != nullptr);
568
569 error = (this->*command->mHandler)(aArgs + 1);
570
571 exit:
572 return error;
573 }
574
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)575 void TcpExample::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
576 {
577 static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
578 }
579
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)580 void TcpExample::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
581 {
582 static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
583 }
584
HandleTcpForwardProgressCallback(otTcpEndpoint * aEndpoint,size_t aInSendBuffer,size_t aBacklog)585 void TcpExample::HandleTcpForwardProgressCallback(otTcpEndpoint *aEndpoint, size_t aInSendBuffer, size_t aBacklog)
586 {
587 static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))
588 ->HandleTcpForwardProgress(aEndpoint, aInSendBuffer, aBacklog);
589 }
590
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)591 void TcpExample::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
592 size_t aBytesAvailable,
593 bool aEndOfStream,
594 size_t aBytesRemaining)
595 {
596 static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))
597 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
598 }
599
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)600 void TcpExample::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
601 {
602 static_cast<TcpExample *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
603 }
604
HandleTcpAcceptReadyCallback(otTcpListener * aListener,const otSockAddr * aPeer,otTcpEndpoint ** aAcceptInto)605 otTcpIncomingConnectionAction TcpExample::HandleTcpAcceptReadyCallback(otTcpListener *aListener,
606 const otSockAddr *aPeer,
607 otTcpEndpoint **aAcceptInto)
608 {
609 return static_cast<TcpExample *>(otTcpListenerGetContext(aListener))
610 ->HandleTcpAcceptReady(aListener, aPeer, aAcceptInto);
611 }
612
HandleTcpAcceptDoneCallback(otTcpListener * aListener,otTcpEndpoint * aEndpoint,const otSockAddr * aPeer)613 void TcpExample::HandleTcpAcceptDoneCallback(otTcpListener *aListener,
614 otTcpEndpoint *aEndpoint,
615 const otSockAddr *aPeer)
616 {
617 static_cast<TcpExample *>(otTcpListenerGetContext(aListener))->HandleTcpAcceptDone(aListener, aEndpoint, aPeer);
618 }
619
HandleTcpEstablished(otTcpEndpoint * aEndpoint)620 void TcpExample::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
621 {
622 OT_UNUSED_VARIABLE(aEndpoint);
623 OutputLine("TCP: Connection established");
624 #if OPENTHREAD_CONFIG_TLS_ENABLE
625 if (mUseTls && !mEndpointConnectedFastOpen)
626 {
627 PrepareTlsHandshake();
628 ContinueTlsHandshake();
629 }
630 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
631 }
632
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)633 void TcpExample::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
634 {
635 OT_UNUSED_VARIABLE(aEndpoint);
636 OT_ASSERT(!mUseCircularSendBuffer); // this callback is not used when using the circular send buffer
637
638 if (mBenchmarkBytesTotal == 0)
639 {
640 // If the benchmark encountered an error, we might end up here. So,
641 // tolerate some benchmark links finishing in this case.
642 if (aData == &mSendLink)
643 {
644 OT_ASSERT(mSendBusy);
645 mSendBusy = false;
646 }
647 }
648 else
649 {
650 OT_ASSERT(aData != &mSendLink);
651 OT_ASSERT(mBenchmarkBytesUnsent >= aData->mLength);
652 mBenchmarkBytesUnsent -= aData->mLength; // could be less than sizeof(mSendBufferBytes) for the first link
653 if (mBenchmarkBytesUnsent >= OT_ARRAY_LENGTH(mBenchmarkLinks) * sizeof(mSendBufferBytes))
654 {
655 aData->mLength = sizeof(mSendBufferBytes);
656 if (otTcpSendByReference(&mEndpoint, aData, 0) != OT_ERROR_NONE)
657 {
658 OutputLine("TCP Benchmark Failed");
659 mBenchmarkBytesTotal = 0;
660 }
661 }
662 else if (mBenchmarkBytesUnsent == 0)
663 {
664 CompleteBenchmark();
665 }
666 }
667 }
668
HandleTcpForwardProgress(otTcpEndpoint * aEndpoint,size_t aInSendBuffer,size_t aBacklog)669 void TcpExample::HandleTcpForwardProgress(otTcpEndpoint *aEndpoint, size_t aInSendBuffer, size_t aBacklog)
670 {
671 OT_UNUSED_VARIABLE(aEndpoint);
672 OT_UNUSED_VARIABLE(aBacklog);
673 OT_ASSERT(mUseCircularSendBuffer); // this callback is only used when using the circular send buffer
674
675 otTcpCircularSendBufferHandleForwardProgress(&mSendBuffer, aInSendBuffer);
676
677 #if OPENTHREAD_CONFIG_TLS_ENABLE
678 if (mUseTls)
679 {
680 ContinueTlsHandshake();
681 }
682 #endif
683
684 /* Handle case where we're in a benchmark. */
685 if (mBenchmarkBytesTotal != 0)
686 {
687 if (mBenchmarkBytesUnsent != 0)
688 {
689 /* Continue sending out data if there's data we haven't sent. */
690 IgnoreError(ContinueBenchmarkCircularSend());
691 }
692 else if (aInSendBuffer == 0)
693 {
694 /* Handle case where all data is sent out and the send buffer has drained. */
695 CompleteBenchmark();
696 }
697 }
698 }
699
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)700 void TcpExample::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
701 size_t aBytesAvailable,
702 bool aEndOfStream,
703 size_t aBytesRemaining)
704 {
705 OT_UNUSED_VARIABLE(aBytesRemaining);
706 OT_ASSERT(aEndpoint == &mEndpoint);
707
708 /* If we get data before the handshake completes, then this is a TFO connection. */
709 if (!mEndpointConnected)
710 {
711 mEndpointConnected = true;
712 mEndpointConnectedFastOpen = true;
713
714 #if OPENTHREAD_CONFIG_TLS_ENABLE
715 if (mUseTls)
716 {
717 PrepareTlsHandshake();
718 }
719 #endif
720 }
721
722 #if OPENTHREAD_CONFIG_TLS_ENABLE
723 if (mUseTls && ContinueTlsHandshake())
724 {
725 return;
726 }
727 #endif
728
729 if ((mTlsHandshakeComplete || !mUseTls) && aBytesAvailable > 0)
730 {
731 #if OPENTHREAD_CONFIG_TLS_ENABLE
732 if (mUseTls)
733 {
734 uint8_t buffer[500];
735 for (;;)
736 {
737 int rv = mbedtls_ssl_read(&mSslContext, buffer, sizeof(buffer));
738 if (rv < 0)
739 {
740 if (rv == MBEDTLS_ERR_SSL_WANT_READ)
741 {
742 break;
743 }
744 OutputLine("TLS receive failure: %d", rv);
745 }
746 else
747 {
748 OutputLine("TLS: Received %u bytes: %.*s", static_cast<unsigned>(rv), rv,
749 reinterpret_cast<const char *>(buffer));
750 }
751 }
752 OutputLine("(TCP: Received %u bytes)", static_cast<unsigned>(aBytesAvailable));
753 }
754 else
755 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
756 {
757 const otLinkedBuffer *data;
758 size_t totalReceived = 0;
759 IgnoreError(otTcpReceiveByReference(aEndpoint, &data));
760 for (; data != nullptr; data = data->mNext)
761 {
762 OutputLine("TCP: Received %u bytes: %.*s", static_cast<unsigned>(data->mLength),
763 static_cast<unsigned>(data->mLength), reinterpret_cast<const char *>(data->mData));
764 totalReceived += data->mLength;
765 }
766 OT_ASSERT(aBytesAvailable == totalReceived);
767 IgnoreReturnValue(otTcpCommitReceive(aEndpoint, totalReceived, 0));
768 }
769 }
770
771 if (aEndOfStream)
772 {
773 OutputLine("TCP: Reached end of stream");
774 }
775 }
776
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)777 void TcpExample::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
778 {
779 static const char *const kReasonStrings[] = {
780 "Disconnected", // (0) OT_TCP_DISCONNECTED_REASON_NORMAL
781 "Connection refused", // (1) OT_TCP_DISCONNECTED_REASON_REFUSED
782 "Connection reset", // (2) OT_TCP_DISCONNECTED_REASON_RESET
783 "Entered TIME-WAIT state", // (3) OT_TCP_DISCONNECTED_REASON_TIME_WAIT
784 "Connection timed out", // (4) OT_TCP_DISCONNECTED_REASON_TIMED_OUT
785 };
786
787 OT_UNUSED_VARIABLE(aEndpoint);
788
789 static_assert(0 == OT_TCP_DISCONNECTED_REASON_NORMAL, "OT_TCP_DISCONNECTED_REASON_NORMAL value is incorrect");
790 static_assert(1 == OT_TCP_DISCONNECTED_REASON_REFUSED, "OT_TCP_DISCONNECTED_REASON_REFUSED value is incorrect");
791 static_assert(2 == OT_TCP_DISCONNECTED_REASON_RESET, "OT_TCP_DISCONNECTED_REASON_RESET value is incorrect");
792 static_assert(3 == OT_TCP_DISCONNECTED_REASON_TIME_WAIT, "OT_TCP_DISCONNECTED_REASON_TIME_WAIT value is incorrect");
793 static_assert(4 == OT_TCP_DISCONNECTED_REASON_TIMED_OUT, "OT_TCP_DISCONNECTED_REASON_TIMED_OUT value is incorrect");
794
795 OutputLine("TCP: %s", Stringify(aReason, kReasonStrings));
796
797 #if OPENTHREAD_CONFIG_TLS_ENABLE
798 if (mUseTls)
799 {
800 mbedtls_ssl_session_reset(&mSslContext);
801 }
802 #endif
803
804 // We set this to false even for the TIME-WAIT state, so that we can reuse
805 // the active socket if an incoming connection comes in instead of waiting
806 // for the 2MSL timeout.
807 mEndpointConnected = false;
808 mEndpointConnectedFastOpen = false;
809 mSendBusy = false;
810
811 // Mark the benchmark as inactive if the connection was disconnected.
812 mBenchmarkBytesTotal = 0;
813 mBenchmarkBytesUnsent = 0;
814
815 otTcpCircularSendBufferForceDiscardAll(&mSendBuffer);
816 }
817
HandleTcpAcceptReady(otTcpListener * aListener,const otSockAddr * aPeer,otTcpEndpoint ** aAcceptInto)818 otTcpIncomingConnectionAction TcpExample::HandleTcpAcceptReady(otTcpListener *aListener,
819 const otSockAddr *aPeer,
820 otTcpEndpoint **aAcceptInto)
821 {
822 otTcpIncomingConnectionAction action;
823
824 OT_UNUSED_VARIABLE(aListener);
825
826 if (mEndpointConnected)
827 {
828 OutputFormat("TCP: Ignoring incoming connection request from ");
829 OutputSockAddr(*aPeer);
830 OutputLine(" (active socket is busy)");
831
832 ExitNow(action = OT_TCP_INCOMING_CONNECTION_ACTION_DEFER);
833 }
834
835 *aAcceptInto = &mEndpoint;
836 action = OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT;
837
838 #if OPENTHREAD_CONFIG_TLS_ENABLE
839 /*
840 * Natural to wait until the AcceptDone callback but with TFO we could get data before that
841 * so it doesn't make sense to wait until then.
842 */
843 if (mUseTls)
844 {
845 int rv;
846
847 rv = mbedtls_ssl_config_defaults(&mSslConfig, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
848 MBEDTLS_SSL_PRESET_DEFAULT);
849 if (rv != 0)
850 {
851 OutputLine("mbedtls_ssl_config_defaults returned %d", rv);
852 }
853 mbedtls_ssl_conf_ca_chain(&mSslConfig, mSrvCert.next, nullptr);
854 rv = mbedtls_ssl_conf_own_cert(&mSslConfig, &mSrvCert, &mPKey);
855 if (rv != 0)
856 {
857 OutputLine("mbedtls_ssl_conf_own_cert returned %d", rv);
858 }
859 }
860 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
861
862 exit:
863 return action;
864 }
865
HandleTcpAcceptDone(otTcpListener * aListener,otTcpEndpoint * aEndpoint,const otSockAddr * aPeer)866 void TcpExample::HandleTcpAcceptDone(otTcpListener *aListener, otTcpEndpoint *aEndpoint, const otSockAddr *aPeer)
867 {
868 OT_UNUSED_VARIABLE(aListener);
869 OT_UNUSED_VARIABLE(aEndpoint);
870
871 mEndpointConnected = true;
872 OutputFormat("Accepted connection from ");
873 OutputSockAddrLine(*aPeer);
874 }
875
ContinueBenchmarkCircularSend(void)876 otError TcpExample::ContinueBenchmarkCircularSend(void)
877 {
878 otError error = OT_ERROR_NONE;
879 size_t freeSpace;
880
881 while (mBenchmarkBytesUnsent != 0 && (freeSpace = otTcpCircularSendBufferGetFreeSpace(&mSendBuffer)) != 0)
882 {
883 size_t toSendThisIteration = OT_MIN(mBenchmarkBytesUnsent, sBenchmarkDataLength);
884 uint32_t flag = (toSendThisIteration < freeSpace && toSendThisIteration < mBenchmarkBytesUnsent)
885 ? OT_TCP_CIRCULAR_SEND_BUFFER_WRITE_MORE_TO_COME
886 : 0;
887 size_t written = 0;
888
889 #if OPENTHREAD_CONFIG_TLS_ENABLE
890 if (mUseTls)
891 {
892 int rv = mbedtls_ssl_write(&mSslContext, reinterpret_cast<const unsigned char *>(sBenchmarkData),
893 toSendThisIteration);
894 if (rv > 0)
895 {
896 written = static_cast<size_t>(rv);
897 OT_ASSERT(written <= mBenchmarkBytesUnsent);
898 }
899 else if (rv != MBEDTLS_ERR_SSL_WANT_WRITE && rv != MBEDTLS_ERR_SSL_WANT_READ)
900 {
901 ExitNow(error = kErrorFailed);
902 }
903 error = kErrorNone;
904 }
905 else
906 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
907 {
908 SuccessOrExit(error = otTcpCircularSendBufferWrite(&mEndpoint, &mSendBuffer, sBenchmarkData,
909 toSendThisIteration, &written, flag));
910 }
911 mBenchmarkBytesUnsent -= written;
912 }
913
914 exit:
915 if (error != OT_ERROR_NONE)
916 {
917 OutputLine("TCP Benchmark Failed");
918 mBenchmarkBytesTotal = 0;
919 mBenchmarkBytesUnsent = 0;
920 }
921
922 return error;
923 }
924
OutputBenchmarkResult(void)925 void TcpExample::OutputBenchmarkResult(void)
926 {
927 uint32_t thousandTimesGoodput =
928 (1000 * (mBenchmarkLastBytesTotal << 3) + (mBenchmarkTimeUsed >> 1)) / mBenchmarkTimeUsed;
929
930 OutputLine("TCP Benchmark Complete: Transferred %lu bytes in %lu milliseconds", ToUlong(mBenchmarkLastBytesTotal),
931 ToUlong(mBenchmarkTimeUsed));
932 OutputLine("TCP Goodput: %lu.%03u kb/s", ToUlong(thousandTimesGoodput / 1000),
933 static_cast<uint16_t>(thousandTimesGoodput % 1000));
934 }
935
CompleteBenchmark(void)936 void TcpExample::CompleteBenchmark(void)
937 {
938 mBenchmarkTimeUsed = TimerMilli::GetNow() - mBenchmarkStart;
939 mBenchmarkLastBytesTotal = mBenchmarkBytesTotal;
940
941 OutputBenchmarkResult();
942
943 mBenchmarkBytesTotal = 0;
944 }
945
946 #if OPENTHREAD_CONFIG_TLS_ENABLE
PrepareTlsHandshake(void)947 void TcpExample::PrepareTlsHandshake(void)
948 {
949 int rv;
950 rv = mbedtls_ssl_set_hostname(&mSslContext, "localhost");
951 if (rv != 0)
952 {
953 OutputLine("mbedtls_ssl_set_hostname returned %d", rv);
954 }
955 rv = mbedtls_ssl_set_hs_ecjpake_password(&mSslContext, reinterpret_cast<const unsigned char *>(sEcjpakePassword),
956 sEcjpakePasswordLength);
957 if (rv != 0)
958 {
959 OutputLine("mbedtls_ssl_set_hs_ecjpake_password returned %d", rv);
960 }
961 mbedtls_ssl_set_bio(&mSslContext, &mEndpointAndCircularSendBuffer, otTcpMbedTlsSslSendCallback,
962 otTcpMbedTlsSslRecvCallback, nullptr);
963 mTlsHandshakeComplete = false;
964 }
965
ContinueTlsHandshake(void)966 bool TcpExample::ContinueTlsHandshake(void)
967 {
968 bool wasNotAlreadyDone = false;
969 int rv;
970
971 if (!mTlsHandshakeComplete)
972 {
973 rv = mbedtls_ssl_handshake(&mSslContext);
974 if (rv == 0)
975 {
976 OutputLine("TLS Handshake Complete");
977 mTlsHandshakeComplete = true;
978 }
979 else if (rv != MBEDTLS_ERR_SSL_WANT_READ && rv != MBEDTLS_ERR_SSL_WANT_WRITE)
980 {
981 OutputLine("TLS Handshake Failed: %d", rv);
982 }
983 wasNotAlreadyDone = true;
984 }
985
986 return wasNotAlreadyDone;
987 }
988 #endif // OPENTHREAD_CONFIG_TLS_ENABLE
989
990 } // namespace Cli
991 } // namespace ot
992
993 #endif // OPENTHREAD_CONFIG_TCP_ENABLE && OPENTHREAD_CONFIG_CLI_TCP_ENABLE
994