1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_dso.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "common/num_utils.hpp"
41 #include "common/random.hpp"
42
43 /**
44 * @file
45 * This file implements the DNS Stateful Operations (DSO) per RFC 8490.
46 */
47
48 namespace ot {
49 namespace Dns {
50
51 RegisterLogModule("DnsDso");
52
53 //---------------------------------------------------------------------------------------------------------------------
54 // otPlatDso transport callbacks
55
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)56 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
57 {
58 return &AsCoreType(aConnection).GetInstance();
59 }
60
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)61 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
62 {
63 return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
64 }
65
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)66 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
67 {
68 AsCoreType(aConnection).HandleConnected();
69 }
70
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)71 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
72 {
73 AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
74 }
75
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)76 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
77 {
78 AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
79 }
80
81 //---------------------------------------------------------------------------------------------------------------------
82 // Dso::Connection
83
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)84 Dso::Connection::Connection(Instance &aInstance,
85 const Ip6::SockAddr &aPeerSockAddr,
86 Callbacks &aCallbacks,
87 uint32_t aInactivityTimeout,
88 uint32_t aKeepAliveInterval)
89 : InstanceLocator(aInstance)
90 , mNext(nullptr)
91 , mCallbacks(aCallbacks)
92 , mPeerSockAddr(aPeerSockAddr)
93 , mState(kStateDisconnected)
94 , mIsServer(false)
95 , mInactivity(aInactivityTimeout)
96 , mKeepAlive(aKeepAliveInterval)
97 {
98 OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
99 Init(/* aIsServer */ false);
100 }
101
Init(bool aIsServer)102 void Dso::Connection::Init(bool aIsServer)
103 {
104 mNextMessageId = 1;
105 mIsServer = aIsServer;
106 mStateDidChange = false;
107 mLongLivedOperation = false;
108 mRetryDelay = 0;
109 mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
110 mDisconnectReason = kReasonUnknown;
111 }
112
SetState(State aState)113 void Dso::Connection::SetState(State aState)
114 {
115 VerifyOrExit(mState != aState);
116
117 LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
118 mPeerSockAddr.ToString().AsCString());
119
120 mState = aState;
121 mStateDidChange = true;
122
123 exit:
124 return;
125 }
126
SignalAnyStateChange(void)127 void Dso::Connection::SignalAnyStateChange(void)
128 {
129 VerifyOrExit(mStateDidChange);
130 mStateDidChange = false;
131
132 switch (mState)
133 {
134 case kStateDisconnected:
135 mCallbacks.mHandleDisconnected(*this);
136 break;
137
138 case kStateConnectedButSessionless:
139 mCallbacks.mHandleConnected(*this);
140 break;
141
142 case kStateSessionEstablished:
143 mCallbacks.mHandleSessionEstablished(*this);
144 break;
145
146 case kStateConnecting:
147 case kStateEstablishingSession:
148 break;
149 };
150
151 exit:
152 return;
153 }
154
NewMessage(void)155 Message *Dso::Connection::NewMessage(void)
156 {
157 return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
158 Message::Settings(Message::kPriorityNormal));
159 }
160
Connect(void)161 void Dso::Connection::Connect(void)
162 {
163 OT_ASSERT(mState == kStateDisconnected);
164
165 Init(/* aIsServer */ false);
166 Get<Dso>().mClientConnections.Push(*this);
167 MarkAsConnecting();
168 otPlatDsoConnect(this, &mPeerSockAddr);
169 }
170
Accept(void)171 void Dso::Connection::Accept(void)
172 {
173 OT_ASSERT(mState == kStateDisconnected);
174
175 Init(/* aIsServer */ true);
176 Get<Dso>().mServerConnections.Push(*this);
177 MarkAsConnecting();
178 }
179
MarkAsConnecting(void)180 void Dso::Connection::MarkAsConnecting(void)
181 {
182 SetState(kStateConnecting);
183
184 // While in `kStateConnecting` state we use the `mKeepAlive` to
185 // track the `kConnectingTimeout` (if connection is not established
186 // within the timeout, we consider it as failure and close it).
187
188 mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
189 Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
190
191 // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
192 // or timeout.
193 }
194
HandleConnected(void)195 void Dso::Connection::HandleConnected(void)
196 {
197 OT_ASSERT(mState == kStateConnecting);
198
199 SetState(kStateConnectedButSessionless);
200 ResetTimeouts(/* aIsKeepAliveMessage */ false);
201
202 SignalAnyStateChange();
203 }
204
Disconnect(DisconnectMode aMode,DisconnectReason aReason)205 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
206 {
207 VerifyOrExit(mState != kStateDisconnected);
208
209 mDisconnectReason = aReason;
210 MarkAsDisconnected();
211
212 otPlatDsoDisconnect(this, MapEnum(aMode));
213
214 exit:
215 return;
216 }
217
HandleDisconnected(DisconnectMode aMode)218 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
219 {
220 VerifyOrExit(mState != kStateDisconnected);
221
222 if (mState == kStateConnecting)
223 {
224 mDisconnectReason = kReasonFailedToConnect;
225 }
226 else
227 {
228 switch (aMode)
229 {
230 case kGracefullyClose:
231 mDisconnectReason = kReasonPeerClosed;
232 break;
233
234 case kForciblyAbort:
235 mDisconnectReason = kReasonPeerAborted;
236 }
237 }
238
239 MarkAsDisconnected();
240 SignalAnyStateChange();
241
242 exit:
243 return;
244 }
245
MarkAsDisconnected(void)246 void Dso::Connection::MarkAsDisconnected(void)
247 {
248 if (IsClient())
249 {
250 IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
251 }
252 else
253 {
254 IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
255 }
256
257 mPendingRequests.Clear();
258 SetState(kStateDisconnected);
259
260 LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
261 }
262
MarkSessionEstablished(void)263 void Dso::Connection::MarkSessionEstablished(void)
264 {
265 switch (mState)
266 {
267 case kStateConnectedButSessionless:
268 case kStateEstablishingSession:
269 case kStateSessionEstablished:
270 break;
271
272 case kStateDisconnected:
273 case kStateConnecting:
274 OT_ASSERT(false);
275 }
276
277 SetState(kStateSessionEstablished);
278 }
279
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)280 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
281 {
282 return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
283 }
284
SendUnidirectionalMessage(Message & aMessage)285 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
286 {
287 MessageId messageId = 0;
288
289 return SendMessage(aMessage, kUnidirectionalMessage, messageId);
290 }
291
SendResponseMessage(Message & aMessage,MessageId aResponseId)292 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
293 {
294 return SendMessage(aMessage, kResponseMessage, aResponseId);
295 }
296
SetLongLivedOperation(bool aLongLivedOperation)297 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
298 {
299 VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
300
301 mLongLivedOperation = aLongLivedOperation;
302
303 LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
304
305 if (!mLongLivedOperation)
306 {
307 TimeMilli now = TimerMilli::GetNow();
308 TimeMilli nextTime;
309
310 nextTime = GetNextFireTime(now);
311
312 if (nextTime != now.GetDistantFuture())
313 {
314 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
315 }
316 }
317
318 exit:
319 return;
320 }
321
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)322 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
323 {
324 Error error = kErrorNone;
325 Message *message = nullptr;
326 RetryDelayTlv retryDelayTlv;
327 MessageId messageId;
328
329 switch (mState)
330 {
331 case kStateSessionEstablished:
332 OT_ASSERT(IsServer());
333 break;
334
335 case kStateConnectedButSessionless:
336 case kStateEstablishingSession:
337 case kStateDisconnected:
338 case kStateConnecting:
339 OT_ASSERT(false);
340 }
341
342 message = NewMessage();
343 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
344
345 retryDelayTlv.Init();
346 retryDelayTlv.SetRetryDelay(aDelay);
347 SuccessOrExit(error = message->Append(retryDelayTlv));
348 error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
349
350 exit:
351 FreeMessageOnError(message, error);
352 return error;
353 }
354
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)355 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
356 {
357 Error error = kErrorNone;
358
359 VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
360
361 // If acting as server, the timeout values are the ones we grant
362 // to a connecting clients. If acting as client, the timeout
363 // values are what to request when sending Keep Alive message.
364 // If in `kStateDisconnected` we set both (since we don't know
365 // yet whether we are going to connect as client or server).
366
367 if ((mState == kStateDisconnected) || IsServer())
368 {
369 mKeepAlive.SetInterval(aKeepAliveInterval);
370 AdjustInactivityTimeout(aInactivityTimeout);
371 }
372
373 if ((mState == kStateDisconnected) || IsClient())
374 {
375 mKeepAlive.SetRequestInterval(aKeepAliveInterval);
376 mInactivity.SetRequestInterval(aInactivityTimeout);
377 }
378
379 switch (mState)
380 {
381 case kStateDisconnected:
382 case kStateConnecting:
383 break;
384
385 case kStateConnectedButSessionless:
386 case kStateEstablishingSession:
387 if (IsServer())
388 {
389 break;
390 }
391
392 OT_FALL_THROUGH;
393
394 case kStateSessionEstablished:
395 error = SendKeepAliveMessage();
396 }
397
398 exit:
399 return error;
400 }
401
SendKeepAliveMessage(void)402 Error Dso::Connection::SendKeepAliveMessage(void)
403 {
404 return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
405 }
406
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)407 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
408 {
409 // Sends a Keep Alive message of a given type. This is a common
410 // method used by both client and server. `aResponseId` is
411 // applicable and used only when the message type is
412 // `kResponseMessage`.
413
414 Error error = kErrorNone;
415 Message *message = nullptr;
416 KeepAliveTlv keepAliveTlv;
417
418 switch (mState)
419 {
420 case kStateConnectedButSessionless:
421 case kStateEstablishingSession:
422 if (IsServer())
423 {
424 // While session is being established, server is only allowed
425 // to send a Keep Alive response to a request from client.
426 OT_ASSERT(aMessageType == kResponseMessage);
427 }
428 break;
429
430 case kStateSessionEstablished:
431 break;
432
433 case kStateDisconnected:
434 case kStateConnecting:
435 OT_ASSERT(false);
436 }
437
438 // Server can send Keep Alive response (to a request from client)
439 // or a unidirectional Keep Alive message. Client can send
440 // KeepAlive request message.
441
442 if (IsServer())
443 {
444 if (aMessageType == kResponseMessage)
445 {
446 OT_ASSERT(aResponseId != 0);
447 }
448 else
449 {
450 OT_ASSERT(aMessageType == kUnidirectionalMessage);
451 }
452 }
453 else
454 {
455 OT_ASSERT(aMessageType == kRequestMessage);
456 }
457
458 message = NewMessage();
459 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
460
461 keepAliveTlv.Init();
462
463 if (IsServer())
464 {
465 keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
466 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
467 }
468 else
469 {
470 keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
471 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
472 }
473
474 SuccessOrExit(error = message->Append(keepAliveTlv));
475
476 error = SendMessage(*message, aMessageType, aResponseId);
477
478 exit:
479 FreeMessageOnError(message, error);
480 return error;
481 }
482
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)483 Error Dso::Connection::SendMessage(Message &aMessage,
484 MessageType aMessageType,
485 MessageId &aMessageId,
486 Dns::Header::Response aResponseCode,
487 uint32_t aResponseTimeout)
488 {
489 Error error = kErrorNone;
490 Tlv::Type primaryTlvType = Tlv::kReservedType;
491 Dns::Header header;
492
493 switch (mState)
494 {
495 case kStateConnectedButSessionless:
496 // To establish session, client MUST send a request message.
497 // Server is not allowed to send any messages. Unidirectional
498 // messages are not allowed before session is established.
499 OT_ASSERT(IsClient());
500 OT_ASSERT(aMessageType == kRequestMessage);
501 break;
502
503 case kStateEstablishingSession:
504 // During session establishment, client is allowed to send
505 // additional request messages, server is only allowed to
506 // send response.
507 if (IsClient())
508 {
509 OT_ASSERT(aMessageType == kRequestMessage);
510 }
511 else
512 {
513 OT_ASSERT(aMessageType == kResponseMessage);
514 }
515 break;
516
517 case kStateSessionEstablished:
518 // All message types are allowed.
519 break;
520
521 case kStateDisconnected:
522 case kStateConnecting:
523 OT_ASSERT(false);
524 }
525
526 // A DSO request or unidirectional message MUST contain at
527 // least one TLV. The first TLV is the "Primary TLV" and
528 // determines the nature of the operation being performed.
529 // A DSO response message may contain no TLVs, or may contain
530 // one or more TLVs. Response Primary TLV(s) MUST appear first
531 // in a DSO response message.
532
533 aMessage.SetOffset(0);
534 IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
535
536 switch (aMessageType)
537 {
538 case kResponseMessage:
539 break;
540 case kRequestMessage:
541 case kUnidirectionalMessage:
542 OT_ASSERT(primaryTlvType != Tlv::kReservedType);
543 }
544
545 // `header` is cleared from its constructor call so all fields
546 // start as zero.
547
548 switch (aMessageType)
549 {
550 case kRequestMessage:
551 header.SetType(Dns::Header::kTypeQuery);
552 aMessageId = mNextMessageId;
553 break;
554
555 case kResponseMessage:
556 header.SetType(Dns::Header::kTypeResponse);
557 break;
558
559 case kUnidirectionalMessage:
560 header.SetType(Dns::Header::kTypeQuery);
561 aMessageId = 0;
562 break;
563 }
564
565 header.SetMessageId(aMessageId);
566 header.SetQueryType(Dns::Header::kQueryTypeDso);
567 header.SetResponseCode(aResponseCode);
568 SuccessOrExit(error = aMessage.Prepend(header));
569
570 SuccessOrExit(error = AppendPadding(aMessage));
571
572 // Update `mPendingRequests` list with the new request info
573
574 if (aMessageType == kRequestMessage)
575 {
576 SuccessOrExit(
577 error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
578
579 if (++mNextMessageId == 0)
580 {
581 mNextMessageId = 1;
582 }
583 }
584
585 LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
586 mPeerSockAddr.ToString().AsCString());
587
588 switch (mState)
589 {
590 case kStateConnectedButSessionless:
591 // On client we transition from "connected" state to
592 // "establishing session" state on successfully sending a
593 // request message.
594 if (IsClient())
595 {
596 SetState(kStateEstablishingSession);
597 }
598 break;
599
600 case kStateEstablishingSession:
601 // On server we transition from "establishing session" state
602 // to "established" on sending a response with success
603 // response code.
604 if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
605 {
606 SetState(kStateSessionEstablished);
607 }
608
609 default:
610 break;
611 }
612
613 ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
614
615 otPlatDsoSend(this, &aMessage);
616
617 // Signal any state changes. This is done at the very end when the
618 // `SendMessage()` is fully processed (all state and local
619 // variables are updated) to ensure that we do not have any
620 // reentrancy issues (e.g., if the callback signalling state
621 // change triggers another tx).
622
623 SignalAnyStateChange();
624
625 exit:
626 return error;
627 }
628
AppendPadding(Message & aMessage)629 Error Dso::Connection::AppendPadding(Message &aMessage)
630 {
631 // This method appends Encryption Padding TLV to a DSO message.
632 // It uses the padding policy "Random-Block-Length Padding" from
633 // RFC 8467.
634
635 static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
636
637 Error error = kErrorNone;
638 uint16_t blockLength;
639 EncryptionPaddingTlv paddingTlv;
640
641 // We pick a random block length. The random selection can be
642 // based on a "weak" source of randomness (so the use of
643 // `NonCrypto` is fine). We add padding to the message such
644 // that its padded length is a multiple of the chosen block
645 // length.
646
647 blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
648
649 paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
650
651 SuccessOrExit(error = aMessage.Append(paddingTlv));
652
653 for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
654 {
655 SuccessOrExit(error = aMessage.Append<uint8_t>(0));
656 }
657
658 exit:
659 return error;
660 }
661
HandleReceive(Message & aMessage)662 void Dso::Connection::HandleReceive(Message &aMessage)
663 {
664 Error error = kErrorAbort;
665 Tlv::Type primaryTlvType = Tlv::kReservedType;
666 Dns::Header header;
667
668 SuccessOrExit(aMessage.Read(0, header));
669
670 if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
671 {
672 if (header.GetType() == Dns::Header::kTypeQuery)
673 {
674 SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
675 error = kErrorNone;
676 }
677
678 ExitNow();
679 }
680
681 switch (mState)
682 {
683 case kStateConnectedButSessionless:
684 // After connection is established, client should initiate
685 // establishing session (by sending a request). So no rx is
686 // allowed before this. On server, we allow rx of a request
687 // message only.
688 VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
689 break;
690
691 case kStateEstablishingSession:
692 // Unidirectional message are allowed after session is
693 // established. While session is being established, on client,
694 // we allow rx on response message. On server we can rx
695 // request or response.
696
697 VerifyOrExit(header.GetMessageId() != 0);
698
699 if (IsClient())
700 {
701 VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
702 }
703 break;
704
705 case kStateSessionEstablished:
706 // All message types are allowed.
707 break;
708
709 case kStateDisconnected:
710 case kStateConnecting:
711 ExitNow();
712 }
713
714 // All count fields MUST be set to zero in the header.
715 VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
716 (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
717
718 aMessage.SetOffset(sizeof(header));
719
720 switch (ReadPrimaryTlv(aMessage, primaryTlvType))
721 {
722 case kErrorNone:
723 VerifyOrExit(primaryTlvType != Tlv::kReservedType);
724 break;
725
726 case kErrorNotFound:
727 // The `primaryTlvType` is set to `Tlv::kReservedType`
728 // (value zero) to indicate that there is no primary TLV.
729 break;
730
731 default:
732 ExitNow();
733 }
734
735 switch (header.GetType())
736 {
737 case Dns::Header::kTypeQuery:
738 error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
739 break;
740
741 case Dns::Header::kTypeResponse:
742 error = ProcessResponseMessage(header, aMessage, primaryTlvType);
743 break;
744 }
745
746 exit:
747 aMessage.Free();
748
749 if (error == kErrorNone)
750 {
751 ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
752 }
753 else
754 {
755 Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
756 }
757
758 // We signal any state change at the very end when the received
759 // message is fully processed (all state and local variables are
760 // updated) to ensure that we do not have any reentrancy issues
761 // (e.g., if a `Connection` method happens to be called from the
762 // callback).
763
764 SignalAnyStateChange();
765 }
766
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const767 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
768 {
769 // Read and validate the primary TLV (first TLV after the header).
770 // The `aMessage.GetOffset()` must point to the first TLV. If no
771 // TLV then `kErrorNotFound` is returned. If TLV in message is not
772 // well-formed `kErrorParse` is returned. The read TLV type is
773 // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
774 // (value zero) when `kErrorNotFound`).
775
776 Error error = kErrorNotFound;
777 Tlv tlv;
778
779 aPrimaryTlvType = Tlv::kReservedType;
780
781 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
782 VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
783 aPrimaryTlvType = tlv.GetType();
784 error = kErrorNone;
785
786 exit:
787 return error;
788 }
789
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)790 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
791 const Message &aMessage,
792 Tlv::Type aPrimaryTlvType)
793 {
794 Error error = kErrorAbort;
795
796 if (IsServer() && (mState == kStateConnectedButSessionless))
797 {
798 SetState(kStateEstablishingSession);
799 }
800
801 // A DSO request or unidirectional message MUST contain at
802 // least one TLV which is the "Primary TLV" and determines
803 // the nature of the operation being performed.
804
805 switch (aPrimaryTlvType)
806 {
807 case KeepAliveTlv::kType:
808 error = ProcessKeepAliveMessage(aHeader, aMessage);
809 break;
810
811 case RetryDelayTlv::kType:
812 error = ProcessRetryDelayMessage(aHeader, aMessage);
813 break;
814
815 case Tlv::kReservedType:
816 case EncryptionPaddingTlv::kType:
817 // Misbehavior by peer.
818 break;
819
820 default:
821 if (aHeader.GetMessageId() == 0)
822 {
823 LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
824
825 error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
826 }
827 else
828 {
829 MessageId messageId = aHeader.GetMessageId();
830
831 LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
832
833 error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
834
835 // `kErrorNotFound` indicates that TLV type is not known.
836
837 if (error == kErrorNotFound)
838 {
839 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
840 error = kErrorNone;
841 }
842 }
843 break;
844 }
845
846 return error;
847 }
848
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)849 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
850 const Message &aMessage,
851 Tlv::Type aPrimaryTlvType)
852 {
853 Error error = kErrorAbort;
854 Tlv::Type requestPrimaryTlvType;
855
856 // If a client or server receives a response where the message
857 // ID is zero, or is any other value that does not match the
858 // message ID of any of its outstanding operations, this is a
859 // fatal error and the recipient MUST forcibly abort the
860 // connection immediately.
861
862 VerifyOrExit(aHeader.GetMessageId() != 0);
863 VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
864
865 // If the response has no error and contains a primary TLV, it
866 // MUST match the request primary TLV.
867
868 if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
869 {
870 VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
871 }
872
873 mPendingRequests.Remove(aHeader.GetMessageId());
874
875 switch (requestPrimaryTlvType)
876 {
877 case KeepAliveTlv::kType:
878 SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
879 break;
880
881 default:
882 SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
883 requestPrimaryTlvType));
884 break;
885 }
886
887 // DSO session is established when client sends a request message
888 // and receives a response from server with no error code.
889
890 if (IsClient() && (mState == kStateEstablishingSession) &&
891 (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
892 {
893 SetState(kStateSessionEstablished);
894 }
895
896 exit:
897 return error;
898 }
899
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)900 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
901 {
902 Error error = kErrorAbort;
903 uint16_t offset = aMessage.GetOffset();
904 Tlv tlv;
905 KeepAliveTlv keepAliveTlv;
906
907 if (aHeader.GetType() == Dns::Header::kTypeResponse)
908 {
909 // A Keep Alive response message is allowed on a client from a sever.
910
911 VerifyOrExit(IsClient());
912
913 if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
914 {
915 // We got an error response code from server for our
916 // Keep Alive request message. If this happens while
917 // establishing the DSO session, it indicates that server
918 // does not support DSO, so we close the connection. If
919 // this happens while session is already established, it
920 // is a misbehavior (fatal error) by server.
921
922 if (mState == kStateEstablishingSession)
923 {
924 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
925 error = kErrorNone;
926 }
927
928 ExitNow();
929 }
930 }
931
932 // Parse and validate the Keep Alive Message
933
934 SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
935 offset += keepAliveTlv.GetSize();
936
937 VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
938
939 // Keep Alive message MUST contain only one Keep Alive TLV.
940
941 while (offset < aMessage.GetLength())
942 {
943 SuccessOrExit(aMessage.Read(offset, tlv));
944 offset += tlv.GetSize();
945
946 VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
947 }
948
949 VerifyOrExit(offset == aMessage.GetLength());
950
951 if (aHeader.GetType() == Dns::Header::kTypeQuery)
952 {
953 if (IsServer())
954 {
955 // Received a Keep Alive message from client. It MUST
956 // be a request message (not unidirectional). We prepare
957 // and send a Keep Alive response.
958
959 VerifyOrExit(aHeader.GetMessageId() != 0);
960
961 LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
962
963 IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
964 error = kErrorNone;
965 ExitNow();
966 }
967
968 // Received a Keep Alive message on client from server. Server
969 // Keep Alive message MUST be unidirectional (message ID
970 // zero).
971
972 VerifyOrExit(aHeader.GetMessageId() == 0);
973 }
974
975 LogInfo("Received Keep Alive %s message from server %s",
976 (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
977
978 // Receiving a Keep Alive interval value from server less than the
979 // minimum (ten seconds) is a fatal error and client MUST then
980 // abort the connection.
981
982 VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
983
984 // Update the timeout intervals on the connection from
985 // the new values we got from the server. The receive
986 // of the Keep Alive message does not itself reset the
987 // inactivity timer. So we use `AdjustInactivityTimeout`
988 // which takes into account the time elapsed since the
989 // last activity.
990
991 AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
992 mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
993
994 LogInfo("Timeouts Inactivity:%lu, KeepAlive:%lu", ToUlong(mInactivity.GetInterval()),
995 ToUlong(mKeepAlive.GetInterval()));
996
997 error = kErrorNone;
998
999 exit:
1000 return error;
1001 }
1002
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)1003 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
1004
1005 {
1006 Error error = kErrorAbort;
1007 RetryDelayTlv retryDelayTlv;
1008
1009 // Retry Delay TLV can be used as the Primary TLV only in
1010 // a unidirectional message sent from server to client.
1011 // It is used by the server to instruct the client to
1012 // close the session and its underlying connection, and not
1013 // to reconnect for the indicated time interval.
1014
1015 VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1016
1017 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1018 VerifyOrExit(retryDelayTlv.IsValid());
1019
1020 mRetryDelayErrorCode = aHeader.GetResponseCode();
1021 mRetryDelay = retryDelayTlv.GetRetryDelay();
1022
1023 LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1024 LogInfo(" RetryDelay:%lu ms, ResponseCode:%d", ToUlong(mRetryDelay), mRetryDelayErrorCode);
1025
1026 Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1027
1028 exit:
1029 return error;
1030 }
1031
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1032 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1033 {
1034 Message *response = NewMessage();
1035 Dns::Header header;
1036
1037 VerifyOrExit(response != nullptr);
1038
1039 header.SetMessageId(aHeader.GetMessageId());
1040 header.SetType(Dns::Header::kTypeResponse);
1041 header.SetQueryType(aHeader.GetQueryType());
1042 header.SetResponseCode(aResponseCode);
1043
1044 SuccessOrExit(response->Prepend(header));
1045
1046 otPlatDsoSend(this, response);
1047 response = nullptr;
1048
1049 exit:
1050 FreeMessage(response);
1051 }
1052
AdjustInactivityTimeout(uint32_t aNewTimeout)1053 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1054 {
1055 // This method sets the inactivity timeout interval to a new value
1056 // and updates the expiration time based on the new timeout value.
1057 //
1058 // On client, it is called on receiving a Keep Alive response or
1059 // unidirectional message from server. Note that the receive of
1060 // the Keep Alive message does not itself reset the inactivity
1061 // timer. So the time elapsed since the last activity should be
1062 // taken into account with the new inactivity timeout value.
1063 //
1064 // On server this method is called from `SetTimeouts()` when a new
1065 // inactivity timeout value is set.
1066
1067 TimeMilli now = TimerMilli::GetNow();
1068 TimeMilli start;
1069 TimeMilli newExpiration;
1070
1071 if (mState == kStateDisconnected)
1072 {
1073 mInactivity.SetInterval(aNewTimeout);
1074 ExitNow();
1075 }
1076
1077 VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1078
1079 // Calculate the start time (i.e., the last time inactivity timer
1080 // was cleared). If the previous inactivity time is set to
1081 // `kInfinite` value (`IsUsed()` returns `false`) then
1082 // `GetExpirationTime()` returns the start time. Otherwise, we
1083 // calculate it going back from the current expiration time with
1084 // the current wait interval.
1085
1086 if (!mInactivity.IsUsed())
1087 {
1088 start = mInactivity.GetExpirationTime();
1089 }
1090 else if (IsClient())
1091 {
1092 start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1093 }
1094 else
1095 {
1096 start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1097 }
1098
1099 mInactivity.SetInterval(aNewTimeout);
1100
1101 if (!mInactivity.IsUsed())
1102 {
1103 newExpiration = start;
1104 }
1105 else if (IsClient())
1106 {
1107 newExpiration = start + aNewTimeout;
1108
1109 if (newExpiration < now)
1110 {
1111 newExpiration = now;
1112 }
1113 }
1114 else
1115 {
1116 newExpiration = start + CalculateServerInactivityWaitTime();
1117
1118 if (newExpiration < now)
1119 {
1120 // If the server abruptly reduces the inactivity timeout
1121 // such that current elapsed time is already more than
1122 // twice the new inactivity timeout, then the client is
1123 // immediately considered delinquent (server can forcibly
1124 // abort the connection). So to give the client time to
1125 // close the connection gracefully, the server SHOULD
1126 // give the client an additional grace period of either
1127 // five seconds or one quarter of the new inactivity
1128 // timeout, whichever is greater [RFC 8490 - 7.1.1].
1129
1130 newExpiration = now + Max(kMinServerInactivityWaitTime, aNewTimeout / 4);
1131 }
1132 }
1133
1134 mInactivity.SetExpirationTime(newExpiration);
1135
1136 exit:
1137 return;
1138 }
1139
CalculateServerInactivityWaitTime(void) const1140 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1141 {
1142 // A server will abort an idle session after five seconds
1143 // (`kMinServerInactivityWaitTime`) or twice the inactivity
1144 // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1145
1146 OT_ASSERT(mInactivity.IsUsed());
1147
1148 return Max(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1149 }
1150
ResetTimeouts(bool aIsKeepAliveMessage)1151 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1152 {
1153 TimeMilli now = TimerMilli::GetNow();
1154 TimeMilli nextTime;
1155
1156 // At both servers and clients, the generation or reception of any
1157 // complete DNS message resets both timers for that DSO
1158 // session, with the one exception being that a DSO Keep Alive
1159 // message resets only the keep alive timer, not the inactivity
1160 // timeout timer [RFC 8490 - 6.3]
1161
1162 if (mKeepAlive.IsUsed())
1163 {
1164 // On client, we wait for the Keep Alive interval but on server
1165 // we wait for twice the interval before considering Keep Alive
1166 // timeout.
1167 //
1168 // Note that we limit the interval to `Timeout::kMaxInterval`
1169 // (which is ~12 days). This max limit ensures that even twice
1170 // the interval is less than max OpenThread timer duration so
1171 // that the expiration time calculations below stay within the
1172 // `TimerMilli` range.
1173
1174 mKeepAlive.SetExpirationTime(now + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1175 }
1176
1177 if (!aIsKeepAliveMessage)
1178 {
1179 if (mInactivity.IsUsed())
1180 {
1181 mInactivity.SetExpirationTime(
1182 now + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1183 }
1184 else
1185 {
1186 // When Inactivity timeout is not used (i.e., interval is set
1187 // to the special `kInfinite` value), we still need to track
1188 // the time so that if/when later the inactivity interval
1189 // gets changed, we can adjust the remaining time correctly
1190 // from `AdjustInactivityTimeout()`. In this case, we just
1191 // track the current time as "expiration time".
1192
1193 mInactivity.SetExpirationTime(now);
1194 }
1195 }
1196
1197 nextTime = GetNextFireTime(now);
1198
1199 if (nextTime != now.GetDistantFuture())
1200 {
1201 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1202 }
1203 }
1204
GetNextFireTime(TimeMilli aNow) const1205 TimeMilli Dso::Connection::GetNextFireTime(TimeMilli aNow) const
1206 {
1207 TimeMilli nextTime = aNow.GetDistantFuture();
1208
1209 switch (mState)
1210 {
1211 case kStateDisconnected:
1212 break;
1213
1214 case kStateConnecting:
1215 // While in `kStateConnecting`, Keep Alive timer is
1216 // used for `kConnectingTimeout`.
1217 VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1218 nextTime = mKeepAlive.GetExpirationTime();
1219 break;
1220
1221 case kStateConnectedButSessionless:
1222 case kStateEstablishingSession:
1223 case kStateSessionEstablished:
1224 nextTime = Min(nextTime, mPendingRequests.GetNextFireTime(aNow));
1225
1226 if (mKeepAlive.IsUsed())
1227 {
1228 VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1229 nextTime = Min(nextTime, mKeepAlive.GetExpirationTime());
1230 }
1231
1232 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1233 {
1234 // An operation being active on a DSO Session includes
1235 // a request message waiting for a response, or an
1236 // active long-lived operation.
1237
1238 VerifyOrExit(mInactivity.GetExpirationTime() > aNow, nextTime = aNow);
1239 nextTime = Min(nextTime, mInactivity.GetExpirationTime());
1240 }
1241
1242 break;
1243 }
1244
1245 exit:
1246 return nextTime;
1247 }
1248
HandleTimer(TimeMilli aNow,TimeMilli & aNextTime)1249 void Dso::Connection::HandleTimer(TimeMilli aNow, TimeMilli &aNextTime)
1250 {
1251 switch (mState)
1252 {
1253 case kStateDisconnected:
1254 break;
1255
1256 case kStateConnecting:
1257 if (mKeepAlive.IsExpired(aNow))
1258 {
1259 Disconnect(kGracefullyClose, kReasonFailedToConnect);
1260 }
1261 break;
1262
1263 case kStateConnectedButSessionless:
1264 case kStateEstablishingSession:
1265 case kStateSessionEstablished:
1266 if (mPendingRequests.HasAnyTimedOut(aNow))
1267 {
1268 // If server sends no response to a request, client
1269 // waits for 30 seconds (`kResponseTimeout`) after which
1270 // client MUST forcibly abort the connection.
1271 Disconnect(kForciblyAbort, kReasonResponseTimeout);
1272 ExitNow();
1273 }
1274
1275 // The inactivity timer is kept clear, while an operation is
1276 // active on the session (which includes a request waiting for
1277 // response or an active long-lived operation).
1278
1279 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation && mInactivity.IsExpired(aNow))
1280 {
1281 // On client, if the inactivity timeout is reached, the
1282 // connection is closed gracefully. On server, if too much
1283 // time (`CalculateServerInactivityWaitTime()`, i.e., five
1284 // seconds or twice the current inactivity timeout interval,
1285 // whichever is grater) elapses server MUST consider the
1286 // client delinquent and MUST forcibly abort the connection.
1287
1288 Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1289 ExitNow();
1290 }
1291
1292 if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNow))
1293 {
1294 // On client, if the Keep Alive interval elapses without any
1295 // DNS messages being sent or received, the client MUST take
1296 // action and send a DSO Keep Alive message.
1297 //
1298 // On server, if twice the Keep Alive interval value elapses
1299 // without any messages being sent or received, the server
1300 // considers the client delinquent and aborts the connection.
1301
1302 if (IsClient())
1303 {
1304 IgnoreError(SendKeepAliveMessage());
1305 }
1306 else
1307 {
1308 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1309 ExitNow();
1310 }
1311 }
1312 break;
1313 }
1314
1315 exit:
1316 aNextTime = Min(aNextTime, GetNextFireTime(aNow));
1317 SignalAnyStateChange();
1318 }
1319
StateToString(State aState)1320 const char *Dso::Connection::StateToString(State aState)
1321 {
1322 static const char *const kStateStrings[] = {
1323 "Disconnected", // (0) kStateDisconnected,
1324 "Connecting", // (1) kStateConnecting,
1325 "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1326 "EstablishingSession", // (3) kStateEstablishingSession,
1327 "SessionEstablished", // (4) kStateSessionEstablished,
1328 };
1329
1330 static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1331 static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1332 static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1333 static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1334 static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1335
1336 return kStateStrings[aState];
1337 }
1338
MessageTypeToString(MessageType aMessageType)1339 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1340 {
1341 static const char *const kMessageTypeStrings[] = {
1342 "Request", // (0) kRequestMessage
1343 "Response", // (1) kResponseMessage
1344 "Unidirectional", // (2) kUnidirectionalMessage
1345 };
1346
1347 static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1348 static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1349 static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1350
1351 return kMessageTypeStrings[aMessageType];
1352 }
1353
DisconnectReasonToString(DisconnectReason aReason)1354 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1355 {
1356 static const char *const kDisconnectReasonStrings[] = {
1357 "FailedToConnect", // (0) kReasonFailedToConnect
1358 "ResponseTimeout", // (1) kReasonResponseTimeout
1359 "PeerDoesNotSupportDso", // (2) kReasonPeerDoesNotSupportDso
1360 "PeerClosed", // (3) kReasonPeerClosed
1361 "PeerAborted", // (4) kReasonPeerAborted
1362 "InactivityTimeout", // (5) kReasonInactivityTimeout
1363 "KeepAliveTimeout", // (6) kReasonKeepAliveTimeout
1364 "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1365 "PeerMisbehavior", // (8) kReasonPeerMisbehavior
1366 "Unknown", // (9) kReasonUnknown
1367 };
1368
1369 static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1370 static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1371 static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1372 static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1373 static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1374 static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1375 static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1376 static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1377 static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1378 static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1379
1380 return kDisconnectReasonStrings[aReason];
1381 }
1382
1383 //---------------------------------------------------------------------------------------------------------------------
1384 // Dso::Connection::PendingRequests
1385
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1386 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1387 {
1388 bool contains = true;
1389 const Entry *entry = mRequests.FindMatching(aMessageId);
1390
1391 VerifyOrExit(entry != nullptr, contains = false);
1392 aPrimaryTlvType = entry->mPrimaryTlvType;
1393
1394 exit:
1395 return contains;
1396 }
1397
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1398 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1399 {
1400 Error error = kErrorNone;
1401 Entry *entry = mRequests.PushBack();
1402
1403 VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1404 entry->mMessageId = aMessageId;
1405 entry->mPrimaryTlvType = aPrimaryTlvType;
1406 entry->mTimeout = aResponseTimeout;
1407
1408 exit:
1409 return error;
1410 }
1411
Remove(MessageId aMessageId)1412 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); }
1413
HasAnyTimedOut(TimeMilli aNow) const1414 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1415 {
1416 bool timedOut = false;
1417
1418 for (const Entry &entry : mRequests)
1419 {
1420 if (entry.mTimeout <= aNow)
1421 {
1422 timedOut = true;
1423 break;
1424 }
1425 }
1426
1427 return timedOut;
1428 }
1429
GetNextFireTime(TimeMilli aNow) const1430 TimeMilli Dso::Connection::PendingRequests::GetNextFireTime(TimeMilli aNow) const
1431 {
1432 TimeMilli nextTime = aNow.GetDistantFuture();
1433
1434 for (const Entry &entry : mRequests)
1435 {
1436 VerifyOrExit(entry.mTimeout > aNow, nextTime = aNow);
1437 nextTime = Min(entry.mTimeout, nextTime);
1438 }
1439
1440 exit:
1441 return nextTime;
1442 }
1443
1444 //---------------------------------------------------------------------------------------------------------------------
1445 // Dso
1446
Dso(Instance & aInstance)1447 Dso::Dso(Instance &aInstance)
1448 : InstanceLocator(aInstance)
1449 , mAcceptHandler(nullptr)
1450 , mTimer(aInstance)
1451 {
1452 }
1453
StartListening(AcceptHandler aAcceptHandler)1454 void Dso::StartListening(AcceptHandler aAcceptHandler)
1455 {
1456 mAcceptHandler = aAcceptHandler;
1457 otPlatDsoEnableListening(&GetInstance(), true);
1458 }
1459
StopListening(void)1460 void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); }
1461
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1462 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1463 {
1464 return mClientConnections.FindMatching(aPeerSockAddr);
1465 }
1466
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1467 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1468 {
1469 return mServerConnections.FindMatching(aPeerSockAddr);
1470 }
1471
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1472 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1473 {
1474 Connection *connection = nullptr;
1475
1476 VerifyOrExit(mAcceptHandler != nullptr);
1477 connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1478
1479 VerifyOrExit(connection != nullptr);
1480 connection->Accept();
1481
1482 exit:
1483 return connection;
1484 }
1485
HandleTimer(void)1486 void Dso::HandleTimer(void)
1487 {
1488 TimeMilli now = TimerMilli::GetNow();
1489 TimeMilli nextTime = now.GetDistantFuture();
1490 Connection *conn;
1491 Connection *next;
1492
1493 for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1494 {
1495 next = conn->GetNext();
1496 conn->HandleTimer(now, nextTime);
1497 }
1498
1499 for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1500 {
1501 next = conn->GetNext();
1502 conn->HandleTimer(now, nextTime);
1503 }
1504
1505 if (nextTime != now.GetDistantFuture())
1506 {
1507 mTimer.FireAtIfEarlier(nextTime);
1508 }
1509 }
1510
1511 } // namespace Dns
1512 } // namespace ot
1513
1514 #if OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1515
otPlatDsoEnableListening(otInstance * aInstance,bool aEnable)1516 OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable)
1517 {
1518 OT_UNUSED_VARIABLE(aInstance);
1519 OT_UNUSED_VARIABLE(aEnable);
1520 }
1521
otPlatDsoConnect(otPlatDsoConnection * aConnection,const otSockAddr * aPeerSockAddr)1522 OT_TOOL_WEAK void otPlatDsoConnect(otPlatDsoConnection *aConnection, const otSockAddr *aPeerSockAddr)
1523 {
1524 OT_UNUSED_VARIABLE(aConnection);
1525 OT_UNUSED_VARIABLE(aPeerSockAddr);
1526 }
1527
otPlatDsoSend(otPlatDsoConnection * aConnection,otMessage * aMessage)1528 OT_TOOL_WEAK void otPlatDsoSend(otPlatDsoConnection *aConnection, otMessage *aMessage)
1529 {
1530 OT_UNUSED_VARIABLE(aConnection);
1531 OT_UNUSED_VARIABLE(aMessage);
1532 }
1533
otPlatDsoDisconnect(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)1534 OT_TOOL_WEAK void otPlatDsoDisconnect(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
1535 {
1536 OT_UNUSED_VARIABLE(aConnection);
1537 OT_UNUSED_VARIABLE(aMode);
1538 }
1539
1540 #endif // OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1541
1542 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1543