1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements TCP/IPv6 sockets.
32 */
33
34 #include "openthread-core-config.h"
35
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37
38 #include "tcp6.hpp"
39
40 #include "common/as_core_type.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/error.hpp"
43 #include "common/instance.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/num_utils.hpp"
47 #include "common/random.hpp"
48 #include "net/checksum.hpp"
49 #include "net/ip6.hpp"
50 #include "net/netif.hpp"
51
52 #include "../../third_party/tcplp/tcplp.h"
53
54 namespace ot {
55 namespace Ip6 {
56
57 using ot::Encoding::BigEndian::HostSwap16;
58 using ot::Encoding::BigEndian::HostSwap32;
59
60 RegisterLogModule("Tcp");
61
62 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
63 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
64 "mTcb field in otTcpEndpoint is aligned incorrectly");
65 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
66
67 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
68 "mTcbListen field in otTcpListener is sized incorrectly");
69 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
70 "mTcbListen field in otTcpListener is aligned incorrectly");
71 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
72
Tcp(Instance & aInstance)73 Tcp::Tcp(Instance &aInstance)
74 : InstanceLocator(aInstance)
75 , mTimer(aInstance)
76 , mTasklet(aInstance)
77 , mEphemeralPort(kDynamicPortMin)
78 {
79 OT_UNUSED_VARIABLE(mEphemeralPort);
80 }
81
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)82 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
83 {
84 Error error;
85 struct tcpcb &tp = GetTcb();
86
87 memset(&tp, 0x00, sizeof(tp));
88
89 SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
90
91 mContext = aArgs.mContext;
92 mEstablishedCallback = aArgs.mEstablishedCallback;
93 mSendDoneCallback = aArgs.mSendDoneCallback;
94 mForwardProgressCallback = aArgs.mForwardProgressCallback;
95 mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
96 mDisconnectedCallback = aArgs.mDisconnectedCallback;
97
98 memset(mTimers, 0x00, sizeof(mTimers));
99 memset(&mSockAddr, 0x00, sizeof(mSockAddr));
100 mPendingCallbacks = 0;
101
102 /*
103 * Initialize buffers --- formerly in initialize_tcb.
104 */
105 {
106 uint8_t *recvbuf = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
107 size_t recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
108 uint8_t *reassbmp = recvbuf + recvbuflen;
109
110 lbuf_init(&tp.sendbuf);
111 cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
112 tp.reassbmp = reassbmp;
113 bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
114 }
115
116 tp.accepted_from = nullptr;
117 initialize_tcb(&tp);
118
119 /* Note that we do not need to zero-initialize mReceiveLinks. */
120
121 tp.instance = &aInstance;
122
123 exit:
124 return error;
125 }
126
GetInstance(void) const127 Instance &Tcp::Endpoint::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcb().instance)); }
128
GetLocalAddress(void) const129 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
130 {
131 const struct tcpcb &tp = GetTcb();
132
133 static otSockAddr temp;
134
135 memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
136 temp.mPort = HostSwap16(tp.lport);
137
138 return AsCoreType(&temp);
139 }
140
GetPeerAddress(void) const141 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
142 {
143 const struct tcpcb &tp = GetTcb();
144
145 static otSockAddr temp;
146
147 memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
148 temp.mPort = HostSwap16(tp.fport);
149
150 return AsCoreType(&temp);
151 }
152
Bind(const SockAddr & aSockName)153 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
154 {
155 Error error;
156 struct tcpcb &tp = GetTcb();
157
158 VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
159 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
160
161 memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
162 tp.lport = HostSwap16(aSockName.mPort);
163 error = kErrorNone;
164
165 exit:
166 return error;
167 }
168
Connect(const SockAddr & aSockName,uint32_t aFlags)169 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
170 {
171 Error error = kErrorNone;
172 struct tcpcb &tp = GetTcb();
173
174 VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
175
176 if (aFlags & OT_TCP_CONNECT_NO_FAST_OPEN)
177 {
178 struct sockaddr_in6 sin6p;
179
180 tp.t_flags &= ~TF_FASTOPEN;
181 memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
182 sin6p.sin6_port = HostSwap16(aSockName.mPort);
183 error = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
184 }
185 else
186 {
187 tp.t_flags |= TF_FASTOPEN;
188
189 /* Stash the destination address in tp. */
190 memcpy(&tp.faddr, &aSockName.mAddress, sizeof(tp.faddr));
191 tp.fport = HostSwap16(aSockName.mPort);
192 }
193
194 exit:
195 return error;
196 }
197
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)198 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
199 {
200 Error error;
201 struct tcpcb &tp = GetTcb();
202
203 size_t backlogBefore = GetBacklogBytes();
204 size_t sent = aBuffer.mLength;
205
206 struct sockaddr_in6 sin6p;
207 struct sockaddr_in6 *name = nullptr;
208
209 if (IS_FASTOPEN(tp.t_flags))
210 {
211 memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
212 sin6p.sin6_port = tp.fport;
213 name = &sin6p;
214 }
215 SuccessOrExit(
216 error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0, name)));
217
218 PostCallbacksAfterSend(sent, backlogBefore);
219
220 exit:
221 return error;
222 }
223
SendByExtension(size_t aNumBytes,uint32_t aFlags)224 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
225 {
226 Error error;
227 bool moreToCome = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
228 struct tcpcb &tp = GetTcb();
229 size_t backlogBefore = GetBacklogBytes();
230 int bsdError;
231
232 struct sockaddr_in6 sin6p;
233 struct sockaddr_in6 *name = nullptr;
234
235 VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
236
237 if (IS_FASTOPEN(tp.t_flags))
238 {
239 memcpy(&sin6p.sin6_addr, &tp.faddr, sizeof(sin6p.sin6_addr));
240 sin6p.sin6_port = tp.fport;
241 name = &sin6p;
242 }
243
244 bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes, name);
245 SuccessOrExit(error = BsdErrorToOtError(bsdError));
246
247 PostCallbacksAfterSend(aNumBytes, backlogBefore);
248
249 exit:
250 return error;
251 }
252
ReceiveByReference(const otLinkedBuffer * & aBuffer)253 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
254 {
255 struct tcpcb &tp = GetTcb();
256
257 cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
258 aBuffer = &mReceiveLinks[0];
259
260 return kErrorNone;
261 }
262
ReceiveContiguify(void)263 Error Tcp::Endpoint::ReceiveContiguify(void)
264 {
265 struct tcpcb &tp = GetTcb();
266
267 cbuf_contiguify(&tp.recvbuf, tp.reassbmp);
268
269 return kErrorNone;
270 }
271
CommitReceive(size_t aNumBytes,uint32_t aFlags)272 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
273 {
274 Error error = kErrorNone;
275 struct tcpcb &tp = GetTcb();
276
277 OT_UNUSED_VARIABLE(aFlags);
278
279 VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
280 VerifyOrExit(aNumBytes > 0, error = kErrorNone);
281
282 cbuf_pop(&tp.recvbuf, aNumBytes);
283 error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
284
285 exit:
286 return error;
287 }
288
SendEndOfStream(void)289 Error Tcp::Endpoint::SendEndOfStream(void)
290 {
291 struct tcpcb &tp = GetTcb();
292
293 return BsdErrorToOtError(tcp_usr_shutdown(&tp));
294 }
295
Abort(void)296 Error Tcp::Endpoint::Abort(void)
297 {
298 struct tcpcb &tp = GetTcb();
299
300 tcp_usr_abort(&tp);
301 /* connection_lost will do any reinitialization work for this socket. */
302 return kErrorNone;
303 }
304
Deinitialize(void)305 Error Tcp::Endpoint::Deinitialize(void)
306 {
307 Error error;
308
309 SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
310 SetNext(nullptr);
311
312 SuccessOrExit(error = Abort());
313
314 exit:
315 return error;
316 }
317
IsClosed(void) const318 bool Tcp::Endpoint::IsClosed(void) const { return GetTcb().t_state == TCP6S_CLOSED; }
319
TimerFlagToIndex(uint8_t aTimerFlag)320 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
321 {
322 uint8_t timerIndex = 0;
323
324 switch (aTimerFlag)
325 {
326 case TT_DELACK:
327 timerIndex = kTimerDelack;
328 break;
329 case TT_REXMT:
330 case TT_PERSIST:
331 timerIndex = kTimerRexmtPersist;
332 break;
333 case TT_KEEP:
334 timerIndex = kTimerKeep;
335 break;
336 case TT_2MSL:
337 timerIndex = kTimer2Msl;
338 break;
339 }
340
341 return timerIndex;
342 }
343
IsTimerActive(uint8_t aTimerIndex)344 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
345 {
346 bool active = false;
347 struct tcpcb *tp = &GetTcb();
348
349 OT_ASSERT(aTimerIndex < kNumTimers);
350 switch (aTimerIndex)
351 {
352 case kTimerDelack:
353 active = tcp_timer_active(tp, TT_DELACK);
354 break;
355 case kTimerRexmtPersist:
356 active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
357 break;
358 case kTimerKeep:
359 active = tcp_timer_active(tp, TT_KEEP);
360 break;
361 case kTimer2Msl:
362 active = tcp_timer_active(tp, TT_2MSL);
363 break;
364 }
365
366 return active;
367 }
368
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)369 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
370 {
371 /*
372 * TCPlp has already set the flag for this timer to record that it's
373 * running. So, all that's left to do is record the expiry time and
374 * (re)set the main timer as appropriate.
375 */
376
377 TimeMilli now = TimerMilli::GetNow();
378 TimeMilli newFireTime = now + aDelay;
379 uint8_t timerIndex = TimerFlagToIndex(aTimerFlag);
380
381 mTimers[timerIndex] = newFireTime.GetValue();
382 LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
383 static_cast<unsigned int>(aDelay));
384
385 Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
386 }
387
CancelTimer(uint8_t aTimerFlag)388 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
389 {
390 /*
391 * TCPlp has already cleared the timer flag before calling this. Since the
392 * main timer's callback properly handles the case where no timers are
393 * actually due, there's actually no work to be done here.
394 */
395
396 OT_UNUSED_VARIABLE(aTimerFlag);
397
398 LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
399 static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
400 }
401
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)402 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
403 {
404 bool calledUserCallback = false;
405 struct tcpcb *tp = &GetTcb();
406
407 /*
408 * NOTE: Firing a timer might potentially activate/deactivate other timers.
409 * If timers x and y expire at the same time, but the callback for timer x
410 * (for x < y) cancels or postpones timer y, should timer y's callback be
411 * called? Our answer is no, since timer x's callback has updated the
412 * TCP stack's state in such a way that it no longer expects timer y's
413 * callback to to be called. Because the TCP stack thinks that timer y
414 * has been cancelled, calling timer y's callback could potentially cause
415 * problems.
416 *
417 * If the timer callbacks set other timers, then they may not be taken
418 * into account when setting aEarliestFutureExpiry. But mTimer's expiry
419 * time will be updated by those, so we can just compare against mTimer's
420 * expiry time when resetting mTimer.
421 */
422 for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
423 {
424 if (IsTimerActive(timerIndex))
425 {
426 TimeMilli expiry(mTimers[timerIndex]);
427
428 if (expiry <= aNow)
429 {
430 /*
431 * If a user callback is called, then return true. For TCPlp,
432 * this only happens if the connection is dropped (e.g., it
433 * times out).
434 */
435 int dropped;
436
437 switch (timerIndex)
438 {
439 case kTimerDelack:
440 dropped = tcp_timer_delack(tp);
441 break;
442 case kTimerRexmtPersist:
443 if (tcp_timer_active(tp, TT_REXMT))
444 {
445 dropped = tcp_timer_rexmt(tp);
446 }
447 else
448 {
449 dropped = tcp_timer_persist(tp);
450 }
451 break;
452 case kTimerKeep:
453 dropped = tcp_timer_keep(tp);
454 break;
455 case kTimer2Msl:
456 dropped = tcp_timer_2msl(tp);
457 break;
458 }
459 VerifyOrExit(dropped == 0, calledUserCallback = true);
460 }
461 else
462 {
463 aHasFutureTimer = true;
464 aEarliestFutureExpiry = Min(aEarliestFutureExpiry, expiry);
465 }
466 }
467 }
468
469 exit:
470 return calledUserCallback;
471 }
472
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)473 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
474 {
475 size_t backlogAfter = GetBacklogBytes();
476
477 if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
478 {
479 mPendingCallbacks |= kForwardProgressCallbackFlag;
480 Get<Tcp>().mTasklet.Post();
481 }
482 }
483
FirePendingCallbacks(void)484 bool Tcp::Endpoint::FirePendingCallbacks(void)
485 {
486 bool calledUserCallback = false;
487
488 if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
489 {
490 mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
491 calledUserCallback = true;
492 }
493
494 mPendingCallbacks = 0;
495
496 return calledUserCallback;
497 }
498
GetSendBufferBytes(void) const499 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
500 {
501 const struct tcpcb &tp = GetTcb();
502 return lbuf_used_space(&tp.sendbuf);
503 }
504
GetInFlightBytes(void) const505 size_t Tcp::Endpoint::GetInFlightBytes(void) const
506 {
507 const struct tcpcb &tp = GetTcb();
508 return tp.snd_max - tp.snd_una;
509 }
510
GetBacklogBytes(void) const511 size_t Tcp::Endpoint::GetBacklogBytes(void) const { return GetSendBufferBytes() - GetInFlightBytes(); }
512
GetLocalIp6Address(void)513 Address &Tcp::Endpoint::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().laddr); }
514
GetLocalIp6Address(void) const515 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
516 {
517 return *reinterpret_cast<const Address *>(&GetTcb().laddr);
518 }
519
GetForeignIp6Address(void)520 Address &Tcp::Endpoint::GetForeignIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().faddr); }
521
GetForeignIp6Address(void) const522 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
523 {
524 return *reinterpret_cast<const Address *>(&GetTcb().faddr);
525 }
526
Matches(const MessageInfo & aMessageInfo) const527 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
528 {
529 bool matches = false;
530 const struct tcpcb *tp = &GetTcb();
531
532 VerifyOrExit(tp->t_state != TCP6S_CLOSED);
533 VerifyOrExit(tp->lport == HostSwap16(aMessageInfo.GetSockPort()));
534 VerifyOrExit(tp->fport == HostSwap16(aMessageInfo.GetPeerPort()));
535 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
536 VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
537
538 matches = true;
539
540 exit:
541 return matches;
542 }
543
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)544 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
545 {
546 Error error;
547 struct tcpcb_listen *tpl = &GetTcbListen();
548
549 SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
550
551 mContext = aArgs.mContext;
552 mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
553 mAcceptDoneCallback = aArgs.mAcceptDoneCallback;
554
555 memset(tpl, 0x00, sizeof(struct tcpcb_listen));
556 tpl->instance = &aInstance;
557
558 exit:
559 return error;
560 }
561
GetInstance(void) const562 Instance &Tcp::Listener::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcbListen().instance)); }
563
Listen(const SockAddr & aSockName)564 Error Tcp::Listener::Listen(const SockAddr &aSockName)
565 {
566 Error error;
567 uint16_t port = HostSwap16(aSockName.mPort);
568 struct tcpcb_listen *tpl = &GetTcbListen();
569
570 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
571
572 memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
573 tpl->lport = port;
574 tpl->t_state = TCP6S_LISTEN;
575 error = kErrorNone;
576
577 exit:
578 return error;
579 }
580
StopListening(void)581 Error Tcp::Listener::StopListening(void)
582 {
583 struct tcpcb_listen *tpl = &GetTcbListen();
584
585 memset(&tpl->laddr, 0x00, sizeof(tpl->laddr));
586 tpl->lport = 0;
587 tpl->t_state = TCP6S_CLOSED;
588 return kErrorNone;
589 }
590
Deinitialize(void)591 Error Tcp::Listener::Deinitialize(void)
592 {
593 Error error;
594
595 SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
596 SetNext(nullptr);
597
598 exit:
599 return error;
600 }
601
IsClosed(void) const602 bool Tcp::Listener::IsClosed(void) const { return GetTcbListen().t_state == TCP6S_CLOSED; }
603
GetLocalIp6Address(void)604 Address &Tcp::Listener::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcbListen().laddr); }
605
GetLocalIp6Address(void) const606 const Address &Tcp::Listener::GetLocalIp6Address(void) const
607 {
608 return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
609 }
610
Matches(const MessageInfo & aMessageInfo) const611 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
612 {
613 bool matches = false;
614 const struct tcpcb_listen *tpl = &GetTcbListen();
615
616 VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
617 VerifyOrExit(tpl->lport == HostSwap16(aMessageInfo.GetSockPort()));
618 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
619
620 matches = true;
621
622 exit:
623 return matches;
624 }
625
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)626 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
627 {
628 Error error = kErrorNotImplemented;
629
630 /*
631 * The type uint32_t was chosen for alignment purposes. The size is the
632 * maximum TCP header size, including options.
633 */
634 uint32_t header[15];
635
636 uint16_t length = aIp6Header.GetPayloadLength();
637 uint8_t headerSize;
638
639 struct ip6_hdr *ip6Header;
640 struct tcphdr *tcpHeader;
641
642 Endpoint *endpoint;
643 Endpoint *endpointPrev;
644
645 Listener *listener;
646 Listener *listenerPrev;
647
648 struct tcplp_signals sig;
649 int nextAction;
650
651 VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
652 VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
653 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
654 headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
655 VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
656 static_cast<uint16_t>(headerSize) <= length,
657 error = kErrorParse);
658 SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
659 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
660
661 ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
662 tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
663 tcp_fields_to_host(tcpHeader);
664
665 aMessageInfo.mPeerPort = HostSwap16(tcpHeader->th_sport);
666 aMessageInfo.mSockPort = HostSwap16(tcpHeader->th_dport);
667
668 endpoint = mEndpoints.FindMatching(aMessageInfo, endpointPrev);
669 if (endpoint != nullptr)
670 {
671 struct tcpcb *tp = &endpoint->GetTcb();
672
673 otLinkedBuffer *priorHead = lbuf_head(&tp->sendbuf);
674 size_t priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
675
676 memset(&sig, 0x00, sizeof(sig));
677 nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
678 if (nextAction != RELOOKUP_REQUIRED)
679 {
680 ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
681 ExitNow();
682 }
683 /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
684 }
685
686 listener = mListeners.FindMatching(aMessageInfo, listenerPrev);
687 if (listener != nullptr)
688 {
689 struct tcpcb_listen *tpl = &listener->GetTcbListen();
690
691 memset(&sig, 0x00, sizeof(sig));
692 nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, &sig);
693 OT_ASSERT(nextAction != RELOOKUP_REQUIRED);
694 if (sig.accepted_connection != nullptr)
695 {
696 ProcessSignals(Tcp::Endpoint::FromTcb(*sig.accepted_connection), nullptr, 0, sig);
697 }
698 ExitNow();
699 }
700
701 tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
702 ECONNREFUSED);
703
704 exit:
705 return error;
706 }
707
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals) const708 void Tcp::ProcessSignals(Endpoint &aEndpoint,
709 otLinkedBuffer *aPriorHead,
710 size_t aPriorBacklog,
711 struct tcplp_signals &aSignals) const
712 {
713 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
714 if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
715 {
716 aEndpoint.mEstablishedCallback(&aEndpoint);
717 }
718
719 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
720 if (aEndpoint.mSendDoneCallback != nullptr)
721 {
722 otLinkedBuffer *curr = aPriorHead;
723
724 OT_ASSERT(curr != nullptr || aSignals.links_popped == 0);
725
726 for (uint32_t i = 0; i != aSignals.links_popped; i++)
727 {
728 otLinkedBuffer *next = curr->mNext;
729
730 VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
731
732 curr->mNext = nullptr;
733 aEndpoint.mSendDoneCallback(&aEndpoint, curr);
734 curr = next;
735 }
736 }
737
738 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
739 if (aEndpoint.mForwardProgressCallback != nullptr)
740 {
741 size_t backlogBytes = aEndpoint.GetBacklogBytes();
742
743 if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
744 {
745 aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
746 aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
747 }
748 }
749
750 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
751 if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
752 {
753 aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
754 aEndpoint.GetTcb().reass_fin_index != -1,
755 cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
756 }
757
758 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
759 if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
760 {
761 aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
762 }
763
764 exit:
765 return;
766 }
767
BsdErrorToOtError(int aBsdError)768 Error Tcp::BsdErrorToOtError(int aBsdError)
769 {
770 Error error = kErrorFailed;
771
772 switch (aBsdError)
773 {
774 case 0:
775 error = kErrorNone;
776 break;
777 }
778
779 return error;
780 }
781
CanBind(const SockAddr & aSockName)782 bool Tcp::CanBind(const SockAddr &aSockName)
783 {
784 uint16_t port = HostSwap16(aSockName.mPort);
785 bool allowed = false;
786
787 for (Endpoint &endpoint : mEndpoints)
788 {
789 struct tcpcb *tp = &endpoint.GetTcb();
790
791 if (tp->lport == port)
792 {
793 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
794 VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
795 VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
796 }
797 }
798
799 for (Listener &listener : mListeners)
800 {
801 struct tcpcb_listen *tpl = &listener.GetTcbListen();
802
803 if (tpl->lport == port)
804 {
805 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
806 VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
807 VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
808 }
809 }
810
811 allowed = true;
812
813 exit:
814 return allowed;
815 }
816
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)817 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
818 {
819 bool success;
820
821 if (aBindAddress)
822 {
823 const Address *source;
824
825 source = Get<Ip6>().SelectSourceAddress(aPeer.GetAddress());
826 VerifyOrExit(source != nullptr, success = false);
827 aToBind.SetAddress(*source);
828 }
829
830 if (aBindPort)
831 {
832 /*
833 * TODO: Use a less naive algorithm to allocate ephemeral ports. For
834 * example, see RFC 6056.
835 */
836
837 for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
838 {
839 aToBind.SetPort(mEphemeralPort);
840
841 if (mEphemeralPort == kDynamicPortMax)
842 {
843 mEphemeralPort = kDynamicPortMin;
844 }
845 else
846 {
847 mEphemeralPort++;
848 }
849
850 if (CanBind(aToBind))
851 {
852 ExitNow(success = true);
853 }
854 }
855
856 ExitNow(success = false);
857 }
858
859 success = CanBind(aToBind);
860
861 exit:
862 return success;
863 }
864
HandleTimer(void)865 void Tcp::HandleTimer(void)
866 {
867 TimeMilli now = TimerMilli::GetNow();
868 bool pendingTimer;
869 TimeMilli earliestPendingTimerExpiry;
870
871 LogDebg("Main TCP timer expired");
872
873 /*
874 * The timer callbacks could potentially set/reset/cancel timers.
875 * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
876 * this function to recompute the timer. If they did, we'd have a
877 * re-entrancy problem, where the callbacks called in this function could
878 * wind up re-entering this function in a nested call frame.
879 *
880 * In general, calling this function from Endpoint::SetTimer and
881 * Endpoint::CancelTimer could be inefficient, since those functions are
882 * called multiple times on each received TCP segment. If we want to
883 * prevent the main timer from firing except when an actual TCP timer
884 * expires, a better alternative is to reset the main timer in
885 * HandleMessage, right before processing signals. That would achieve that
886 * objective while avoiding re-entrancy issues altogether.
887 */
888 restart:
889 pendingTimer = false;
890 earliestPendingTimerExpiry = now.GetDistantFuture();
891
892 for (Endpoint &endpoint : mEndpoints)
893 {
894 if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
895 {
896 /*
897 * If a non-OpenThread callback is called --- which, in practice,
898 * happens if the connection times out and the user-defined
899 * connection lost callback is called --- then we might have to
900 * start over. The reason is that the user might deinitialize
901 * endpoints, changing the structure of the linked list. For
902 * example, if the user deinitializes both this endpoint and the
903 * next one in the linked list, then we can't continue traversing
904 * the linked list.
905 */
906 goto restart;
907 }
908 }
909
910 if (pendingTimer)
911 {
912 /*
913 * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
914 * because one of the earlier callbacks might have set TCP timers,
915 * in which case `mTimer` would have been set to the earliest of those
916 * timers.
917 */
918 mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
919 LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
920 }
921 else
922 {
923 LogDebg("Did not reset main TCP timer");
924 }
925 }
926
ProcessCallbacks(void)927 void Tcp::ProcessCallbacks(void)
928 {
929 for (Endpoint &endpoint : mEndpoints)
930 {
931 if (endpoint.FirePendingCallbacks())
932 {
933 mTasklet.Post();
934 break;
935 }
936 }
937 }
938
939 } // namespace Ip6
940 } // namespace ot
941
942 /*
943 * Implement TCPlp system stubs declared in tcplp.h.
944 *
945 * Because these functions have C linkage, it is important that only one
946 * definition is given for each function name, regardless of the namespace it
947 * in. For example, if we give two definitions of tcplp_sys_new_message, we
948 * will get errors, even if they are in different namespaces. To avoid
949 * confusion, I've put these functions outside of any namespace.
950 */
951
952 using namespace ot;
953 using namespace ot::Ip6;
954
955 extern "C" {
956
tcplp_sys_new_message(otInstance * aInstance)957 otMessage *tcplp_sys_new_message(otInstance *aInstance)
958 {
959 Instance &instance = AsCoreType(aInstance);
960 Message *message = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
961
962 if (message)
963 {
964 message->SetLinkSecurityEnabled(true);
965 }
966
967 return message;
968 }
969
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)970 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
971 {
972 OT_UNUSED_VARIABLE(aInstance);
973 Message &message = AsCoreType(aMessage);
974 message.Free();
975 }
976
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)977 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
978 {
979 Instance &instance = AsCoreType(aInstance);
980 Message &message = AsCoreType(aMessage);
981 MessageInfo &info = AsCoreType(aMessageInfo);
982
983 LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
984
985 IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
986 }
987
tcplp_sys_get_ticks(void)988 uint32_t tcplp_sys_get_ticks(void) { return TimerMilli::GetNow().GetValue(); }
989
tcplp_sys_get_millis(void)990 uint32_t tcplp_sys_get_millis(void) { return TimerMilli::GetNow().GetValue(); }
991
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)992 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
993 {
994 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
995 endpoint.SetTimer(aTimerFlag, aDelay);
996 }
997
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)998 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
999 {
1000 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1001 endpoint.CancelTimer(aTimerFlag);
1002 }
1003
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)1004 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
1005 {
1006 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1007 Tcp &tcp = listener.Get<Tcp>();
1008 struct tcpcb *rv = (struct tcpcb *)-1;
1009 otSockAddr addr;
1010 otTcpEndpoint *endpointPtr;
1011 otTcpIncomingConnectionAction action;
1012
1013 VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
1014
1015 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1016 addr.mPort = HostSwap16(aPort);
1017 action = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
1018
1019 VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
1020
1021 switch (action)
1022 {
1023 case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
1024 {
1025 Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
1026
1027 /*
1028 * The documentation says that the user must initialize the
1029 * endpoint before passing it here, so we do a sanity check to make
1030 * sure the endpoint is initialized and closed. That check may not
1031 * be necessary, but we do it anyway.
1032 */
1033 VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
1034
1035 rv = &endpoint.GetTcb();
1036
1037 break;
1038 }
1039 case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1040 rv = nullptr;
1041 break;
1042 case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1043 rv = (struct tcpcb *)-1;
1044 break;
1045 }
1046
1047 exit:
1048 return rv;
1049 }
1050
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1051 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1052 struct tcpcb *aAccepted,
1053 struct in6_addr *aAddr,
1054 uint16_t aPort)
1055 {
1056 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1057 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1058 Tcp &tcp = endpoint.Get<Tcp>();
1059 bool accepted = true;
1060
1061 if (listener.mAcceptDoneCallback != nullptr)
1062 {
1063 otSockAddr addr;
1064
1065 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1066 addr.mPort = HostSwap16(aPort);
1067 listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1068
1069 if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1070 {
1071 accepted = false;
1072 }
1073 }
1074
1075 return accepted;
1076 }
1077
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1078 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1079 {
1080 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1081
1082 if (endpoint.mDisconnectedCallback != nullptr)
1083 {
1084 otTcpDisconnectedReason reason;
1085
1086 switch (aErrNum)
1087 {
1088 case CONN_LOST_NORMAL:
1089 reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1090 break;
1091 case ECONNREFUSED:
1092 reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1093 break;
1094 case ETIMEDOUT:
1095 reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1096 break;
1097 case ECONNRESET:
1098 default:
1099 reason = OT_TCP_DISCONNECTED_REASON_RESET;
1100 break;
1101 }
1102 endpoint.mDisconnectedCallback(&endpoint, reason);
1103 }
1104 }
1105
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1106 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1107 {
1108 if (aNewState == TCP6S_CLOSED)
1109 {
1110 /* Re-initialize the TCB. */
1111 cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1112 aTcb->accepted_from = nullptr;
1113 initialize_tcb(aTcb);
1114 }
1115 /* Any adaptive changes to the sleep interval would go here. */
1116 }
1117
tcplp_sys_log(const char * aFormat,...)1118 void tcplp_sys_log(const char *aFormat, ...)
1119 {
1120 char buffer[128];
1121 va_list args;
1122 va_start(args, aFormat);
1123 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1124 va_end(args);
1125
1126 LogDebg("%s", buffer);
1127 }
1128
tcplp_sys_panic(const char * aFormat,...)1129 void tcplp_sys_panic(const char *aFormat, ...)
1130 {
1131 char buffer[128];
1132 va_list args;
1133 va_start(args, aFormat);
1134 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1135 va_end(args);
1136
1137 LogCrit("%s", buffer);
1138
1139 OT_ASSERT(false);
1140 }
1141
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1142 bool tcplp_sys_autobind(otInstance *aInstance,
1143 const otSockAddr *aPeer,
1144 otSockAddr *aToBind,
1145 bool aBindAddress,
1146 bool aBindPort)
1147 {
1148 Instance &instance = AsCoreType(aInstance);
1149
1150 return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1151 aBindAddress, aBindPort);
1152 }
1153
tcplp_sys_generate_isn()1154 uint32_t tcplp_sys_generate_isn()
1155 {
1156 uint32_t isn;
1157 IgnoreError(Random::Crypto::Fill(isn));
1158 return isn;
1159 }
1160
tcplp_sys_hostswap16(uint16_t aHostPort)1161 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort) { return HostSwap16(aHostPort); }
1162
tcplp_sys_hostswap32(uint32_t aHostPort)1163 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort) { return HostSwap32(aHostPort); }
1164 }
1165
1166 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1167