1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_dso.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "common/num_utils.hpp"
41 #include "common/random.hpp"
42 
43 /**
44  * @file
45  *   This file implements the DNS Stateful Operations (DSO) per RFC 8490.
46  */
47 
48 namespace ot {
49 namespace Dns {
50 
51 RegisterLogModule("DnsDso");
52 
53 //---------------------------------------------------------------------------------------------------------------------
54 // otPlatDso transport callbacks
55 
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)56 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
57 {
58     return &AsCoreType(aConnection).GetInstance();
59 }
60 
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)61 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
62 {
63     return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
64 }
65 
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)66 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
67 {
68     AsCoreType(aConnection).HandleConnected();
69 }
70 
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)71 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
72 {
73     AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
74 }
75 
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)76 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
77 {
78     AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
79 }
80 
81 //---------------------------------------------------------------------------------------------------------------------
82 // Dso::Connection
83 
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)84 Dso::Connection::Connection(Instance            &aInstance,
85                             const Ip6::SockAddr &aPeerSockAddr,
86                             Callbacks           &aCallbacks,
87                             uint32_t             aInactivityTimeout,
88                             uint32_t             aKeepAliveInterval)
89     : InstanceLocator(aInstance)
90     , mNext(nullptr)
91     , mCallbacks(aCallbacks)
92     , mPeerSockAddr(aPeerSockAddr)
93     , mState(kStateDisconnected)
94     , mIsServer(false)
95     , mInactivity(aInactivityTimeout)
96     , mKeepAlive(aKeepAliveInterval)
97 {
98     OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
99     Init(/* aIsServer */ false);
100 }
101 
Init(bool aIsServer)102 void Dso::Connection::Init(bool aIsServer)
103 {
104     mNextMessageId       = 1;
105     mIsServer            = aIsServer;
106     mStateDidChange      = false;
107     mLongLivedOperation  = false;
108     mRetryDelay          = 0;
109     mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
110     mDisconnectReason    = kReasonUnknown;
111 }
112 
SetState(State aState)113 void Dso::Connection::SetState(State aState)
114 {
115     VerifyOrExit(mState != aState);
116 
117     LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
118             mPeerSockAddr.ToString().AsCString());
119 
120     mState          = aState;
121     mStateDidChange = true;
122 
123 exit:
124     return;
125 }
126 
SignalAnyStateChange(void)127 void Dso::Connection::SignalAnyStateChange(void)
128 {
129     VerifyOrExit(mStateDidChange);
130     mStateDidChange = false;
131 
132     switch (mState)
133     {
134     case kStateDisconnected:
135         mCallbacks.mHandleDisconnected(*this);
136         break;
137 
138     case kStateConnectedButSessionless:
139         mCallbacks.mHandleConnected(*this);
140         break;
141 
142     case kStateSessionEstablished:
143         mCallbacks.mHandleSessionEstablished(*this);
144         break;
145 
146     case kStateConnecting:
147     case kStateEstablishingSession:
148         break;
149     };
150 
151 exit:
152     return;
153 }
154 
NewMessage(void)155 Message *Dso::Connection::NewMessage(void)
156 {
157     return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
158                                        Message::Settings(Message::kPriorityNormal));
159 }
160 
Connect(void)161 void Dso::Connection::Connect(void)
162 {
163     OT_ASSERT(mState == kStateDisconnected);
164 
165     Init(/* aIsServer */ false);
166     Get<Dso>().mClientConnections.Push(*this);
167     MarkAsConnecting();
168     otPlatDsoConnect(this, &mPeerSockAddr);
169 }
170 
Accept(void)171 void Dso::Connection::Accept(void)
172 {
173     OT_ASSERT(mState == kStateDisconnected);
174 
175     Init(/* aIsServer */ true);
176     Get<Dso>().mServerConnections.Push(*this);
177     MarkAsConnecting();
178 }
179 
MarkAsConnecting(void)180 void Dso::Connection::MarkAsConnecting(void)
181 {
182     SetState(kStateConnecting);
183 
184     // While in `kStateConnecting` state we use the `mKeepAlive` to
185     // track the `kConnectingTimeout` (if connection is not established
186     // within the timeout, we consider it as failure and close it).
187 
188     mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
189     Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
190 
191     // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
192     // or timeout.
193 }
194 
HandleConnected(void)195 void Dso::Connection::HandleConnected(void)
196 {
197     OT_ASSERT(mState == kStateConnecting);
198 
199     SetState(kStateConnectedButSessionless);
200     ResetTimeouts(/* aIsKeepAliveMessage */ false);
201 
202     SignalAnyStateChange();
203 }
204 
Disconnect(DisconnectMode aMode,DisconnectReason aReason)205 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
206 {
207     VerifyOrExit(mState != kStateDisconnected);
208 
209     mDisconnectReason = aReason;
210     MarkAsDisconnected();
211 
212     otPlatDsoDisconnect(this, MapEnum(aMode));
213 
214 exit:
215     return;
216 }
217 
HandleDisconnected(DisconnectMode aMode)218 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
219 {
220     VerifyOrExit(mState != kStateDisconnected);
221 
222     if (mState == kStateConnecting)
223     {
224         mDisconnectReason = kReasonFailedToConnect;
225     }
226     else
227     {
228         switch (aMode)
229         {
230         case kGracefullyClose:
231             mDisconnectReason = kReasonPeerClosed;
232             break;
233 
234         case kForciblyAbort:
235             mDisconnectReason = kReasonPeerAborted;
236         }
237     }
238 
239     MarkAsDisconnected();
240     SignalAnyStateChange();
241 
242 exit:
243     return;
244 }
245 
MarkAsDisconnected(void)246 void Dso::Connection::MarkAsDisconnected(void)
247 {
248     if (IsClient())
249     {
250         IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
251     }
252     else
253     {
254         IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
255     }
256 
257     mPendingRequests.Clear();
258     SetState(kStateDisconnected);
259 
260     LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
261 }
262 
MarkSessionEstablished(void)263 void Dso::Connection::MarkSessionEstablished(void)
264 {
265     switch (mState)
266     {
267     case kStateConnectedButSessionless:
268     case kStateEstablishingSession:
269     case kStateSessionEstablished:
270         break;
271 
272     case kStateDisconnected:
273     case kStateConnecting:
274         OT_ASSERT(false);
275     }
276 
277     SetState(kStateSessionEstablished);
278 }
279 
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)280 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
281 {
282     return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
283 }
284 
SendUnidirectionalMessage(Message & aMessage)285 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
286 {
287     MessageId messageId = 0;
288 
289     return SendMessage(aMessage, kUnidirectionalMessage, messageId);
290 }
291 
SendResponseMessage(Message & aMessage,MessageId aResponseId)292 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
293 {
294     return SendMessage(aMessage, kResponseMessage, aResponseId);
295 }
296 
SetLongLivedOperation(bool aLongLivedOperation)297 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
298 {
299     VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
300 
301     mLongLivedOperation = aLongLivedOperation;
302 
303     LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
304 
305     if (!mLongLivedOperation)
306     {
307         TimeMilli now = TimerMilli::GetNow();
308         TimeMilli nextTime;
309 
310         nextTime = GetNextFireTime(now);
311 
312         if (nextTime != now.GetDistantFuture())
313         {
314             Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
315         }
316     }
317 
318 exit:
319     return;
320 }
321 
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)322 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
323 {
324     Error         error   = kErrorNone;
325     Message      *message = nullptr;
326     RetryDelayTlv retryDelayTlv;
327     MessageId     messageId;
328 
329     switch (mState)
330     {
331     case kStateSessionEstablished:
332         OT_ASSERT(IsServer());
333         break;
334 
335     case kStateConnectedButSessionless:
336     case kStateEstablishingSession:
337     case kStateDisconnected:
338     case kStateConnecting:
339         OT_ASSERT(false);
340     }
341 
342     message = NewMessage();
343     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
344 
345     retryDelayTlv.Init();
346     retryDelayTlv.SetRetryDelay(aDelay);
347     SuccessOrExit(error = message->Append(retryDelayTlv));
348     error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
349 
350 exit:
351     FreeMessageOnError(message, error);
352     return error;
353 }
354 
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)355 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
356 {
357     Error error = kErrorNone;
358 
359     VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
360 
361     // If acting as server, the timeout values are the ones we grant
362     // to a connecting clients. If acting as client, the timeout
363     // values are what to request when sending Keep Alive message.
364     // If in `kStateDisconnected` we set both (since we don't know
365     // yet whether we are going to connect as client or server).
366 
367     if ((mState == kStateDisconnected) || IsServer())
368     {
369         mKeepAlive.SetInterval(aKeepAliveInterval);
370         AdjustInactivityTimeout(aInactivityTimeout);
371     }
372 
373     if ((mState == kStateDisconnected) || IsClient())
374     {
375         mKeepAlive.SetRequestInterval(aKeepAliveInterval);
376         mInactivity.SetRequestInterval(aInactivityTimeout);
377     }
378 
379     switch (mState)
380     {
381     case kStateDisconnected:
382     case kStateConnecting:
383         break;
384 
385     case kStateConnectedButSessionless:
386     case kStateEstablishingSession:
387         if (IsServer())
388         {
389             break;
390         }
391 
392         OT_FALL_THROUGH;
393 
394     case kStateSessionEstablished:
395         error = SendKeepAliveMessage();
396     }
397 
398 exit:
399     return error;
400 }
401 
SendKeepAliveMessage(void)402 Error Dso::Connection::SendKeepAliveMessage(void)
403 {
404     return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
405 }
406 
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)407 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
408 {
409     // Sends a Keep Alive message of a given type. This is a common
410     // method used by both client and server. `aResponseId` is
411     // applicable and used only when the message type is
412     // `kResponseMessage`.
413 
414     Error        error   = kErrorNone;
415     Message     *message = nullptr;
416     KeepAliveTlv keepAliveTlv;
417 
418     switch (mState)
419     {
420     case kStateConnectedButSessionless:
421     case kStateEstablishingSession:
422         if (IsServer())
423         {
424             // While session is being established, server is only allowed
425             // to send a Keep Alive response to a request from client.
426             OT_ASSERT(aMessageType == kResponseMessage);
427         }
428         break;
429 
430     case kStateSessionEstablished:
431         break;
432 
433     case kStateDisconnected:
434     case kStateConnecting:
435         OT_ASSERT(false);
436     }
437 
438     // Server can send Keep Alive response (to a request from client)
439     // or a unidirectional Keep Alive message. Client can send
440     // KeepAlive request message.
441 
442     if (IsServer())
443     {
444         if (aMessageType == kResponseMessage)
445         {
446             OT_ASSERT(aResponseId != 0);
447         }
448         else
449         {
450             OT_ASSERT(aMessageType == kUnidirectionalMessage);
451         }
452     }
453     else
454     {
455         OT_ASSERT(aMessageType == kRequestMessage);
456     }
457 
458     message = NewMessage();
459     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
460 
461     keepAliveTlv.Init();
462 
463     if (IsServer())
464     {
465         keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
466         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
467     }
468     else
469     {
470         keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
471         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
472     }
473 
474     SuccessOrExit(error = message->Append(keepAliveTlv));
475 
476     error = SendMessage(*message, aMessageType, aResponseId);
477 
478 exit:
479     FreeMessageOnError(message, error);
480     return error;
481 }
482 
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)483 Error Dso::Connection::SendMessage(Message              &aMessage,
484                                    MessageType           aMessageType,
485                                    MessageId            &aMessageId,
486                                    Dns::Header::Response aResponseCode,
487                                    uint32_t              aResponseTimeout)
488 {
489     Error       error          = kErrorNone;
490     Tlv::Type   primaryTlvType = Tlv::kReservedType;
491     Dns::Header header;
492 
493     switch (mState)
494     {
495     case kStateConnectedButSessionless:
496         // To establish session, client MUST send a request message.
497         // Server is not allowed to send any messages. Unidirectional
498         // messages are not allowed before session is established.
499         OT_ASSERT(IsClient());
500         OT_ASSERT(aMessageType == kRequestMessage);
501         break;
502 
503     case kStateEstablishingSession:
504         // During session establishment, client is allowed to send
505         // additional request messages, server is only allowed to
506         // send response.
507         if (IsClient())
508         {
509             OT_ASSERT(aMessageType == kRequestMessage);
510         }
511         else
512         {
513             OT_ASSERT(aMessageType == kResponseMessage);
514         }
515         break;
516 
517     case kStateSessionEstablished:
518         // All message types are allowed.
519         break;
520 
521     case kStateDisconnected:
522     case kStateConnecting:
523         OT_ASSERT(false);
524     }
525 
526     // A DSO request or unidirectional message MUST contain at
527     // least one TLV. The first TLV is the "Primary TLV" and
528     // determines the nature of the operation being performed.
529     // A DSO response message may contain no TLVs, or may contain
530     // one or more TLVs. Response Primary TLV(s) MUST appear first
531     // in a DSO response message.
532 
533     aMessage.SetOffset(0);
534     IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
535 
536     switch (aMessageType)
537     {
538     case kResponseMessage:
539         break;
540     case kRequestMessage:
541     case kUnidirectionalMessage:
542         OT_ASSERT(primaryTlvType != Tlv::kReservedType);
543     }
544 
545     // `header` is cleared from its constructor call so all fields
546     // start as zero.
547 
548     switch (aMessageType)
549     {
550     case kRequestMessage:
551         header.SetType(Dns::Header::kTypeQuery);
552         aMessageId = mNextMessageId;
553         break;
554 
555     case kResponseMessage:
556         header.SetType(Dns::Header::kTypeResponse);
557         break;
558 
559     case kUnidirectionalMessage:
560         header.SetType(Dns::Header::kTypeQuery);
561         aMessageId = 0;
562         break;
563     }
564 
565     header.SetMessageId(aMessageId);
566     header.SetQueryType(Dns::Header::kQueryTypeDso);
567     header.SetResponseCode(aResponseCode);
568     SuccessOrExit(error = aMessage.Prepend(header));
569 
570     SuccessOrExit(error = AppendPadding(aMessage));
571 
572     // Update `mPendingRequests` list with the new request info
573 
574     if (aMessageType == kRequestMessage)
575     {
576         SuccessOrExit(
577             error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
578 
579         if (++mNextMessageId == 0)
580         {
581             mNextMessageId = 1;
582         }
583     }
584 
585     LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
586             mPeerSockAddr.ToString().AsCString());
587 
588     switch (mState)
589     {
590     case kStateConnectedButSessionless:
591         // On client we transition from "connected" state to
592         // "establishing session" state on successfully sending a
593         // request message.
594         if (IsClient())
595         {
596             SetState(kStateEstablishingSession);
597         }
598         break;
599 
600     case kStateEstablishingSession:
601         // On server we transition from "establishing session" state
602         // to "established" on sending a response with success
603         // response code.
604         if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
605         {
606             SetState(kStateSessionEstablished);
607         }
608 
609     default:
610         break;
611     }
612 
613     ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
614 
615     otPlatDsoSend(this, &aMessage);
616 
617     // Signal any state changes. This is done at the very end when the
618     // `SendMessage()` is fully processed (all state and local
619     // variables are updated) to ensure that we do not have any
620     // reentrancy issues (e.g., if the callback signalling state
621     // change triggers another tx).
622 
623     SignalAnyStateChange();
624 
625 exit:
626     return error;
627 }
628 
AppendPadding(Message & aMessage)629 Error Dso::Connection::AppendPadding(Message &aMessage)
630 {
631     // This method appends Encryption Padding TLV to a DSO message.
632     // It uses the padding policy "Random-Block-Length Padding" from
633     // RFC 8467.
634 
635     static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
636 
637     Error                error = kErrorNone;
638     uint16_t             blockLength;
639     EncryptionPaddingTlv paddingTlv;
640 
641     // We pick a random block length. The random selection can be
642     // based on a "weak" source of randomness (so the use of
643     // `NonCrypto` is fine). We add padding to the message such
644     // that its padded length is a multiple of the chosen block
645     // length.
646 
647     blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
648 
649     paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
650 
651     SuccessOrExit(error = aMessage.Append(paddingTlv));
652 
653     for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
654     {
655         SuccessOrExit(error = aMessage.Append<uint8_t>(0));
656     }
657 
658 exit:
659     return error;
660 }
661 
HandleReceive(Message & aMessage)662 void Dso::Connection::HandleReceive(Message &aMessage)
663 {
664     Error       error          = kErrorAbort;
665     Tlv::Type   primaryTlvType = Tlv::kReservedType;
666     Dns::Header header;
667 
668     SuccessOrExit(aMessage.Read(0, header));
669 
670     if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
671     {
672         if (header.GetType() == Dns::Header::kTypeQuery)
673         {
674             SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
675             error = kErrorNone;
676         }
677 
678         ExitNow();
679     }
680 
681     switch (mState)
682     {
683     case kStateConnectedButSessionless:
684         // After connection is established, client should initiate
685         // establishing session (by sending a request). So no rx is
686         // allowed before this. On server, we allow rx of a request
687         // message only.
688         VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
689         break;
690 
691     case kStateEstablishingSession:
692         // Unidirectional message are allowed after session is
693         // established. While session is being established, on client,
694         // we allow rx on response message. On server we can rx
695         // request or response.
696 
697         VerifyOrExit(header.GetMessageId() != 0);
698 
699         if (IsClient())
700         {
701             VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
702         }
703         break;
704 
705     case kStateSessionEstablished:
706         // All message types are allowed.
707         break;
708 
709     case kStateDisconnected:
710     case kStateConnecting:
711         ExitNow();
712     }
713 
714     // All count fields MUST be set to zero in the header.
715     VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
716                  (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
717 
718     aMessage.SetOffset(sizeof(header));
719 
720     switch (ReadPrimaryTlv(aMessage, primaryTlvType))
721     {
722     case kErrorNone:
723         VerifyOrExit(primaryTlvType != Tlv::kReservedType);
724         break;
725 
726     case kErrorNotFound:
727         // The `primaryTlvType` is set to `Tlv::kReservedType`
728         // (value zero) to indicate that there is no primary TLV.
729         break;
730 
731     default:
732         ExitNow();
733     }
734 
735     switch (header.GetType())
736     {
737     case Dns::Header::kTypeQuery:
738         error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
739         break;
740 
741     case Dns::Header::kTypeResponse:
742         error = ProcessResponseMessage(header, aMessage, primaryTlvType);
743         break;
744     }
745 
746 exit:
747     aMessage.Free();
748 
749     if (error == kErrorNone)
750     {
751         ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
752     }
753     else
754     {
755         Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
756     }
757 
758     // We signal any state change at the very end when the received
759     // message is fully processed (all state and local variables are
760     // updated) to ensure that we do not have any reentrancy issues
761     // (e.g., if a `Connection` method happens to be called from the
762     // callback).
763 
764     SignalAnyStateChange();
765 }
766 
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const767 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
768 {
769     // Read and validate the primary TLV (first TLV  after the header).
770     // The `aMessage.GetOffset()` must point to the first TLV. If no
771     // TLV then `kErrorNotFound` is returned. If TLV in message is not
772     // well-formed `kErrorParse` is returned. The read TLV type is
773     // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
774     // (value zero) when `kErrorNotFound`).
775 
776     Error error = kErrorNotFound;
777     Tlv   tlv;
778 
779     aPrimaryTlvType = Tlv::kReservedType;
780 
781     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
782     VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
783     aPrimaryTlvType = tlv.GetType();
784     error           = kErrorNone;
785 
786 exit:
787     return error;
788 }
789 
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)790 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
791                                                              const Message     &aMessage,
792                                                              Tlv::Type          aPrimaryTlvType)
793 {
794     Error error = kErrorAbort;
795 
796     if (IsServer() && (mState == kStateConnectedButSessionless))
797     {
798         SetState(kStateEstablishingSession);
799     }
800 
801     // A DSO request or unidirectional message MUST contain at
802     // least one TLV which is the "Primary TLV" and determines
803     // the nature of the operation being performed.
804 
805     switch (aPrimaryTlvType)
806     {
807     case KeepAliveTlv::kType:
808         error = ProcessKeepAliveMessage(aHeader, aMessage);
809         break;
810 
811     case RetryDelayTlv::kType:
812         error = ProcessRetryDelayMessage(aHeader, aMessage);
813         break;
814 
815     case Tlv::kReservedType:
816     case EncryptionPaddingTlv::kType:
817         // Misbehavior by peer.
818         break;
819 
820     default:
821         if (aHeader.GetMessageId() == 0)
822         {
823             LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
824 
825             error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
826         }
827         else
828         {
829             MessageId messageId = aHeader.GetMessageId();
830 
831             LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
832 
833             error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
834 
835             // `kErrorNotFound` indicates that TLV type is not known.
836 
837             if (error == kErrorNotFound)
838             {
839                 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
840                 error = kErrorNone;
841             }
842         }
843         break;
844     }
845 
846     return error;
847 }
848 
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)849 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
850                                               const Message     &aMessage,
851                                               Tlv::Type          aPrimaryTlvType)
852 {
853     Error     error = kErrorAbort;
854     Tlv::Type requestPrimaryTlvType;
855 
856     // If a client or server receives a response where the message
857     // ID is zero, or is any other value that does not match the
858     // message ID of any of its outstanding operations, this is a
859     // fatal error and the recipient MUST forcibly abort the
860     // connection immediately.
861 
862     VerifyOrExit(aHeader.GetMessageId() != 0);
863     VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
864 
865     // If the response has no error and contains a primary TLV, it
866     // MUST match the request primary TLV.
867 
868     if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
869     {
870         VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
871     }
872 
873     mPendingRequests.Remove(aHeader.GetMessageId());
874 
875     switch (requestPrimaryTlvType)
876     {
877     case KeepAliveTlv::kType:
878         SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
879         break;
880 
881     default:
882         SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
883                                                                  requestPrimaryTlvType));
884         break;
885     }
886 
887     // DSO session is established when client sends a request message
888     // and receives a response from server with no error code.
889 
890     if (IsClient() && (mState == kStateEstablishingSession) &&
891         (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
892     {
893         SetState(kStateSessionEstablished);
894     }
895 
896 exit:
897     return error;
898 }
899 
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)900 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
901 {
902     Error        error  = kErrorAbort;
903     uint16_t     offset = aMessage.GetOffset();
904     Tlv          tlv;
905     KeepAliveTlv keepAliveTlv;
906 
907     if (aHeader.GetType() == Dns::Header::kTypeResponse)
908     {
909         // A Keep Alive response message is allowed on a client from a sever.
910 
911         VerifyOrExit(IsClient());
912 
913         if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
914         {
915             // We got an error response code from server for our
916             // Keep Alive request message. If this happens while
917             // establishing the DSO session, it indicates that server
918             // does not support DSO, so we close the connection. If
919             // this happens while session is already established, it
920             // is a misbehavior (fatal error) by server.
921 
922             if (mState == kStateEstablishingSession)
923             {
924                 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
925                 error = kErrorNone;
926             }
927 
928             ExitNow();
929         }
930     }
931 
932     // Parse and validate the Keep Alive Message
933 
934     SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
935     offset += keepAliveTlv.GetSize();
936 
937     VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
938 
939     // Keep Alive message MUST contain only one Keep Alive TLV.
940 
941     while (offset < aMessage.GetLength())
942     {
943         SuccessOrExit(aMessage.Read(offset, tlv));
944         offset += tlv.GetSize();
945 
946         VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
947     }
948 
949     VerifyOrExit(offset == aMessage.GetLength());
950 
951     if (aHeader.GetType() == Dns::Header::kTypeQuery)
952     {
953         if (IsServer())
954         {
955             // Received a Keep Alive message from client. It MUST
956             // be a request message (not unidirectional). We prepare
957             // and send a Keep Alive response.
958 
959             VerifyOrExit(aHeader.GetMessageId() != 0);
960 
961             LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
962 
963             IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
964             error = kErrorNone;
965             ExitNow();
966         }
967 
968         // Received a Keep Alive message on client from server. Server
969         // Keep Alive message MUST be unidirectional (message ID
970         // zero).
971 
972         VerifyOrExit(aHeader.GetMessageId() == 0);
973     }
974 
975     LogInfo("Received Keep Alive %s message from server %s",
976             (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
977 
978     // Receiving a Keep Alive interval value from server less than the
979     // minimum (ten seconds) is a fatal error and client MUST then
980     // abort the connection.
981 
982     VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
983 
984     // Update the timeout intervals on the connection from
985     // the new values we got from the server. The receive
986     // of the Keep Alive message does not itself reset the
987     // inactivity timer. So we use `AdjustInactivityTimeout`
988     // which takes into account the time elapsed since the
989     // last activity.
990 
991     AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
992     mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
993 
994     LogInfo("Timeouts Inactivity:%lu, KeepAlive:%lu", ToUlong(mInactivity.GetInterval()),
995             ToUlong(mKeepAlive.GetInterval()));
996 
997     error = kErrorNone;
998 
999 exit:
1000     return error;
1001 }
1002 
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)1003 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
1004 
1005 {
1006     Error         error = kErrorAbort;
1007     RetryDelayTlv retryDelayTlv;
1008 
1009     // Retry Delay TLV can be used as the Primary TLV only in
1010     // a unidirectional message sent from server to client.
1011     // It is used by the server to instruct the client to
1012     // close the session and its underlying connection, and not
1013     // to reconnect for the indicated time interval.
1014 
1015     VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1016 
1017     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1018     VerifyOrExit(retryDelayTlv.IsValid());
1019 
1020     mRetryDelayErrorCode = aHeader.GetResponseCode();
1021     mRetryDelay          = retryDelayTlv.GetRetryDelay();
1022 
1023     LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1024     LogInfo("   RetryDelay:%lu ms, ResponseCode:%d", ToUlong(mRetryDelay), mRetryDelayErrorCode);
1025 
1026     Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1027 
1028 exit:
1029     return error;
1030 }
1031 
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1032 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1033 {
1034     Message    *response = NewMessage();
1035     Dns::Header header;
1036 
1037     VerifyOrExit(response != nullptr);
1038 
1039     header.SetMessageId(aHeader.GetMessageId());
1040     header.SetType(Dns::Header::kTypeResponse);
1041     header.SetQueryType(aHeader.GetQueryType());
1042     header.SetResponseCode(aResponseCode);
1043 
1044     SuccessOrExit(response->Prepend(header));
1045 
1046     otPlatDsoSend(this, response);
1047     response = nullptr;
1048 
1049 exit:
1050     FreeMessage(response);
1051 }
1052 
AdjustInactivityTimeout(uint32_t aNewTimeout)1053 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1054 {
1055     // This method sets the inactivity timeout interval to a new value
1056     // and updates the expiration time based on the new timeout value.
1057     //
1058     // On client, it is called on receiving a Keep Alive response or
1059     // unidirectional message from server. Note that the receive of
1060     // the Keep Alive message does not itself reset the inactivity
1061     // timer. So the time elapsed since the last activity should be
1062     // taken into account with the new inactivity timeout value.
1063     //
1064     // On server this method is called from `SetTimeouts()` when a new
1065     // inactivity timeout value is set.
1066 
1067     TimeMilli now = TimerMilli::GetNow();
1068     TimeMilli start;
1069     TimeMilli newExpiration;
1070 
1071     if (mState == kStateDisconnected)
1072     {
1073         mInactivity.SetInterval(aNewTimeout);
1074         ExitNow();
1075     }
1076 
1077     VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1078 
1079     // Calculate the start time (i.e., the last time inactivity timer
1080     // was cleared). If the previous inactivity time is set to
1081     // `kInfinite` value (`IsUsed()` returns `false`) then
1082     // `GetExpirationTime()` returns the start time. Otherwise, we
1083     // calculate it going back from the current expiration time with
1084     // the current wait interval.
1085 
1086     if (!mInactivity.IsUsed())
1087     {
1088         start = mInactivity.GetExpirationTime();
1089     }
1090     else if (IsClient())
1091     {
1092         start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1093     }
1094     else
1095     {
1096         start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1097     }
1098 
1099     mInactivity.SetInterval(aNewTimeout);
1100 
1101     if (!mInactivity.IsUsed())
1102     {
1103         newExpiration = start;
1104     }
1105     else if (IsClient())
1106     {
1107         newExpiration = start + aNewTimeout;
1108 
1109         if (newExpiration < now)
1110         {
1111             newExpiration = now;
1112         }
1113     }
1114     else
1115     {
1116         newExpiration = start + CalculateServerInactivityWaitTime();
1117 
1118         if (newExpiration < now)
1119         {
1120             // If the server abruptly reduces the inactivity timeout
1121             // such that current elapsed time is already more than
1122             // twice the new inactivity timeout, then the client is
1123             // immediately considered delinquent (server can forcibly
1124             // abort the connection). So to give the client time to
1125             // close the connection gracefully, the server SHOULD
1126             // give the client an additional grace period of either
1127             // five seconds or one quarter of the new inactivity
1128             // timeout, whichever is greater [RFC 8490 - 7.1.1].
1129 
1130             newExpiration = now + Max(kMinServerInactivityWaitTime, aNewTimeout / 4);
1131         }
1132     }
1133 
1134     mInactivity.SetExpirationTime(newExpiration);
1135 
1136 exit:
1137     return;
1138 }
1139 
CalculateServerInactivityWaitTime(void) const1140 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1141 {
1142     // A server will abort an idle session after five seconds
1143     // (`kMinServerInactivityWaitTime`) or twice the inactivity
1144     // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1145 
1146     OT_ASSERT(mInactivity.IsUsed());
1147 
1148     return Max(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1149 }
1150 
ResetTimeouts(bool aIsKeepAliveMessage)1151 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1152 {
1153     TimeMilli now = TimerMilli::GetNow();
1154     TimeMilli nextTime;
1155 
1156     // At both servers and clients, the generation or reception of any
1157     // complete DNS message resets both timers for that DSO
1158     // session, with the one exception being that a DSO Keep Alive
1159     // message resets only the keep alive timer, not the inactivity
1160     // timeout timer [RFC 8490 - 6.3]
1161 
1162     if (mKeepAlive.IsUsed())
1163     {
1164         // On client, we wait for the Keep Alive interval but on server
1165         // we wait for twice the interval before considering Keep Alive
1166         // timeout.
1167         //
1168         // Note that we limit the interval to `Timeout::kMaxInterval`
1169         // (which is ~12 days). This max limit ensures that even twice
1170         // the interval is less than max OpenThread timer duration so
1171         // that the expiration time calculations below stay within the
1172         // `TimerMilli` range.
1173 
1174         mKeepAlive.SetExpirationTime(now + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1175     }
1176 
1177     if (!aIsKeepAliveMessage)
1178     {
1179         if (mInactivity.IsUsed())
1180         {
1181             mInactivity.SetExpirationTime(
1182                 now + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1183         }
1184         else
1185         {
1186             // When Inactivity timeout is not used (i.e., interval is set
1187             // to the special `kInfinite` value), we still need to track
1188             // the time so that if/when later the inactivity interval
1189             // gets changed, we can adjust the remaining time correctly
1190             // from `AdjustInactivityTimeout()`. In this case, we just
1191             // track the current time as "expiration time".
1192 
1193             mInactivity.SetExpirationTime(now);
1194         }
1195     }
1196 
1197     nextTime = GetNextFireTime(now);
1198 
1199     if (nextTime != now.GetDistantFuture())
1200     {
1201         Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1202     }
1203 }
1204 
GetNextFireTime(TimeMilli aNow) const1205 TimeMilli Dso::Connection::GetNextFireTime(TimeMilli aNow) const
1206 {
1207     TimeMilli nextTime = aNow.GetDistantFuture();
1208 
1209     switch (mState)
1210     {
1211     case kStateDisconnected:
1212         break;
1213 
1214     case kStateConnecting:
1215         // While in `kStateConnecting`, Keep Alive timer is
1216         // used for `kConnectingTimeout`.
1217         VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1218         nextTime = mKeepAlive.GetExpirationTime();
1219         break;
1220 
1221     case kStateConnectedButSessionless:
1222     case kStateEstablishingSession:
1223     case kStateSessionEstablished:
1224         nextTime = Min(nextTime, mPendingRequests.GetNextFireTime(aNow));
1225 
1226         if (mKeepAlive.IsUsed())
1227         {
1228             VerifyOrExit(mKeepAlive.GetExpirationTime() > aNow, nextTime = aNow);
1229             nextTime = Min(nextTime, mKeepAlive.GetExpirationTime());
1230         }
1231 
1232         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1233         {
1234             // An operation being active on a DSO Session includes
1235             // a request message waiting for a response, or an
1236             // active long-lived operation.
1237 
1238             VerifyOrExit(mInactivity.GetExpirationTime() > aNow, nextTime = aNow);
1239             nextTime = Min(nextTime, mInactivity.GetExpirationTime());
1240         }
1241 
1242         break;
1243     }
1244 
1245 exit:
1246     return nextTime;
1247 }
1248 
HandleTimer(TimeMilli aNow,TimeMilli & aNextTime)1249 void Dso::Connection::HandleTimer(TimeMilli aNow, TimeMilli &aNextTime)
1250 {
1251     switch (mState)
1252     {
1253     case kStateDisconnected:
1254         break;
1255 
1256     case kStateConnecting:
1257         if (mKeepAlive.IsExpired(aNow))
1258         {
1259             Disconnect(kGracefullyClose, kReasonFailedToConnect);
1260         }
1261         break;
1262 
1263     case kStateConnectedButSessionless:
1264     case kStateEstablishingSession:
1265     case kStateSessionEstablished:
1266         if (mPendingRequests.HasAnyTimedOut(aNow))
1267         {
1268             // If server sends no response to a request, client
1269             // waits for 30 seconds (`kResponseTimeout`) after which
1270             // client MUST forcibly abort the connection.
1271             Disconnect(kForciblyAbort, kReasonResponseTimeout);
1272             ExitNow();
1273         }
1274 
1275         // The inactivity timer is kept clear, while an operation is
1276         // active on the session (which includes a request waiting for
1277         // response or an active long-lived operation).
1278 
1279         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation && mInactivity.IsExpired(aNow))
1280         {
1281             // On client, if the inactivity timeout is reached, the
1282             // connection is closed gracefully. On server, if too much
1283             // time (`CalculateServerInactivityWaitTime()`, i.e., five
1284             // seconds or twice the current inactivity timeout interval,
1285             // whichever is grater) elapses server MUST consider the
1286             // client delinquent and MUST forcibly abort the connection.
1287 
1288             Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1289             ExitNow();
1290         }
1291 
1292         if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNow))
1293         {
1294             // On client, if the Keep Alive interval elapses without any
1295             // DNS messages being sent or received, the client MUST take
1296             // action and send a DSO Keep Alive message.
1297             //
1298             // On server, if twice the Keep Alive interval value elapses
1299             // without any messages being sent or received, the server
1300             // considers the client delinquent and aborts the connection.
1301 
1302             if (IsClient())
1303             {
1304                 IgnoreError(SendKeepAliveMessage());
1305             }
1306             else
1307             {
1308                 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1309                 ExitNow();
1310             }
1311         }
1312         break;
1313     }
1314 
1315 exit:
1316     aNextTime = Min(aNextTime, GetNextFireTime(aNow));
1317     SignalAnyStateChange();
1318 }
1319 
StateToString(State aState)1320 const char *Dso::Connection::StateToString(State aState)
1321 {
1322     static const char *const kStateStrings[] = {
1323         "Disconnected",            // (0) kStateDisconnected,
1324         "Connecting",              // (1) kStateConnecting,
1325         "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1326         "EstablishingSession",     // (3) kStateEstablishingSession,
1327         "SessionEstablished",      // (4) kStateSessionEstablished,
1328     };
1329 
1330     static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1331     static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1332     static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1333     static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1334     static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1335 
1336     return kStateStrings[aState];
1337 }
1338 
MessageTypeToString(MessageType aMessageType)1339 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1340 {
1341     static const char *const kMessageTypeStrings[] = {
1342         "Request",        // (0) kRequestMessage
1343         "Response",       // (1) kResponseMessage
1344         "Unidirectional", // (2) kUnidirectionalMessage
1345     };
1346 
1347     static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1348     static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1349     static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1350 
1351     return kMessageTypeStrings[aMessageType];
1352 }
1353 
DisconnectReasonToString(DisconnectReason aReason)1354 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1355 {
1356     static const char *const kDisconnectReasonStrings[] = {
1357         "FailedToConnect",         // (0) kReasonFailedToConnect
1358         "ResponseTimeout",         // (1) kReasonResponseTimeout
1359         "PeerDoesNotSupportDso",   // (2) kReasonPeerDoesNotSupportDso
1360         "PeerClosed",              // (3) kReasonPeerClosed
1361         "PeerAborted",             // (4) kReasonPeerAborted
1362         "InactivityTimeout",       // (5) kReasonInactivityTimeout
1363         "KeepAliveTimeout",        // (6) kReasonKeepAliveTimeout
1364         "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1365         "PeerMisbehavior",         // (8) kReasonPeerMisbehavior
1366         "Unknown",                 // (9) kReasonUnknown
1367     };
1368 
1369     static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1370     static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1371     static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1372     static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1373     static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1374     static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1375     static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1376     static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1377     static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1378     static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1379 
1380     return kDisconnectReasonStrings[aReason];
1381 }
1382 
1383 //---------------------------------------------------------------------------------------------------------------------
1384 // Dso::Connection::PendingRequests
1385 
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1386 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1387 {
1388     bool         contains = true;
1389     const Entry *entry    = mRequests.FindMatching(aMessageId);
1390 
1391     VerifyOrExit(entry != nullptr, contains = false);
1392     aPrimaryTlvType = entry->mPrimaryTlvType;
1393 
1394 exit:
1395     return contains;
1396 }
1397 
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1398 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1399 {
1400     Error  error = kErrorNone;
1401     Entry *entry = mRequests.PushBack();
1402 
1403     VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1404     entry->mMessageId      = aMessageId;
1405     entry->mPrimaryTlvType = aPrimaryTlvType;
1406     entry->mTimeout        = aResponseTimeout;
1407 
1408 exit:
1409     return error;
1410 }
1411 
Remove(MessageId aMessageId)1412 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); }
1413 
HasAnyTimedOut(TimeMilli aNow) const1414 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1415 {
1416     bool timedOut = false;
1417 
1418     for (const Entry &entry : mRequests)
1419     {
1420         if (entry.mTimeout <= aNow)
1421         {
1422             timedOut = true;
1423             break;
1424         }
1425     }
1426 
1427     return timedOut;
1428 }
1429 
GetNextFireTime(TimeMilli aNow) const1430 TimeMilli Dso::Connection::PendingRequests::GetNextFireTime(TimeMilli aNow) const
1431 {
1432     TimeMilli nextTime = aNow.GetDistantFuture();
1433 
1434     for (const Entry &entry : mRequests)
1435     {
1436         VerifyOrExit(entry.mTimeout > aNow, nextTime = aNow);
1437         nextTime = Min(entry.mTimeout, nextTime);
1438     }
1439 
1440 exit:
1441     return nextTime;
1442 }
1443 
1444 //---------------------------------------------------------------------------------------------------------------------
1445 // Dso
1446 
Dso(Instance & aInstance)1447 Dso::Dso(Instance &aInstance)
1448     : InstanceLocator(aInstance)
1449     , mAcceptHandler(nullptr)
1450     , mTimer(aInstance)
1451 {
1452 }
1453 
StartListening(AcceptHandler aAcceptHandler)1454 void Dso::StartListening(AcceptHandler aAcceptHandler)
1455 {
1456     mAcceptHandler = aAcceptHandler;
1457     otPlatDsoEnableListening(&GetInstance(), true);
1458 }
1459 
StopListening(void)1460 void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); }
1461 
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1462 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1463 {
1464     return mClientConnections.FindMatching(aPeerSockAddr);
1465 }
1466 
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1467 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1468 {
1469     return mServerConnections.FindMatching(aPeerSockAddr);
1470 }
1471 
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1472 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1473 {
1474     Connection *connection = nullptr;
1475 
1476     VerifyOrExit(mAcceptHandler != nullptr);
1477     connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1478 
1479     VerifyOrExit(connection != nullptr);
1480     connection->Accept();
1481 
1482 exit:
1483     return connection;
1484 }
1485 
HandleTimer(void)1486 void Dso::HandleTimer(void)
1487 {
1488     TimeMilli   now      = TimerMilli::GetNow();
1489     TimeMilli   nextTime = now.GetDistantFuture();
1490     Connection *conn;
1491     Connection *next;
1492 
1493     for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1494     {
1495         next = conn->GetNext();
1496         conn->HandleTimer(now, nextTime);
1497     }
1498 
1499     for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1500     {
1501         next = conn->GetNext();
1502         conn->HandleTimer(now, nextTime);
1503     }
1504 
1505     if (nextTime != now.GetDistantFuture())
1506     {
1507         mTimer.FireAtIfEarlier(nextTime);
1508     }
1509 }
1510 
1511 } // namespace Dns
1512 } // namespace ot
1513 
1514 #if OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1515 
otPlatDsoEnableListening(otInstance * aInstance,bool aEnable)1516 OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable)
1517 {
1518     OT_UNUSED_VARIABLE(aInstance);
1519     OT_UNUSED_VARIABLE(aEnable);
1520 }
1521 
otPlatDsoConnect(otPlatDsoConnection * aConnection,const otSockAddr * aPeerSockAddr)1522 OT_TOOL_WEAK void otPlatDsoConnect(otPlatDsoConnection *aConnection, const otSockAddr *aPeerSockAddr)
1523 {
1524     OT_UNUSED_VARIABLE(aConnection);
1525     OT_UNUSED_VARIABLE(aPeerSockAddr);
1526 }
1527 
otPlatDsoSend(otPlatDsoConnection * aConnection,otMessage * aMessage)1528 OT_TOOL_WEAK void otPlatDsoSend(otPlatDsoConnection *aConnection, otMessage *aMessage)
1529 {
1530     OT_UNUSED_VARIABLE(aConnection);
1531     OT_UNUSED_VARIABLE(aMessage);
1532 }
1533 
otPlatDsoDisconnect(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)1534 OT_TOOL_WEAK void otPlatDsoDisconnect(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
1535 {
1536     OT_UNUSED_VARIABLE(aConnection);
1537     OT_UNUSED_VARIABLE(aMode);
1538 }
1539 
1540 #endif // OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1541 
1542 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1543