1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements TCP/IPv6 sockets.
32  */
33 
34 #include "openthread-core-config.h"
35 
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37 
38 #include "tcp6.hpp"
39 
40 #include "common/as_core_type.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/error.hpp"
43 #include "common/instance.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/num_utils.hpp"
47 #include "common/random.hpp"
48 #include "net/checksum.hpp"
49 #include "net/ip6.hpp"
50 #include "net/netif.hpp"
51 
52 #include "../../third_party/tcplp/tcplp.h"
53 
54 namespace ot {
55 namespace Ip6 {
56 
57 using ot::Encoding::BigEndian::HostSwap16;
58 using ot::Encoding::BigEndian::HostSwap32;
59 
60 RegisterLogModule("Tcp");
61 
62 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
63 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
64               "mTcb field in otTcpEndpoint is aligned incorrectly");
65 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
66 
67 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
68               "mTcbListen field in otTcpListener is sized incorrectly");
69 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
70               "mTcbListen field in otTcpListener is aligned incorrectly");
71 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
72 
Tcp(Instance & aInstance)73 Tcp::Tcp(Instance &aInstance)
74     : InstanceLocator(aInstance)
75     , mTimer(aInstance)
76     , mTasklet(aInstance)
77     , mEphemeralPort(kDynamicPortMin)
78 {
79     OT_UNUSED_VARIABLE(mEphemeralPort);
80 }
81 
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)82 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
83 {
84     Error         error;
85     struct tcpcb &tp = GetTcb();
86 
87     memset(&tp, 0x00, sizeof(tp));
88 
89     SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
90 
91     mContext                  = aArgs.mContext;
92     mEstablishedCallback      = aArgs.mEstablishedCallback;
93     mSendDoneCallback         = aArgs.mSendDoneCallback;
94     mForwardProgressCallback  = aArgs.mForwardProgressCallback;
95     mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
96     mDisconnectedCallback     = aArgs.mDisconnectedCallback;
97 
98     memset(mTimers, 0x00, sizeof(mTimers));
99     memset(&mSockAddr, 0x00, sizeof(mSockAddr));
100     mPendingCallbacks = 0;
101 
102     /*
103      * Initialize buffers --- formerly in initialize_tcb.
104      */
105     {
106         uint8_t *recvbuf    = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
107         size_t   recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
108         uint8_t *reassbmp   = recvbuf + recvbuflen;
109 
110         lbuf_init(&tp.sendbuf);
111         cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
112         tp.reassbmp = reassbmp;
113         bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
114     }
115 
116     tp.accepted_from = nullptr;
117     initialize_tcb(&tp);
118 
119     /* Note that we do not need to zero-initialize mReceiveLinks. */
120 
121     tp.instance = &aInstance;
122 
123 exit:
124     return error;
125 }
126 
GetInstance(void) const127 Instance &Tcp::Endpoint::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcb().instance)); }
128 
GetLocalAddress(void) const129 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
130 {
131     const struct tcpcb &tp = GetTcb();
132 
133     static otSockAddr temp;
134 
135     memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
136     temp.mPort = HostSwap16(tp.lport);
137 
138     return AsCoreType(&temp);
139 }
140 
GetPeerAddress(void) const141 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
142 {
143     const struct tcpcb &tp = GetTcb();
144 
145     static otSockAddr temp;
146 
147     memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
148     temp.mPort = HostSwap16(tp.fport);
149 
150     return AsCoreType(&temp);
151 }
152 
Bind(const SockAddr & aSockName)153 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
154 {
155     Error         error;
156     struct tcpcb &tp = GetTcb();
157 
158     VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
159     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
160 
161     memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
162     tp.lport = HostSwap16(aSockName.mPort);
163     error    = kErrorNone;
164 
165 exit:
166     return error;
167 }
168 
Connect(const SockAddr & aSockName,uint32_t aFlags)169 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
170 {
171     Error               error = kErrorNone;
172     struct tcpcb       &tp    = GetTcb();
173     struct sockaddr_in6 sin6p;
174 
175     OT_UNUSED_VARIABLE(aFlags);
176 
177     VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
178 
179     memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
180     sin6p.sin6_port = HostSwap16(aSockName.mPort);
181     error           = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
182 
183 exit:
184     return error;
185 }
186 
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)187 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
188 {
189     Error         error;
190     struct tcpcb &tp = GetTcb();
191 
192     size_t backlogBefore = GetBacklogBytes();
193     size_t sent          = aBuffer.mLength;
194 
195     SuccessOrExit(error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0)));
196 
197     PostCallbacksAfterSend(sent, backlogBefore);
198 
199 exit:
200     return error;
201 }
202 
SendByExtension(size_t aNumBytes,uint32_t aFlags)203 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
204 {
205     Error         error;
206     bool          moreToCome    = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
207     struct tcpcb &tp            = GetTcb();
208     size_t        backlogBefore = GetBacklogBytes();
209     int           bsdError;
210 
211     VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
212 
213     bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes);
214     SuccessOrExit(error = BsdErrorToOtError(bsdError));
215 
216     PostCallbacksAfterSend(aNumBytes, backlogBefore);
217 
218 exit:
219     return error;
220 }
221 
ReceiveByReference(const otLinkedBuffer * & aBuffer)222 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
223 {
224     struct tcpcb &tp = GetTcb();
225 
226     cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
227     aBuffer = &mReceiveLinks[0];
228 
229     return kErrorNone;
230 }
231 
ReceiveContiguify(void)232 Error Tcp::Endpoint::ReceiveContiguify(void)
233 {
234     struct tcpcb &tp = GetTcb();
235 
236     cbuf_contiguify(&tp.recvbuf, tp.reassbmp);
237 
238     return kErrorNone;
239 }
240 
CommitReceive(size_t aNumBytes,uint32_t aFlags)241 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
242 {
243     Error         error = kErrorNone;
244     struct tcpcb &tp    = GetTcb();
245 
246     OT_UNUSED_VARIABLE(aFlags);
247 
248     VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
249     VerifyOrExit(aNumBytes > 0, error = kErrorNone);
250 
251     cbuf_pop(&tp.recvbuf, aNumBytes);
252     error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
253 
254 exit:
255     return error;
256 }
257 
SendEndOfStream(void)258 Error Tcp::Endpoint::SendEndOfStream(void)
259 {
260     struct tcpcb &tp = GetTcb();
261 
262     return BsdErrorToOtError(tcp_usr_shutdown(&tp));
263 }
264 
Abort(void)265 Error Tcp::Endpoint::Abort(void)
266 {
267     struct tcpcb &tp = GetTcb();
268 
269     tcp_usr_abort(&tp);
270     /* connection_lost will do any reinitialization work for this socket. */
271     return kErrorNone;
272 }
273 
Deinitialize(void)274 Error Tcp::Endpoint::Deinitialize(void)
275 {
276     Error error;
277 
278     SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
279     SetNext(nullptr);
280 
281     SuccessOrExit(error = Abort());
282 
283 exit:
284     return error;
285 }
286 
IsClosed(void) const287 bool Tcp::Endpoint::IsClosed(void) const { return GetTcb().t_state == TCP6S_CLOSED; }
288 
TimerFlagToIndex(uint8_t aTimerFlag)289 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
290 {
291     uint8_t timerIndex = 0;
292 
293     switch (aTimerFlag)
294     {
295     case TT_DELACK:
296         timerIndex = kTimerDelack;
297         break;
298     case TT_REXMT:
299     case TT_PERSIST:
300         timerIndex = kTimerRexmtPersist;
301         break;
302     case TT_KEEP:
303         timerIndex = kTimerKeep;
304         break;
305     case TT_2MSL:
306         timerIndex = kTimer2Msl;
307         break;
308     }
309 
310     return timerIndex;
311 }
312 
IsTimerActive(uint8_t aTimerIndex)313 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
314 {
315     bool          active = false;
316     struct tcpcb *tp     = &GetTcb();
317 
318     OT_ASSERT(aTimerIndex < kNumTimers);
319     switch (aTimerIndex)
320     {
321     case kTimerDelack:
322         active = tcp_timer_active(tp, TT_DELACK);
323         break;
324     case kTimerRexmtPersist:
325         active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
326         break;
327     case kTimerKeep:
328         active = tcp_timer_active(tp, TT_KEEP);
329         break;
330     case kTimer2Msl:
331         active = tcp_timer_active(tp, TT_2MSL);
332         break;
333     }
334 
335     return active;
336 }
337 
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)338 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
339 {
340     /*
341      * TCPlp has already set the flag for this timer to record that it's
342      * running. So, all that's left to do is record the expiry time and
343      * (re)set the main timer as appropriate.
344      */
345 
346     TimeMilli now         = TimerMilli::GetNow();
347     TimeMilli newFireTime = now + aDelay;
348     uint8_t   timerIndex  = TimerFlagToIndex(aTimerFlag);
349 
350     mTimers[timerIndex] = newFireTime.GetValue();
351     LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
352             static_cast<unsigned int>(aDelay));
353 
354     Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
355 }
356 
CancelTimer(uint8_t aTimerFlag)357 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
358 {
359     /*
360      * TCPlp has already cleared the timer flag before calling this. Since the
361      * main timer's callback properly handles the case where no timers are
362      * actually due, there's actually no work to be done here.
363      */
364 
365     OT_UNUSED_VARIABLE(aTimerFlag);
366 
367     LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
368             static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
369 }
370 
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)371 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
372 {
373     bool          calledUserCallback = false;
374     struct tcpcb *tp                 = &GetTcb();
375 
376     /*
377      * NOTE: Firing a timer might potentially activate/deactivate other timers.
378      * If timers x and y expire at the same time, but the callback for timer x
379      * (for x < y) cancels or postpones timer y, should timer y's callback be
380      * called? Our answer is no, since timer x's callback has updated the
381      * TCP stack's state in such a way that it no longer expects timer y's
382      * callback to to be called. Because the TCP stack thinks that timer y
383      * has been cancelled, calling timer y's callback could potentially cause
384      * problems.
385      *
386      * If the timer callbacks set other timers, then they may not be taken
387      * into account when setting aEarliestFutureExpiry. But mTimer's expiry
388      * time will be updated by those, so we can just compare against mTimer's
389      * expiry time when resetting mTimer.
390      */
391     for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
392     {
393         if (IsTimerActive(timerIndex))
394         {
395             TimeMilli expiry(mTimers[timerIndex]);
396 
397             if (expiry <= aNow)
398             {
399                 /*
400                  * If a user callback is called, then return true. For TCPlp,
401                  * this only happens if the connection is dropped (e.g., it
402                  * times out).
403                  */
404                 int dropped;
405 
406                 switch (timerIndex)
407                 {
408                 case kTimerDelack:
409                     dropped = tcp_timer_delack(tp);
410                     break;
411                 case kTimerRexmtPersist:
412                     if (tcp_timer_active(tp, TT_REXMT))
413                     {
414                         dropped = tcp_timer_rexmt(tp);
415                     }
416                     else
417                     {
418                         dropped = tcp_timer_persist(tp);
419                     }
420                     break;
421                 case kTimerKeep:
422                     dropped = tcp_timer_keep(tp);
423                     break;
424                 case kTimer2Msl:
425                     dropped = tcp_timer_2msl(tp);
426                     break;
427                 }
428                 VerifyOrExit(dropped == 0, calledUserCallback = true);
429             }
430             else
431             {
432                 aHasFutureTimer       = true;
433                 aEarliestFutureExpiry = Min(aEarliestFutureExpiry, expiry);
434             }
435         }
436     }
437 
438 exit:
439     return calledUserCallback;
440 }
441 
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)442 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
443 {
444     size_t backlogAfter = GetBacklogBytes();
445 
446     if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
447     {
448         mPendingCallbacks |= kForwardProgressCallbackFlag;
449         Get<Tcp>().mTasklet.Post();
450     }
451 }
452 
FirePendingCallbacks(void)453 bool Tcp::Endpoint::FirePendingCallbacks(void)
454 {
455     bool calledUserCallback = false;
456 
457     if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
458     {
459         mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
460         calledUserCallback = true;
461     }
462 
463     mPendingCallbacks = 0;
464 
465     return calledUserCallback;
466 }
467 
GetSendBufferBytes(void) const468 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
469 {
470     const struct tcpcb &tp = GetTcb();
471     return lbuf_used_space(&tp.sendbuf);
472 }
473 
GetInFlightBytes(void) const474 size_t Tcp::Endpoint::GetInFlightBytes(void) const
475 {
476     const struct tcpcb &tp = GetTcb();
477     return tp.snd_max - tp.snd_una;
478 }
479 
GetBacklogBytes(void) const480 size_t Tcp::Endpoint::GetBacklogBytes(void) const { return GetSendBufferBytes() - GetInFlightBytes(); }
481 
GetLocalIp6Address(void)482 Address &Tcp::Endpoint::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().laddr); }
483 
GetLocalIp6Address(void) const484 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
485 {
486     return *reinterpret_cast<const Address *>(&GetTcb().laddr);
487 }
488 
GetForeignIp6Address(void)489 Address &Tcp::Endpoint::GetForeignIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().faddr); }
490 
GetForeignIp6Address(void) const491 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
492 {
493     return *reinterpret_cast<const Address *>(&GetTcb().faddr);
494 }
495 
Matches(const MessageInfo & aMessageInfo) const496 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
497 {
498     bool                matches = false;
499     const struct tcpcb *tp      = &GetTcb();
500 
501     VerifyOrExit(tp->t_state != TCP6S_CLOSED);
502     VerifyOrExit(tp->lport == HostSwap16(aMessageInfo.GetSockPort()));
503     VerifyOrExit(tp->fport == HostSwap16(aMessageInfo.GetPeerPort()));
504     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
505     VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
506 
507     matches = true;
508 
509 exit:
510     return matches;
511 }
512 
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)513 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
514 {
515     Error                error;
516     struct tcpcb_listen *tpl = &GetTcbListen();
517 
518     SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
519 
520     mContext             = aArgs.mContext;
521     mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
522     mAcceptDoneCallback  = aArgs.mAcceptDoneCallback;
523 
524     memset(tpl, 0x00, sizeof(struct tcpcb_listen));
525     tpl->instance = &aInstance;
526 
527 exit:
528     return error;
529 }
530 
GetInstance(void) const531 Instance &Tcp::Listener::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcbListen().instance)); }
532 
Listen(const SockAddr & aSockName)533 Error Tcp::Listener::Listen(const SockAddr &aSockName)
534 {
535     Error                error;
536     uint16_t             port = HostSwap16(aSockName.mPort);
537     struct tcpcb_listen *tpl  = &GetTcbListen();
538 
539     VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
540 
541     memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
542     tpl->lport   = port;
543     tpl->t_state = TCP6S_LISTEN;
544     error        = kErrorNone;
545 
546 exit:
547     return error;
548 }
549 
StopListening(void)550 Error Tcp::Listener::StopListening(void)
551 {
552     struct tcpcb_listen *tpl = &GetTcbListen();
553 
554     memset(&tpl->laddr, 0x00, sizeof(tpl->laddr));
555     tpl->lport   = 0;
556     tpl->t_state = TCP6S_CLOSED;
557     return kErrorNone;
558 }
559 
Deinitialize(void)560 Error Tcp::Listener::Deinitialize(void)
561 {
562     Error error;
563 
564     SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
565     SetNext(nullptr);
566 
567 exit:
568     return error;
569 }
570 
IsClosed(void) const571 bool Tcp::Listener::IsClosed(void) const { return GetTcbListen().t_state == TCP6S_CLOSED; }
572 
GetLocalIp6Address(void)573 Address &Tcp::Listener::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcbListen().laddr); }
574 
GetLocalIp6Address(void) const575 const Address &Tcp::Listener::GetLocalIp6Address(void) const
576 {
577     return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
578 }
579 
Matches(const MessageInfo & aMessageInfo) const580 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
581 {
582     bool                       matches = false;
583     const struct tcpcb_listen *tpl     = &GetTcbListen();
584 
585     VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
586     VerifyOrExit(tpl->lport == HostSwap16(aMessageInfo.GetSockPort()));
587     VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
588 
589     matches = true;
590 
591 exit:
592     return matches;
593 }
594 
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)595 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
596 {
597     Error error = kErrorNotImplemented;
598 
599     /*
600      * The type uint32_t was chosen for alignment purposes. The size is the
601      * maximum TCP header size, including options.
602      */
603     uint32_t header[15];
604 
605     uint16_t length = aIp6Header.GetPayloadLength();
606     uint8_t  headerSize;
607 
608     struct ip6_hdr *ip6Header;
609     struct tcphdr  *tcpHeader;
610 
611     Endpoint *endpoint;
612     Endpoint *endpointPrev;
613 
614     Listener *listener;
615     Listener *listenerPrev;
616 
617     VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
618     VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
619     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
620     headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
621     VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
622                      static_cast<uint16_t>(headerSize) <= length,
623                  error = kErrorParse);
624     SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
625     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
626 
627     ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
628     tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
629     tcp_fields_to_host(tcpHeader);
630 
631     aMessageInfo.mPeerPort = HostSwap16(tcpHeader->th_sport);
632     aMessageInfo.mSockPort = HostSwap16(tcpHeader->th_dport);
633 
634     endpoint = mEndpoints.FindMatching(aMessageInfo, endpointPrev);
635     if (endpoint != nullptr)
636     {
637         struct tcplp_signals sig;
638         int                  nextAction;
639         struct tcpcb        *tp = &endpoint->GetTcb();
640 
641         otLinkedBuffer *priorHead    = lbuf_head(&tp->sendbuf);
642         size_t          priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
643 
644         memset(&sig, 0x00, sizeof(sig));
645         nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
646         if (nextAction != RELOOKUP_REQUIRED)
647         {
648             ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
649             ExitNow();
650         }
651         /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
652     }
653 
654     listener = mListeners.FindMatching(aMessageInfo, listenerPrev);
655     if (listener != nullptr)
656     {
657         struct tcpcb_listen *tpl = &listener->GetTcbListen();
658 
659         tcp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, nullptr);
660         ExitNow();
661     }
662 
663     tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
664                       ECONNREFUSED);
665 
666 exit:
667     return error;
668 }
669 
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals) const670 void Tcp::ProcessSignals(Endpoint             &aEndpoint,
671                          otLinkedBuffer       *aPriorHead,
672                          size_t                aPriorBacklog,
673                          struct tcplp_signals &aSignals) const
674 {
675     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
676     if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
677     {
678         aEndpoint.mEstablishedCallback(&aEndpoint);
679     }
680 
681     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
682     if (aEndpoint.mSendDoneCallback != nullptr)
683     {
684         otLinkedBuffer *curr = aPriorHead;
685 
686         for (uint32_t i = 0; i != aSignals.links_popped; i++)
687         {
688             otLinkedBuffer *next = curr->mNext;
689 
690             VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
691 
692             curr->mNext = nullptr;
693             aEndpoint.mSendDoneCallback(&aEndpoint, curr);
694             curr = next;
695         }
696     }
697 
698     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
699     if (aEndpoint.mForwardProgressCallback != nullptr)
700     {
701         size_t backlogBytes = aEndpoint.GetBacklogBytes();
702 
703         if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
704         {
705             aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
706             aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
707         }
708     }
709 
710     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
711     if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
712     {
713         aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
714                                             aEndpoint.GetTcb().reass_fin_index != -1,
715                                             cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
716     }
717 
718     VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
719     if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
720     {
721         aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
722     }
723 
724 exit:
725     return;
726 }
727 
BsdErrorToOtError(int aBsdError)728 Error Tcp::BsdErrorToOtError(int aBsdError)
729 {
730     Error error = kErrorFailed;
731 
732     switch (aBsdError)
733     {
734     case 0:
735         error = kErrorNone;
736         break;
737     }
738 
739     return error;
740 }
741 
CanBind(const SockAddr & aSockName)742 bool Tcp::CanBind(const SockAddr &aSockName)
743 {
744     uint16_t port    = HostSwap16(aSockName.mPort);
745     bool     allowed = false;
746 
747     for (Endpoint &endpoint : mEndpoints)
748     {
749         struct tcpcb *tp = &endpoint.GetTcb();
750 
751         if (tp->lport == port)
752         {
753             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
754             VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
755             VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
756         }
757     }
758 
759     for (Listener &listener : mListeners)
760     {
761         struct tcpcb_listen *tpl = &listener.GetTcbListen();
762 
763         if (tpl->lport == port)
764         {
765             VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
766             VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
767             VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
768         }
769     }
770 
771     allowed = true;
772 
773 exit:
774     return allowed;
775 }
776 
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)777 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
778 {
779     bool success;
780 
781     if (aBindAddress)
782     {
783         const Address *source;
784 
785         source = Get<Ip6>().SelectSourceAddress(aPeer.GetAddress());
786         VerifyOrExit(source != nullptr, success = false);
787         aToBind.SetAddress(*source);
788     }
789 
790     if (aBindPort)
791     {
792         /*
793          * TODO: Use a less naive algorithm to allocate ephemeral ports. For
794          * example, see RFC 6056.
795          */
796 
797         for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
798         {
799             aToBind.SetPort(mEphemeralPort);
800 
801             if (mEphemeralPort == kDynamicPortMax)
802             {
803                 mEphemeralPort = kDynamicPortMin;
804             }
805             else
806             {
807                 mEphemeralPort++;
808             }
809 
810             if (CanBind(aToBind))
811             {
812                 ExitNow(success = true);
813             }
814         }
815 
816         ExitNow(success = false);
817     }
818 
819     success = CanBind(aToBind);
820 
821 exit:
822     return success;
823 }
824 
HandleTimer(void)825 void Tcp::HandleTimer(void)
826 {
827     TimeMilli now = TimerMilli::GetNow();
828     bool      pendingTimer;
829     TimeMilli earliestPendingTimerExpiry;
830 
831     LogDebg("Main TCP timer expired");
832 
833     /*
834      * The timer callbacks could potentially set/reset/cancel timers.
835      * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
836      * this function to recompute the timer. If they did, we'd have a
837      * re-entrancy problem, where the callbacks called in this function could
838      * wind up re-entering this function in a nested call frame.
839      *
840      * In general, calling this function from Endpoint::SetTimer and
841      * Endpoint::CancelTimer could be inefficient, since those functions are
842      * called multiple times on each received TCP segment. If we want to
843      * prevent the main timer from firing except when an actual TCP timer
844      * expires, a better alternative is to reset the main timer in
845      * HandleMessage, right before processing signals. That would achieve that
846      * objective while avoiding re-entrancy issues altogether.
847      */
848 restart:
849     pendingTimer               = false;
850     earliestPendingTimerExpiry = now.GetDistantFuture();
851 
852     for (Endpoint &endpoint : mEndpoints)
853     {
854         if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
855         {
856             /*
857              * If a non-OpenThread callback is called --- which, in practice,
858              * happens if the connection times out and the user-defined
859              * connection lost callback is called --- then we might have to
860              * start over. The reason is that the user might deinitialize
861              * endpoints, changing the structure of the linked list. For
862              * example, if the user deinitializes both this endpoint and the
863              * next one in the linked list, then we can't continue traversing
864              * the linked list.
865              */
866             goto restart;
867         }
868     }
869 
870     if (pendingTimer)
871     {
872         /*
873          * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
874          * because one of the earlier callbacks might have set TCP timers,
875          * in which case `mTimer` would have been set to the earliest of those
876          * timers.
877          */
878         mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
879         LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
880     }
881     else
882     {
883         LogDebg("Did not reset main TCP timer");
884     }
885 }
886 
ProcessCallbacks(void)887 void Tcp::ProcessCallbacks(void)
888 {
889     for (Endpoint &endpoint : mEndpoints)
890     {
891         if (endpoint.FirePendingCallbacks())
892         {
893             mTasklet.Post();
894             break;
895         }
896     }
897 }
898 
899 } // namespace Ip6
900 } // namespace ot
901 
902 /*
903  * Implement TCPlp system stubs declared in tcplp.h.
904  *
905  * Because these functions have C linkage, it is important that only one
906  * definition is given for each function name, regardless of the namespace it
907  * in. For example, if we give two definitions of tcplp_sys_new_message, we
908  * will get errors, even if they are in different namespaces. To avoid
909  * confusion, I've put these functions outside of any namespace.
910  */
911 
912 using namespace ot;
913 using namespace ot::Ip6;
914 
915 extern "C" {
916 
tcplp_sys_new_message(otInstance * aInstance)917 otMessage *tcplp_sys_new_message(otInstance *aInstance)
918 {
919     Instance &instance = AsCoreType(aInstance);
920     Message  *message  = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
921 
922     if (message)
923     {
924         message->SetLinkSecurityEnabled(true);
925     }
926 
927     return message;
928 }
929 
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)930 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
931 {
932     OT_UNUSED_VARIABLE(aInstance);
933     Message &message = AsCoreType(aMessage);
934     message.Free();
935 }
936 
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)937 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
938 {
939     Instance    &instance = AsCoreType(aInstance);
940     Message     &message  = AsCoreType(aMessage);
941     MessageInfo &info     = AsCoreType(aMessageInfo);
942 
943     LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
944 
945     IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
946 }
947 
tcplp_sys_get_ticks(void)948 uint32_t tcplp_sys_get_ticks(void) { return TimerMilli::GetNow().GetValue(); }
949 
tcplp_sys_get_millis(void)950 uint32_t tcplp_sys_get_millis(void) { return TimerMilli::GetNow().GetValue(); }
951 
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)952 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
953 {
954     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
955     endpoint.SetTimer(aTimerFlag, aDelay);
956 }
957 
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)958 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
959 {
960     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
961     endpoint.CancelTimer(aTimerFlag);
962 }
963 
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)964 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
965 {
966     Tcp::Listener                &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
967     Tcp                          &tcp      = listener.Get<Tcp>();
968     struct tcpcb                 *rv       = (struct tcpcb *)-1;
969     otSockAddr                    addr;
970     otTcpEndpoint                *endpointPtr;
971     otTcpIncomingConnectionAction action;
972 
973     VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
974 
975     memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
976     addr.mPort = HostSwap16(aPort);
977     action     = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
978 
979     VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
980 
981     switch (action)
982     {
983     case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
984     {
985         Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
986 
987         /*
988          * The documentation says that the user must initialize the
989          * endpoint before passing it here, so we do a sanity check to make
990          * sure the endpoint is initialized and closed. That check may not
991          * be necessary, but we do it anyway.
992          */
993         VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
994 
995         rv = &endpoint.GetTcb();
996 
997         break;
998     }
999     case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1000         rv = nullptr;
1001         break;
1002     case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1003         rv = (struct tcpcb *)-1;
1004         break;
1005     }
1006 
1007 exit:
1008     return rv;
1009 }
1010 
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1011 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1012                                    struct tcpcb        *aAccepted,
1013                                    struct in6_addr     *aAddr,
1014                                    uint16_t             aPort)
1015 {
1016     Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1017     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1018     Tcp           &tcp      = endpoint.Get<Tcp>();
1019     bool           accepted = true;
1020 
1021     if (listener.mAcceptDoneCallback != nullptr)
1022     {
1023         otSockAddr addr;
1024 
1025         memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1026         addr.mPort = HostSwap16(aPort);
1027         listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1028 
1029         if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1030         {
1031             accepted = false;
1032         }
1033     }
1034 
1035     return accepted;
1036 }
1037 
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1038 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1039 {
1040     Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1041 
1042     if (endpoint.mDisconnectedCallback != nullptr)
1043     {
1044         otTcpDisconnectedReason reason;
1045 
1046         switch (aErrNum)
1047         {
1048         case CONN_LOST_NORMAL:
1049             reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1050             break;
1051         case ECONNREFUSED:
1052             reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1053             break;
1054         case ETIMEDOUT:
1055             reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1056             break;
1057         case ECONNRESET:
1058         default:
1059             reason = OT_TCP_DISCONNECTED_REASON_RESET;
1060             break;
1061         }
1062         endpoint.mDisconnectedCallback(&endpoint, reason);
1063     }
1064 }
1065 
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1066 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1067 {
1068     if (aNewState == TCP6S_CLOSED)
1069     {
1070         /* Re-initialize the TCB. */
1071         cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1072         aTcb->accepted_from = nullptr;
1073         initialize_tcb(aTcb);
1074     }
1075     /* Any adaptive changes to the sleep interval would go here. */
1076 }
1077 
tcplp_sys_log(const char * aFormat,...)1078 void tcplp_sys_log(const char *aFormat, ...)
1079 {
1080     char    buffer[128];
1081     va_list args;
1082     va_start(args, aFormat);
1083     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1084     va_end(args);
1085 
1086     LogDebg("%s", buffer);
1087 }
1088 
tcplp_sys_panic(const char * aFormat,...)1089 void tcplp_sys_panic(const char *aFormat, ...)
1090 {
1091     char    buffer[128];
1092     va_list args;
1093     va_start(args, aFormat);
1094     vsnprintf(buffer, sizeof(buffer), aFormat, args);
1095     va_end(args);
1096 
1097     LogCrit("%s", buffer);
1098 
1099     OT_ASSERT(false);
1100 }
1101 
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1102 bool tcplp_sys_autobind(otInstance       *aInstance,
1103                         const otSockAddr *aPeer,
1104                         otSockAddr       *aToBind,
1105                         bool              aBindAddress,
1106                         bool              aBindPort)
1107 {
1108     Instance &instance = AsCoreType(aInstance);
1109 
1110     return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1111                                         aBindAddress, aBindPort);
1112 }
1113 
tcplp_sys_generate_isn()1114 uint32_t tcplp_sys_generate_isn()
1115 {
1116     uint32_t isn;
1117     IgnoreError(Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&isn), sizeof(isn)));
1118     return isn;
1119 }
1120 
tcplp_sys_hostswap16(uint16_t aHostPort)1121 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort) { return HostSwap16(aHostPort); }
1122 
tcplp_sys_hostswap32(uint32_t aHostPort)1123 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort) { return HostSwap32(aHostPort); }
1124 }
1125 
1126 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1127