1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements TCP/IPv6 sockets.
32 */
33
34 #include "openthread-core-config.h"
35
36 #if OPENTHREAD_CONFIG_TCP_ENABLE
37
38 #include "tcp6.hpp"
39
40 #include "common/as_core_type.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/error.hpp"
43 #include "common/instance.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/num_utils.hpp"
47 #include "common/random.hpp"
48 #include "net/checksum.hpp"
49 #include "net/ip6.hpp"
50 #include "net/netif.hpp"
51
52 #include "../../third_party/tcplp/tcplp.h"
53
54 namespace ot {
55 namespace Ip6 {
56
57 using ot::Encoding::BigEndian::HostSwap16;
58 using ot::Encoding::BigEndian::HostSwap32;
59
60 RegisterLogModule("Tcp");
61
62 static_assert(sizeof(struct tcpcb) == sizeof(Tcp::Endpoint::mTcb), "mTcb field in otTcpEndpoint is sized incorrectly");
63 static_assert(alignof(struct tcpcb) == alignof(decltype(Tcp::Endpoint::mTcb)),
64 "mTcb field in otTcpEndpoint is aligned incorrectly");
65 static_assert(offsetof(Tcp::Endpoint, mTcb) == 0, "mTcb field in otTcpEndpoint has nonzero offset");
66
67 static_assert(sizeof(struct tcpcb_listen) == sizeof(Tcp::Listener::mTcbListen),
68 "mTcbListen field in otTcpListener is sized incorrectly");
69 static_assert(alignof(struct tcpcb_listen) == alignof(decltype(Tcp::Listener::mTcbListen)),
70 "mTcbListen field in otTcpListener is aligned incorrectly");
71 static_assert(offsetof(Tcp::Listener, mTcbListen) == 0, "mTcbListen field in otTcpEndpoint has nonzero offset");
72
Tcp(Instance & aInstance)73 Tcp::Tcp(Instance &aInstance)
74 : InstanceLocator(aInstance)
75 , mTimer(aInstance)
76 , mTasklet(aInstance)
77 , mEphemeralPort(kDynamicPortMin)
78 {
79 OT_UNUSED_VARIABLE(mEphemeralPort);
80 }
81
Initialize(Instance & aInstance,const otTcpEndpointInitializeArgs & aArgs)82 Error Tcp::Endpoint::Initialize(Instance &aInstance, const otTcpEndpointInitializeArgs &aArgs)
83 {
84 Error error;
85 struct tcpcb &tp = GetTcb();
86
87 memset(&tp, 0x00, sizeof(tp));
88
89 SuccessOrExit(error = aInstance.Get<Tcp>().mEndpoints.Add(*this));
90
91 mContext = aArgs.mContext;
92 mEstablishedCallback = aArgs.mEstablishedCallback;
93 mSendDoneCallback = aArgs.mSendDoneCallback;
94 mForwardProgressCallback = aArgs.mForwardProgressCallback;
95 mReceiveAvailableCallback = aArgs.mReceiveAvailableCallback;
96 mDisconnectedCallback = aArgs.mDisconnectedCallback;
97
98 memset(mTimers, 0x00, sizeof(mTimers));
99 memset(&mSockAddr, 0x00, sizeof(mSockAddr));
100 mPendingCallbacks = 0;
101
102 /*
103 * Initialize buffers --- formerly in initialize_tcb.
104 */
105 {
106 uint8_t *recvbuf = static_cast<uint8_t *>(aArgs.mReceiveBuffer);
107 size_t recvbuflen = aArgs.mReceiveBufferSize - ((aArgs.mReceiveBufferSize + 8) / 9);
108 uint8_t *reassbmp = recvbuf + recvbuflen;
109
110 lbuf_init(&tp.sendbuf);
111 cbuf_init(&tp.recvbuf, recvbuf, recvbuflen);
112 tp.reassbmp = reassbmp;
113 bmp_init(tp.reassbmp, BITS_TO_BYTES(recvbuflen));
114 }
115
116 tp.accepted_from = nullptr;
117 initialize_tcb(&tp);
118
119 /* Note that we do not need to zero-initialize mReceiveLinks. */
120
121 tp.instance = &aInstance;
122
123 exit:
124 return error;
125 }
126
GetInstance(void) const127 Instance &Tcp::Endpoint::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcb().instance)); }
128
GetLocalAddress(void) const129 const SockAddr &Tcp::Endpoint::GetLocalAddress(void) const
130 {
131 const struct tcpcb &tp = GetTcb();
132
133 static otSockAddr temp;
134
135 memcpy(&temp.mAddress, &tp.laddr, sizeof(temp.mAddress));
136 temp.mPort = HostSwap16(tp.lport);
137
138 return AsCoreType(&temp);
139 }
140
GetPeerAddress(void) const141 const SockAddr &Tcp::Endpoint::GetPeerAddress(void) const
142 {
143 const struct tcpcb &tp = GetTcb();
144
145 static otSockAddr temp;
146
147 memcpy(&temp.mAddress, &tp.faddr, sizeof(temp.mAddress));
148 temp.mPort = HostSwap16(tp.fport);
149
150 return AsCoreType(&temp);
151 }
152
Bind(const SockAddr & aSockName)153 Error Tcp::Endpoint::Bind(const SockAddr &aSockName)
154 {
155 Error error;
156 struct tcpcb &tp = GetTcb();
157
158 VerifyOrExit(!AsCoreType(&aSockName.mAddress).IsUnspecified(), error = kErrorInvalidArgs);
159 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
160
161 memcpy(&tp.laddr, &aSockName.mAddress, sizeof(tp.laddr));
162 tp.lport = HostSwap16(aSockName.mPort);
163 error = kErrorNone;
164
165 exit:
166 return error;
167 }
168
Connect(const SockAddr & aSockName,uint32_t aFlags)169 Error Tcp::Endpoint::Connect(const SockAddr &aSockName, uint32_t aFlags)
170 {
171 Error error = kErrorNone;
172 struct tcpcb &tp = GetTcb();
173 struct sockaddr_in6 sin6p;
174
175 OT_UNUSED_VARIABLE(aFlags);
176
177 VerifyOrExit(tp.t_state == TCP6S_CLOSED, error = kErrorInvalidState);
178
179 memcpy(&sin6p.sin6_addr, &aSockName.mAddress, sizeof(sin6p.sin6_addr));
180 sin6p.sin6_port = HostSwap16(aSockName.mPort);
181 error = BsdErrorToOtError(tcp6_usr_connect(&tp, &sin6p));
182
183 exit:
184 return error;
185 }
186
SendByReference(otLinkedBuffer & aBuffer,uint32_t aFlags)187 Error Tcp::Endpoint::SendByReference(otLinkedBuffer &aBuffer, uint32_t aFlags)
188 {
189 Error error;
190 struct tcpcb &tp = GetTcb();
191
192 size_t backlogBefore = GetBacklogBytes();
193 size_t sent = aBuffer.mLength;
194
195 SuccessOrExit(error = BsdErrorToOtError(tcp_usr_send(&tp, (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0, &aBuffer, 0)));
196
197 PostCallbacksAfterSend(sent, backlogBefore);
198
199 exit:
200 return error;
201 }
202
SendByExtension(size_t aNumBytes,uint32_t aFlags)203 Error Tcp::Endpoint::SendByExtension(size_t aNumBytes, uint32_t aFlags)
204 {
205 Error error;
206 bool moreToCome = (aFlags & OT_TCP_SEND_MORE_TO_COME) != 0;
207 struct tcpcb &tp = GetTcb();
208 size_t backlogBefore = GetBacklogBytes();
209 int bsdError;
210
211 VerifyOrExit(lbuf_head(&tp.sendbuf) != nullptr, error = kErrorInvalidState);
212
213 bsdError = tcp_usr_send(&tp, moreToCome ? 1 : 0, nullptr, aNumBytes);
214 SuccessOrExit(error = BsdErrorToOtError(bsdError));
215
216 PostCallbacksAfterSend(aNumBytes, backlogBefore);
217
218 exit:
219 return error;
220 }
221
ReceiveByReference(const otLinkedBuffer * & aBuffer)222 Error Tcp::Endpoint::ReceiveByReference(const otLinkedBuffer *&aBuffer)
223 {
224 struct tcpcb &tp = GetTcb();
225
226 cbuf_reference(&tp.recvbuf, &mReceiveLinks[0], &mReceiveLinks[1]);
227 aBuffer = &mReceiveLinks[0];
228
229 return kErrorNone;
230 }
231
ReceiveContiguify(void)232 Error Tcp::Endpoint::ReceiveContiguify(void)
233 {
234 struct tcpcb &tp = GetTcb();
235
236 cbuf_contiguify(&tp.recvbuf, tp.reassbmp);
237
238 return kErrorNone;
239 }
240
CommitReceive(size_t aNumBytes,uint32_t aFlags)241 Error Tcp::Endpoint::CommitReceive(size_t aNumBytes, uint32_t aFlags)
242 {
243 Error error = kErrorNone;
244 struct tcpcb &tp = GetTcb();
245
246 OT_UNUSED_VARIABLE(aFlags);
247
248 VerifyOrExit(cbuf_used_space(&tp.recvbuf) >= aNumBytes, error = kErrorFailed);
249 VerifyOrExit(aNumBytes > 0, error = kErrorNone);
250
251 cbuf_pop(&tp.recvbuf, aNumBytes);
252 error = BsdErrorToOtError(tcp_usr_rcvd(&tp));
253
254 exit:
255 return error;
256 }
257
SendEndOfStream(void)258 Error Tcp::Endpoint::SendEndOfStream(void)
259 {
260 struct tcpcb &tp = GetTcb();
261
262 return BsdErrorToOtError(tcp_usr_shutdown(&tp));
263 }
264
Abort(void)265 Error Tcp::Endpoint::Abort(void)
266 {
267 struct tcpcb &tp = GetTcb();
268
269 tcp_usr_abort(&tp);
270 /* connection_lost will do any reinitialization work for this socket. */
271 return kErrorNone;
272 }
273
Deinitialize(void)274 Error Tcp::Endpoint::Deinitialize(void)
275 {
276 Error error;
277
278 SuccessOrExit(error = Get<Tcp>().mEndpoints.Remove(*this));
279 SetNext(nullptr);
280
281 SuccessOrExit(error = Abort());
282
283 exit:
284 return error;
285 }
286
IsClosed(void) const287 bool Tcp::Endpoint::IsClosed(void) const { return GetTcb().t_state == TCP6S_CLOSED; }
288
TimerFlagToIndex(uint8_t aTimerFlag)289 uint8_t Tcp::Endpoint::TimerFlagToIndex(uint8_t aTimerFlag)
290 {
291 uint8_t timerIndex = 0;
292
293 switch (aTimerFlag)
294 {
295 case TT_DELACK:
296 timerIndex = kTimerDelack;
297 break;
298 case TT_REXMT:
299 case TT_PERSIST:
300 timerIndex = kTimerRexmtPersist;
301 break;
302 case TT_KEEP:
303 timerIndex = kTimerKeep;
304 break;
305 case TT_2MSL:
306 timerIndex = kTimer2Msl;
307 break;
308 }
309
310 return timerIndex;
311 }
312
IsTimerActive(uint8_t aTimerIndex)313 bool Tcp::Endpoint::IsTimerActive(uint8_t aTimerIndex)
314 {
315 bool active = false;
316 struct tcpcb *tp = &GetTcb();
317
318 OT_ASSERT(aTimerIndex < kNumTimers);
319 switch (aTimerIndex)
320 {
321 case kTimerDelack:
322 active = tcp_timer_active(tp, TT_DELACK);
323 break;
324 case kTimerRexmtPersist:
325 active = tcp_timer_active(tp, TT_REXMT) || tcp_timer_active(tp, TT_PERSIST);
326 break;
327 case kTimerKeep:
328 active = tcp_timer_active(tp, TT_KEEP);
329 break;
330 case kTimer2Msl:
331 active = tcp_timer_active(tp, TT_2MSL);
332 break;
333 }
334
335 return active;
336 }
337
SetTimer(uint8_t aTimerFlag,uint32_t aDelay)338 void Tcp::Endpoint::SetTimer(uint8_t aTimerFlag, uint32_t aDelay)
339 {
340 /*
341 * TCPlp has already set the flag for this timer to record that it's
342 * running. So, all that's left to do is record the expiry time and
343 * (re)set the main timer as appropriate.
344 */
345
346 TimeMilli now = TimerMilli::GetNow();
347 TimeMilli newFireTime = now + aDelay;
348 uint8_t timerIndex = TimerFlagToIndex(aTimerFlag);
349
350 mTimers[timerIndex] = newFireTime.GetValue();
351 LogDebg("Endpoint %p set timer %u to %u ms", static_cast<void *>(this), static_cast<unsigned int>(timerIndex),
352 static_cast<unsigned int>(aDelay));
353
354 Get<Tcp>().mTimer.FireAtIfEarlier(newFireTime);
355 }
356
CancelTimer(uint8_t aTimerFlag)357 void Tcp::Endpoint::CancelTimer(uint8_t aTimerFlag)
358 {
359 /*
360 * TCPlp has already cleared the timer flag before calling this. Since the
361 * main timer's callback properly handles the case where no timers are
362 * actually due, there's actually no work to be done here.
363 */
364
365 OT_UNUSED_VARIABLE(aTimerFlag);
366
367 LogDebg("Endpoint %p cancelled timer %u", static_cast<void *>(this),
368 static_cast<unsigned int>(TimerFlagToIndex(aTimerFlag)));
369 }
370
FirePendingTimers(TimeMilli aNow,bool & aHasFutureTimer,TimeMilli & aEarliestFutureExpiry)371 bool Tcp::Endpoint::FirePendingTimers(TimeMilli aNow, bool &aHasFutureTimer, TimeMilli &aEarliestFutureExpiry)
372 {
373 bool calledUserCallback = false;
374 struct tcpcb *tp = &GetTcb();
375
376 /*
377 * NOTE: Firing a timer might potentially activate/deactivate other timers.
378 * If timers x and y expire at the same time, but the callback for timer x
379 * (for x < y) cancels or postpones timer y, should timer y's callback be
380 * called? Our answer is no, since timer x's callback has updated the
381 * TCP stack's state in such a way that it no longer expects timer y's
382 * callback to to be called. Because the TCP stack thinks that timer y
383 * has been cancelled, calling timer y's callback could potentially cause
384 * problems.
385 *
386 * If the timer callbacks set other timers, then they may not be taken
387 * into account when setting aEarliestFutureExpiry. But mTimer's expiry
388 * time will be updated by those, so we can just compare against mTimer's
389 * expiry time when resetting mTimer.
390 */
391 for (uint8_t timerIndex = 0; timerIndex != kNumTimers; timerIndex++)
392 {
393 if (IsTimerActive(timerIndex))
394 {
395 TimeMilli expiry(mTimers[timerIndex]);
396
397 if (expiry <= aNow)
398 {
399 /*
400 * If a user callback is called, then return true. For TCPlp,
401 * this only happens if the connection is dropped (e.g., it
402 * times out).
403 */
404 int dropped;
405
406 switch (timerIndex)
407 {
408 case kTimerDelack:
409 dropped = tcp_timer_delack(tp);
410 break;
411 case kTimerRexmtPersist:
412 if (tcp_timer_active(tp, TT_REXMT))
413 {
414 dropped = tcp_timer_rexmt(tp);
415 }
416 else
417 {
418 dropped = tcp_timer_persist(tp);
419 }
420 break;
421 case kTimerKeep:
422 dropped = tcp_timer_keep(tp);
423 break;
424 case kTimer2Msl:
425 dropped = tcp_timer_2msl(tp);
426 break;
427 }
428 VerifyOrExit(dropped == 0, calledUserCallback = true);
429 }
430 else
431 {
432 aHasFutureTimer = true;
433 aEarliestFutureExpiry = Min(aEarliestFutureExpiry, expiry);
434 }
435 }
436 }
437
438 exit:
439 return calledUserCallback;
440 }
441
PostCallbacksAfterSend(size_t aSent,size_t aBacklogBefore)442 void Tcp::Endpoint::PostCallbacksAfterSend(size_t aSent, size_t aBacklogBefore)
443 {
444 size_t backlogAfter = GetBacklogBytes();
445
446 if (backlogAfter < aBacklogBefore + aSent && mForwardProgressCallback != nullptr)
447 {
448 mPendingCallbacks |= kForwardProgressCallbackFlag;
449 Get<Tcp>().mTasklet.Post();
450 }
451 }
452
FirePendingCallbacks(void)453 bool Tcp::Endpoint::FirePendingCallbacks(void)
454 {
455 bool calledUserCallback = false;
456
457 if ((mPendingCallbacks & kForwardProgressCallbackFlag) != 0 && mForwardProgressCallback != nullptr)
458 {
459 mForwardProgressCallback(this, GetSendBufferBytes(), GetBacklogBytes());
460 calledUserCallback = true;
461 }
462
463 mPendingCallbacks = 0;
464
465 return calledUserCallback;
466 }
467
GetSendBufferBytes(void) const468 size_t Tcp::Endpoint::GetSendBufferBytes(void) const
469 {
470 const struct tcpcb &tp = GetTcb();
471 return lbuf_used_space(&tp.sendbuf);
472 }
473
GetInFlightBytes(void) const474 size_t Tcp::Endpoint::GetInFlightBytes(void) const
475 {
476 const struct tcpcb &tp = GetTcb();
477 return tp.snd_max - tp.snd_una;
478 }
479
GetBacklogBytes(void) const480 size_t Tcp::Endpoint::GetBacklogBytes(void) const { return GetSendBufferBytes() - GetInFlightBytes(); }
481
GetLocalIp6Address(void)482 Address &Tcp::Endpoint::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().laddr); }
483
GetLocalIp6Address(void) const484 const Address &Tcp::Endpoint::GetLocalIp6Address(void) const
485 {
486 return *reinterpret_cast<const Address *>(&GetTcb().laddr);
487 }
488
GetForeignIp6Address(void)489 Address &Tcp::Endpoint::GetForeignIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcb().faddr); }
490
GetForeignIp6Address(void) const491 const Address &Tcp::Endpoint::GetForeignIp6Address(void) const
492 {
493 return *reinterpret_cast<const Address *>(&GetTcb().faddr);
494 }
495
Matches(const MessageInfo & aMessageInfo) const496 bool Tcp::Endpoint::Matches(const MessageInfo &aMessageInfo) const
497 {
498 bool matches = false;
499 const struct tcpcb *tp = &GetTcb();
500
501 VerifyOrExit(tp->t_state != TCP6S_CLOSED);
502 VerifyOrExit(tp->lport == HostSwap16(aMessageInfo.GetSockPort()));
503 VerifyOrExit(tp->fport == HostSwap16(aMessageInfo.GetPeerPort()));
504 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
505 VerifyOrExit(GetForeignIp6Address() == aMessageInfo.GetPeerAddr());
506
507 matches = true;
508
509 exit:
510 return matches;
511 }
512
Initialize(Instance & aInstance,const otTcpListenerInitializeArgs & aArgs)513 Error Tcp::Listener::Initialize(Instance &aInstance, const otTcpListenerInitializeArgs &aArgs)
514 {
515 Error error;
516 struct tcpcb_listen *tpl = &GetTcbListen();
517
518 SuccessOrExit(error = aInstance.Get<Tcp>().mListeners.Add(*this));
519
520 mContext = aArgs.mContext;
521 mAcceptReadyCallback = aArgs.mAcceptReadyCallback;
522 mAcceptDoneCallback = aArgs.mAcceptDoneCallback;
523
524 memset(tpl, 0x00, sizeof(struct tcpcb_listen));
525 tpl->instance = &aInstance;
526
527 exit:
528 return error;
529 }
530
GetInstance(void) const531 Instance &Tcp::Listener::GetInstance(void) const { return AsNonConst(AsCoreType(GetTcbListen().instance)); }
532
Listen(const SockAddr & aSockName)533 Error Tcp::Listener::Listen(const SockAddr &aSockName)
534 {
535 Error error;
536 uint16_t port = HostSwap16(aSockName.mPort);
537 struct tcpcb_listen *tpl = &GetTcbListen();
538
539 VerifyOrExit(Get<Tcp>().CanBind(aSockName), error = kErrorInvalidState);
540
541 memcpy(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr));
542 tpl->lport = port;
543 tpl->t_state = TCP6S_LISTEN;
544 error = kErrorNone;
545
546 exit:
547 return error;
548 }
549
StopListening(void)550 Error Tcp::Listener::StopListening(void)
551 {
552 struct tcpcb_listen *tpl = &GetTcbListen();
553
554 memset(&tpl->laddr, 0x00, sizeof(tpl->laddr));
555 tpl->lport = 0;
556 tpl->t_state = TCP6S_CLOSED;
557 return kErrorNone;
558 }
559
Deinitialize(void)560 Error Tcp::Listener::Deinitialize(void)
561 {
562 Error error;
563
564 SuccessOrExit(error = Get<Tcp>().mListeners.Remove(*this));
565 SetNext(nullptr);
566
567 exit:
568 return error;
569 }
570
IsClosed(void) const571 bool Tcp::Listener::IsClosed(void) const { return GetTcbListen().t_state == TCP6S_CLOSED; }
572
GetLocalIp6Address(void)573 Address &Tcp::Listener::GetLocalIp6Address(void) { return *reinterpret_cast<Address *>(&GetTcbListen().laddr); }
574
GetLocalIp6Address(void) const575 const Address &Tcp::Listener::GetLocalIp6Address(void) const
576 {
577 return *reinterpret_cast<const Address *>(&GetTcbListen().laddr);
578 }
579
Matches(const MessageInfo & aMessageInfo) const580 bool Tcp::Listener::Matches(const MessageInfo &aMessageInfo) const
581 {
582 bool matches = false;
583 const struct tcpcb_listen *tpl = &GetTcbListen();
584
585 VerifyOrExit(tpl->t_state == TCP6S_LISTEN);
586 VerifyOrExit(tpl->lport == HostSwap16(aMessageInfo.GetSockPort()));
587 VerifyOrExit(GetLocalIp6Address().IsUnspecified() || GetLocalIp6Address() == aMessageInfo.GetSockAddr());
588
589 matches = true;
590
591 exit:
592 return matches;
593 }
594
HandleMessage(ot::Ip6::Header & aIp6Header,Message & aMessage,MessageInfo & aMessageInfo)595 Error Tcp::HandleMessage(ot::Ip6::Header &aIp6Header, Message &aMessage, MessageInfo &aMessageInfo)
596 {
597 Error error = kErrorNotImplemented;
598
599 /*
600 * The type uint32_t was chosen for alignment purposes. The size is the
601 * maximum TCP header size, including options.
602 */
603 uint32_t header[15];
604
605 uint16_t length = aIp6Header.GetPayloadLength();
606 uint8_t headerSize;
607
608 struct ip6_hdr *ip6Header;
609 struct tcphdr *tcpHeader;
610
611 Endpoint *endpoint;
612 Endpoint *endpointPrev;
613
614 Listener *listener;
615 Listener *listenerPrev;
616
617 VerifyOrExit(length == aMessage.GetLength() - aMessage.GetOffset(), error = kErrorParse);
618 VerifyOrExit(length >= sizeof(Tcp::Header), error = kErrorParse);
619 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset() + offsetof(struct tcphdr, th_off_x2), headerSize));
620 headerSize = static_cast<uint8_t>((headerSize >> TH_OFF_SHIFT) << 2);
621 VerifyOrExit(headerSize >= sizeof(struct tcphdr) && headerSize <= sizeof(header) &&
622 static_cast<uint16_t>(headerSize) <= length,
623 error = kErrorParse);
624 SuccessOrExit(error = Checksum::VerifyMessageChecksum(aMessage, aMessageInfo, kProtoTcp));
625 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), &header[0], headerSize));
626
627 ip6Header = reinterpret_cast<struct ip6_hdr *>(&aIp6Header);
628 tcpHeader = reinterpret_cast<struct tcphdr *>(&header[0]);
629 tcp_fields_to_host(tcpHeader);
630
631 aMessageInfo.mPeerPort = HostSwap16(tcpHeader->th_sport);
632 aMessageInfo.mSockPort = HostSwap16(tcpHeader->th_dport);
633
634 endpoint = mEndpoints.FindMatching(aMessageInfo, endpointPrev);
635 if (endpoint != nullptr)
636 {
637 struct tcplp_signals sig;
638 int nextAction;
639 struct tcpcb *tp = &endpoint->GetTcb();
640
641 otLinkedBuffer *priorHead = lbuf_head(&tp->sendbuf);
642 size_t priorBacklog = endpoint->GetSendBufferBytes() - endpoint->GetInFlightBytes();
643
644 memset(&sig, 0x00, sizeof(sig));
645 nextAction = tcp_input(ip6Header, tcpHeader, &aMessage, tp, nullptr, &sig);
646 if (nextAction != RELOOKUP_REQUIRED)
647 {
648 ProcessSignals(*endpoint, priorHead, priorBacklog, sig);
649 ExitNow();
650 }
651 /* If the matching socket was in the TIME-WAIT state, then we try passive sockets. */
652 }
653
654 listener = mListeners.FindMatching(aMessageInfo, listenerPrev);
655 if (listener != nullptr)
656 {
657 struct tcpcb_listen *tpl = &listener->GetTcbListen();
658
659 tcp_input(ip6Header, tcpHeader, &aMessage, nullptr, tpl, nullptr);
660 ExitNow();
661 }
662
663 tcp_dropwithreset(ip6Header, tcpHeader, nullptr, &InstanceLocator::GetInstance(), length - headerSize,
664 ECONNREFUSED);
665
666 exit:
667 return error;
668 }
669
ProcessSignals(Endpoint & aEndpoint,otLinkedBuffer * aPriorHead,size_t aPriorBacklog,struct tcplp_signals & aSignals) const670 void Tcp::ProcessSignals(Endpoint &aEndpoint,
671 otLinkedBuffer *aPriorHead,
672 size_t aPriorBacklog,
673 struct tcplp_signals &aSignals) const
674 {
675 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
676 if (aSignals.conn_established && aEndpoint.mEstablishedCallback != nullptr)
677 {
678 aEndpoint.mEstablishedCallback(&aEndpoint);
679 }
680
681 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
682 if (aEndpoint.mSendDoneCallback != nullptr)
683 {
684 otLinkedBuffer *curr = aPriorHead;
685
686 for (uint32_t i = 0; i != aSignals.links_popped; i++)
687 {
688 otLinkedBuffer *next = curr->mNext;
689
690 VerifyOrExit(i == 0 || (IsInitialized(aEndpoint) && !aEndpoint.IsClosed()));
691
692 curr->mNext = nullptr;
693 aEndpoint.mSendDoneCallback(&aEndpoint, curr);
694 curr = next;
695 }
696 }
697
698 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
699 if (aEndpoint.mForwardProgressCallback != nullptr)
700 {
701 size_t backlogBytes = aEndpoint.GetBacklogBytes();
702
703 if (aSignals.bytes_acked > 0 || backlogBytes < aPriorBacklog)
704 {
705 aEndpoint.mForwardProgressCallback(&aEndpoint, aEndpoint.GetSendBufferBytes(), backlogBytes);
706 aEndpoint.mPendingCallbacks &= ~kForwardProgressCallbackFlag;
707 }
708 }
709
710 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
711 if ((aSignals.recvbuf_added || aSignals.rcvd_fin) && aEndpoint.mReceiveAvailableCallback != nullptr)
712 {
713 aEndpoint.mReceiveAvailableCallback(&aEndpoint, cbuf_used_space(&aEndpoint.GetTcb().recvbuf),
714 aEndpoint.GetTcb().reass_fin_index != -1,
715 cbuf_free_space(&aEndpoint.GetTcb().recvbuf));
716 }
717
718 VerifyOrExit(IsInitialized(aEndpoint) && !aEndpoint.IsClosed());
719 if (aEndpoint.GetTcb().t_state == TCP6S_TIME_WAIT && aEndpoint.mDisconnectedCallback != nullptr)
720 {
721 aEndpoint.mDisconnectedCallback(&aEndpoint, OT_TCP_DISCONNECTED_REASON_TIME_WAIT);
722 }
723
724 exit:
725 return;
726 }
727
BsdErrorToOtError(int aBsdError)728 Error Tcp::BsdErrorToOtError(int aBsdError)
729 {
730 Error error = kErrorFailed;
731
732 switch (aBsdError)
733 {
734 case 0:
735 error = kErrorNone;
736 break;
737 }
738
739 return error;
740 }
741
CanBind(const SockAddr & aSockName)742 bool Tcp::CanBind(const SockAddr &aSockName)
743 {
744 uint16_t port = HostSwap16(aSockName.mPort);
745 bool allowed = false;
746
747 for (Endpoint &endpoint : mEndpoints)
748 {
749 struct tcpcb *tp = &endpoint.GetTcb();
750
751 if (tp->lport == port)
752 {
753 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
754 VerifyOrExit(!reinterpret_cast<Address *>(&tp->laddr)->IsUnspecified());
755 VerifyOrExit(memcmp(&endpoint.GetTcb().laddr, &aSockName.mAddress, sizeof(tp->laddr)) != 0);
756 }
757 }
758
759 for (Listener &listener : mListeners)
760 {
761 struct tcpcb_listen *tpl = &listener.GetTcbListen();
762
763 if (tpl->lport == port)
764 {
765 VerifyOrExit(!aSockName.GetAddress().IsUnspecified());
766 VerifyOrExit(!reinterpret_cast<Address *>(&tpl->laddr)->IsUnspecified());
767 VerifyOrExit(memcmp(&tpl->laddr, &aSockName.mAddress, sizeof(tpl->laddr)) != 0);
768 }
769 }
770
771 allowed = true;
772
773 exit:
774 return allowed;
775 }
776
AutoBind(const SockAddr & aPeer,SockAddr & aToBind,bool aBindAddress,bool aBindPort)777 bool Tcp::AutoBind(const SockAddr &aPeer, SockAddr &aToBind, bool aBindAddress, bool aBindPort)
778 {
779 bool success;
780
781 if (aBindAddress)
782 {
783 const Address *source;
784
785 source = Get<Ip6>().SelectSourceAddress(aPeer.GetAddress());
786 VerifyOrExit(source != nullptr, success = false);
787 aToBind.SetAddress(*source);
788 }
789
790 if (aBindPort)
791 {
792 /*
793 * TODO: Use a less naive algorithm to allocate ephemeral ports. For
794 * example, see RFC 6056.
795 */
796
797 for (uint16_t i = 0; i != kDynamicPortMax - kDynamicPortMin + 1; i++)
798 {
799 aToBind.SetPort(mEphemeralPort);
800
801 if (mEphemeralPort == kDynamicPortMax)
802 {
803 mEphemeralPort = kDynamicPortMin;
804 }
805 else
806 {
807 mEphemeralPort++;
808 }
809
810 if (CanBind(aToBind))
811 {
812 ExitNow(success = true);
813 }
814 }
815
816 ExitNow(success = false);
817 }
818
819 success = CanBind(aToBind);
820
821 exit:
822 return success;
823 }
824
HandleTimer(void)825 void Tcp::HandleTimer(void)
826 {
827 TimeMilli now = TimerMilli::GetNow();
828 bool pendingTimer;
829 TimeMilli earliestPendingTimerExpiry;
830
831 LogDebg("Main TCP timer expired");
832
833 /*
834 * The timer callbacks could potentially set/reset/cancel timers.
835 * Importantly, Endpoint::SetTimer and Endpoint::CancelTimer do not call
836 * this function to recompute the timer. If they did, we'd have a
837 * re-entrancy problem, where the callbacks called in this function could
838 * wind up re-entering this function in a nested call frame.
839 *
840 * In general, calling this function from Endpoint::SetTimer and
841 * Endpoint::CancelTimer could be inefficient, since those functions are
842 * called multiple times on each received TCP segment. If we want to
843 * prevent the main timer from firing except when an actual TCP timer
844 * expires, a better alternative is to reset the main timer in
845 * HandleMessage, right before processing signals. That would achieve that
846 * objective while avoiding re-entrancy issues altogether.
847 */
848 restart:
849 pendingTimer = false;
850 earliestPendingTimerExpiry = now.GetDistantFuture();
851
852 for (Endpoint &endpoint : mEndpoints)
853 {
854 if (endpoint.FirePendingTimers(now, pendingTimer, earliestPendingTimerExpiry))
855 {
856 /*
857 * If a non-OpenThread callback is called --- which, in practice,
858 * happens if the connection times out and the user-defined
859 * connection lost callback is called --- then we might have to
860 * start over. The reason is that the user might deinitialize
861 * endpoints, changing the structure of the linked list. For
862 * example, if the user deinitializes both this endpoint and the
863 * next one in the linked list, then we can't continue traversing
864 * the linked list.
865 */
866 goto restart;
867 }
868 }
869
870 if (pendingTimer)
871 {
872 /*
873 * We need to use Timer::FireAtIfEarlier instead of timer::FireAt
874 * because one of the earlier callbacks might have set TCP timers,
875 * in which case `mTimer` would have been set to the earliest of those
876 * timers.
877 */
878 mTimer.FireAtIfEarlier(earliestPendingTimerExpiry);
879 LogDebg("Reset main TCP timer to %u ms", static_cast<unsigned int>(earliestPendingTimerExpiry - now));
880 }
881 else
882 {
883 LogDebg("Did not reset main TCP timer");
884 }
885 }
886
ProcessCallbacks(void)887 void Tcp::ProcessCallbacks(void)
888 {
889 for (Endpoint &endpoint : mEndpoints)
890 {
891 if (endpoint.FirePendingCallbacks())
892 {
893 mTasklet.Post();
894 break;
895 }
896 }
897 }
898
899 } // namespace Ip6
900 } // namespace ot
901
902 /*
903 * Implement TCPlp system stubs declared in tcplp.h.
904 *
905 * Because these functions have C linkage, it is important that only one
906 * definition is given for each function name, regardless of the namespace it
907 * in. For example, if we give two definitions of tcplp_sys_new_message, we
908 * will get errors, even if they are in different namespaces. To avoid
909 * confusion, I've put these functions outside of any namespace.
910 */
911
912 using namespace ot;
913 using namespace ot::Ip6;
914
915 extern "C" {
916
tcplp_sys_new_message(otInstance * aInstance)917 otMessage *tcplp_sys_new_message(otInstance *aInstance)
918 {
919 Instance &instance = AsCoreType(aInstance);
920 Message *message = instance.Get<ot::Ip6::Ip6>().NewMessage(0);
921
922 if (message)
923 {
924 message->SetLinkSecurityEnabled(true);
925 }
926
927 return message;
928 }
929
tcplp_sys_free_message(otInstance * aInstance,otMessage * aMessage)930 void tcplp_sys_free_message(otInstance *aInstance, otMessage *aMessage)
931 {
932 OT_UNUSED_VARIABLE(aInstance);
933 Message &message = AsCoreType(aMessage);
934 message.Free();
935 }
936
tcplp_sys_send_message(otInstance * aInstance,otMessage * aMessage,otMessageInfo * aMessageInfo)937 void tcplp_sys_send_message(otInstance *aInstance, otMessage *aMessage, otMessageInfo *aMessageInfo)
938 {
939 Instance &instance = AsCoreType(aInstance);
940 Message &message = AsCoreType(aMessage);
941 MessageInfo &info = AsCoreType(aMessageInfo);
942
943 LogDebg("Sending TCP segment: payload_size = %d", static_cast<int>(message.GetLength()));
944
945 IgnoreError(instance.Get<ot::Ip6::Ip6>().SendDatagram(message, info, kProtoTcp));
946 }
947
tcplp_sys_get_ticks(void)948 uint32_t tcplp_sys_get_ticks(void) { return TimerMilli::GetNow().GetValue(); }
949
tcplp_sys_get_millis(void)950 uint32_t tcplp_sys_get_millis(void) { return TimerMilli::GetNow().GetValue(); }
951
tcplp_sys_set_timer(struct tcpcb * aTcb,uint8_t aTimerFlag,uint32_t aDelay)952 void tcplp_sys_set_timer(struct tcpcb *aTcb, uint8_t aTimerFlag, uint32_t aDelay)
953 {
954 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
955 endpoint.SetTimer(aTimerFlag, aDelay);
956 }
957
tcplp_sys_stop_timer(struct tcpcb * aTcb,uint8_t aTimerFlag)958 void tcplp_sys_stop_timer(struct tcpcb *aTcb, uint8_t aTimerFlag)
959 {
960 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
961 endpoint.CancelTimer(aTimerFlag);
962 }
963
tcplp_sys_accept_ready(struct tcpcb_listen * aTcbListen,struct in6_addr * aAddr,uint16_t aPort)964 struct tcpcb *tcplp_sys_accept_ready(struct tcpcb_listen *aTcbListen, struct in6_addr *aAddr, uint16_t aPort)
965 {
966 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
967 Tcp &tcp = listener.Get<Tcp>();
968 struct tcpcb *rv = (struct tcpcb *)-1;
969 otSockAddr addr;
970 otTcpEndpoint *endpointPtr;
971 otTcpIncomingConnectionAction action;
972
973 VerifyOrExit(listener.mAcceptReadyCallback != nullptr);
974
975 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
976 addr.mPort = HostSwap16(aPort);
977 action = listener.mAcceptReadyCallback(&listener, &addr, &endpointPtr);
978
979 VerifyOrExit(tcp.IsInitialized(listener) && !listener.IsClosed());
980
981 switch (action)
982 {
983 case OT_TCP_INCOMING_CONNECTION_ACTION_ACCEPT:
984 {
985 Tcp::Endpoint &endpoint = AsCoreType(endpointPtr);
986
987 /*
988 * The documentation says that the user must initialize the
989 * endpoint before passing it here, so we do a sanity check to make
990 * sure the endpoint is initialized and closed. That check may not
991 * be necessary, but we do it anyway.
992 */
993 VerifyOrExit(tcp.IsInitialized(endpoint) && endpoint.IsClosed());
994
995 rv = &endpoint.GetTcb();
996
997 break;
998 }
999 case OT_TCP_INCOMING_CONNECTION_ACTION_DEFER:
1000 rv = nullptr;
1001 break;
1002 case OT_TCP_INCOMING_CONNECTION_ACTION_REFUSE:
1003 rv = (struct tcpcb *)-1;
1004 break;
1005 }
1006
1007 exit:
1008 return rv;
1009 }
1010
tcplp_sys_accepted_connection(struct tcpcb_listen * aTcbListen,struct tcpcb * aAccepted,struct in6_addr * aAddr,uint16_t aPort)1011 bool tcplp_sys_accepted_connection(struct tcpcb_listen *aTcbListen,
1012 struct tcpcb *aAccepted,
1013 struct in6_addr *aAddr,
1014 uint16_t aPort)
1015 {
1016 Tcp::Listener &listener = Tcp::Listener::FromTcbListen(*aTcbListen);
1017 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aAccepted);
1018 Tcp &tcp = endpoint.Get<Tcp>();
1019 bool accepted = true;
1020
1021 if (listener.mAcceptDoneCallback != nullptr)
1022 {
1023 otSockAddr addr;
1024
1025 memcpy(&addr.mAddress, aAddr, sizeof(addr.mAddress));
1026 addr.mPort = HostSwap16(aPort);
1027 listener.mAcceptDoneCallback(&listener, &endpoint, &addr);
1028
1029 if (!tcp.IsInitialized(endpoint) || endpoint.IsClosed())
1030 {
1031 accepted = false;
1032 }
1033 }
1034
1035 return accepted;
1036 }
1037
tcplp_sys_connection_lost(struct tcpcb * aTcb,uint8_t aErrNum)1038 void tcplp_sys_connection_lost(struct tcpcb *aTcb, uint8_t aErrNum)
1039 {
1040 Tcp::Endpoint &endpoint = Tcp::Endpoint::FromTcb(*aTcb);
1041
1042 if (endpoint.mDisconnectedCallback != nullptr)
1043 {
1044 otTcpDisconnectedReason reason;
1045
1046 switch (aErrNum)
1047 {
1048 case CONN_LOST_NORMAL:
1049 reason = OT_TCP_DISCONNECTED_REASON_NORMAL;
1050 break;
1051 case ECONNREFUSED:
1052 reason = OT_TCP_DISCONNECTED_REASON_REFUSED;
1053 break;
1054 case ETIMEDOUT:
1055 reason = OT_TCP_DISCONNECTED_REASON_TIMED_OUT;
1056 break;
1057 case ECONNRESET:
1058 default:
1059 reason = OT_TCP_DISCONNECTED_REASON_RESET;
1060 break;
1061 }
1062 endpoint.mDisconnectedCallback(&endpoint, reason);
1063 }
1064 }
1065
tcplp_sys_on_state_change(struct tcpcb * aTcb,int aNewState)1066 void tcplp_sys_on_state_change(struct tcpcb *aTcb, int aNewState)
1067 {
1068 if (aNewState == TCP6S_CLOSED)
1069 {
1070 /* Re-initialize the TCB. */
1071 cbuf_pop(&aTcb->recvbuf, cbuf_used_space(&aTcb->recvbuf));
1072 aTcb->accepted_from = nullptr;
1073 initialize_tcb(aTcb);
1074 }
1075 /* Any adaptive changes to the sleep interval would go here. */
1076 }
1077
tcplp_sys_log(const char * aFormat,...)1078 void tcplp_sys_log(const char *aFormat, ...)
1079 {
1080 char buffer[128];
1081 va_list args;
1082 va_start(args, aFormat);
1083 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1084 va_end(args);
1085
1086 LogDebg("%s", buffer);
1087 }
1088
tcplp_sys_panic(const char * aFormat,...)1089 void tcplp_sys_panic(const char *aFormat, ...)
1090 {
1091 char buffer[128];
1092 va_list args;
1093 va_start(args, aFormat);
1094 vsnprintf(buffer, sizeof(buffer), aFormat, args);
1095 va_end(args);
1096
1097 LogCrit("%s", buffer);
1098
1099 OT_ASSERT(false);
1100 }
1101
tcplp_sys_autobind(otInstance * aInstance,const otSockAddr * aPeer,otSockAddr * aToBind,bool aBindAddress,bool aBindPort)1102 bool tcplp_sys_autobind(otInstance *aInstance,
1103 const otSockAddr *aPeer,
1104 otSockAddr *aToBind,
1105 bool aBindAddress,
1106 bool aBindPort)
1107 {
1108 Instance &instance = AsCoreType(aInstance);
1109
1110 return instance.Get<Tcp>().AutoBind(*static_cast<const SockAddr *>(aPeer), *static_cast<SockAddr *>(aToBind),
1111 aBindAddress, aBindPort);
1112 }
1113
tcplp_sys_generate_isn()1114 uint32_t tcplp_sys_generate_isn()
1115 {
1116 uint32_t isn;
1117 IgnoreError(Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&isn), sizeof(isn)));
1118 return isn;
1119 }
1120
tcplp_sys_hostswap16(uint16_t aHostPort)1121 uint16_t tcplp_sys_hostswap16(uint16_t aHostPort) { return HostSwap16(aHostPort); }
1122
tcplp_sys_hostswap32(uint32_t aHostPort)1123 uint32_t tcplp_sys_hostswap32(uint32_t aHostPort) { return HostSwap32(aHostPort); }
1124 }
1125
1126 #endif // OPENTHREAD_CONFIG_TCP_ENABLE
1127