1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_dso.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "common/num_utils.hpp"
40 #include "common/random.hpp"
41 #include "instance/instance.hpp"
42 
43 /**
44  * @file
45  *   This file implements the DNS Stateful Operations (DSO) per RFC 8490.
46  */
47 
48 namespace ot {
49 namespace Dns {
50 
51 RegisterLogModule("DnsDso");
52 
53 //---------------------------------------------------------------------------------------------------------------------
54 // otPlatDso transport callbacks
55 
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)56 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
57 {
58     return &AsCoreType(aConnection).GetInstance();
59 }
60 
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)61 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
62 {
63     return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
64 }
65 
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)66 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
67 {
68     AsCoreType(aConnection).HandleConnected();
69 }
70 
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)71 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
72 {
73     AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
74 }
75 
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)76 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
77 {
78     AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
79 }
80 
81 //---------------------------------------------------------------------------------------------------------------------
82 // Dso::Connection
83 
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)84 Dso::Connection::Connection(Instance            &aInstance,
85                             const Ip6::SockAddr &aPeerSockAddr,
86                             Callbacks           &aCallbacks,
87                             uint32_t             aInactivityTimeout,
88                             uint32_t             aKeepAliveInterval)
89     : InstanceLocator(aInstance)
90     , mNext(nullptr)
91     , mCallbacks(aCallbacks)
92     , mPeerSockAddr(aPeerSockAddr)
93     , mState(kStateDisconnected)
94     , mIsServer(false)
95     , mInactivity(aInactivityTimeout)
96     , mKeepAlive(aKeepAliveInterval)
97 {
98     OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
99     Init(/* aIsServer */ false);
100 }
101 
Init(bool aIsServer)102 void Dso::Connection::Init(bool aIsServer)
103 {
104     mNextMessageId       = 1;
105     mIsServer            = aIsServer;
106     mStateDidChange      = false;
107     mLongLivedOperation  = false;
108     mRetryDelay          = 0;
109     mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
110     mDisconnectReason    = kReasonUnknown;
111 }
112 
SetState(State aState)113 void Dso::Connection::SetState(State aState)
114 {
115     VerifyOrExit(mState != aState);
116 
117     LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
118             mPeerSockAddr.ToString().AsCString());
119 
120     mState          = aState;
121     mStateDidChange = true;
122 
123 exit:
124     return;
125 }
126 
SignalAnyStateChange(void)127 void Dso::Connection::SignalAnyStateChange(void)
128 {
129     VerifyOrExit(mStateDidChange);
130     mStateDidChange = false;
131 
132     switch (mState)
133     {
134     case kStateDisconnected:
135         mCallbacks.mHandleDisconnected(*this);
136         break;
137 
138     case kStateConnectedButSessionless:
139         mCallbacks.mHandleConnected(*this);
140         break;
141 
142     case kStateSessionEstablished:
143         mCallbacks.mHandleSessionEstablished(*this);
144         break;
145 
146     case kStateConnecting:
147     case kStateEstablishingSession:
148         break;
149     };
150 
151 exit:
152     return;
153 }
154 
NewMessage(void)155 Message *Dso::Connection::NewMessage(void)
156 {
157     return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
158                                        Message::Settings(Message::kPriorityNormal));
159 }
160 
Connect(void)161 void Dso::Connection::Connect(void)
162 {
163     OT_ASSERT(mState == kStateDisconnected);
164 
165     Init(/* aIsServer */ false);
166     Get<Dso>().mClientConnections.Push(*this);
167     MarkAsConnecting();
168     otPlatDsoConnect(this, &mPeerSockAddr);
169 }
170 
Accept(void)171 void Dso::Connection::Accept(void)
172 {
173     OT_ASSERT(mState == kStateDisconnected);
174 
175     Init(/* aIsServer */ true);
176     Get<Dso>().mServerConnections.Push(*this);
177     MarkAsConnecting();
178 }
179 
MarkAsConnecting(void)180 void Dso::Connection::MarkAsConnecting(void)
181 {
182     SetState(kStateConnecting);
183 
184     // While in `kStateConnecting` state we use the `mKeepAlive` to
185     // track the `kConnectingTimeout` (if connection is not established
186     // within the timeout, we consider it as failure and close it).
187 
188     mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
189     Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
190 
191     // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
192     // or timeout.
193 }
194 
HandleConnected(void)195 void Dso::Connection::HandleConnected(void)
196 {
197     OT_ASSERT(mState == kStateConnecting);
198 
199     SetState(kStateConnectedButSessionless);
200     ResetTimeouts(/* aIsKeepAliveMessage */ false);
201 
202     SignalAnyStateChange();
203 }
204 
Disconnect(DisconnectMode aMode,DisconnectReason aReason)205 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
206 {
207     VerifyOrExit(mState != kStateDisconnected);
208 
209     mDisconnectReason = aReason;
210     MarkAsDisconnected();
211 
212     otPlatDsoDisconnect(this, MapEnum(aMode));
213 
214 exit:
215     return;
216 }
217 
HandleDisconnected(DisconnectMode aMode)218 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
219 {
220     VerifyOrExit(mState != kStateDisconnected);
221 
222     if (mState == kStateConnecting)
223     {
224         mDisconnectReason = kReasonFailedToConnect;
225     }
226     else
227     {
228         switch (aMode)
229         {
230         case kGracefullyClose:
231             mDisconnectReason = kReasonPeerClosed;
232             break;
233 
234         case kForciblyAbort:
235             mDisconnectReason = kReasonPeerAborted;
236         }
237     }
238 
239     MarkAsDisconnected();
240     SignalAnyStateChange();
241 
242 exit:
243     return;
244 }
245 
MarkAsDisconnected(void)246 void Dso::Connection::MarkAsDisconnected(void)
247 {
248     if (IsClient())
249     {
250         IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
251     }
252     else
253     {
254         IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
255     }
256 
257     mPendingRequests.Clear();
258     SetState(kStateDisconnected);
259 
260     LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
261 }
262 
MarkSessionEstablished(void)263 void Dso::Connection::MarkSessionEstablished(void)
264 {
265     switch (mState)
266     {
267     case kStateConnectedButSessionless:
268     case kStateEstablishingSession:
269     case kStateSessionEstablished:
270         break;
271 
272     case kStateDisconnected:
273     case kStateConnecting:
274         OT_ASSERT(false);
275     }
276 
277     SetState(kStateSessionEstablished);
278 }
279 
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)280 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
281 {
282     return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
283 }
284 
SendUnidirectionalMessage(Message & aMessage)285 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
286 {
287     MessageId messageId = 0;
288 
289     return SendMessage(aMessage, kUnidirectionalMessage, messageId);
290 }
291 
SendResponseMessage(Message & aMessage,MessageId aResponseId)292 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
293 {
294     return SendMessage(aMessage, kResponseMessage, aResponseId);
295 }
296 
SetLongLivedOperation(bool aLongLivedOperation)297 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
298 {
299     VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
300 
301     mLongLivedOperation = aLongLivedOperation;
302 
303     LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
304 
305     if (!mLongLivedOperation)
306     {
307         NextFireTime nextTime;
308 
309         UpdateNextFireTime(nextTime);
310         Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
311     }
312 
313 exit:
314     return;
315 }
316 
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)317 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
318 {
319     Error         error   = kErrorNone;
320     Message      *message = nullptr;
321     RetryDelayTlv retryDelayTlv;
322     MessageId     messageId;
323 
324     switch (mState)
325     {
326     case kStateSessionEstablished:
327         OT_ASSERT(IsServer());
328         break;
329 
330     case kStateConnectedButSessionless:
331     case kStateEstablishingSession:
332     case kStateDisconnected:
333     case kStateConnecting:
334         OT_ASSERT(false);
335     }
336 
337     message = NewMessage();
338     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
339 
340     retryDelayTlv.Init();
341     retryDelayTlv.SetRetryDelay(aDelay);
342     SuccessOrExit(error = message->Append(retryDelayTlv));
343     error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
344 
345 exit:
346     FreeMessageOnError(message, error);
347     return error;
348 }
349 
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)350 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
351 {
352     Error error = kErrorNone;
353 
354     VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
355 
356     // If acting as server, the timeout values are the ones we grant
357     // to a connecting clients. If acting as client, the timeout
358     // values are what to request when sending Keep Alive message.
359     // If in `kStateDisconnected` we set both (since we don't know
360     // yet whether we are going to connect as client or server).
361 
362     if ((mState == kStateDisconnected) || IsServer())
363     {
364         mKeepAlive.SetInterval(aKeepAliveInterval);
365         AdjustInactivityTimeout(aInactivityTimeout);
366     }
367 
368     if ((mState == kStateDisconnected) || IsClient())
369     {
370         mKeepAlive.SetRequestInterval(aKeepAliveInterval);
371         mInactivity.SetRequestInterval(aInactivityTimeout);
372     }
373 
374     switch (mState)
375     {
376     case kStateDisconnected:
377     case kStateConnecting:
378         break;
379 
380     case kStateConnectedButSessionless:
381     case kStateEstablishingSession:
382         if (IsServer())
383         {
384             break;
385         }
386 
387         OT_FALL_THROUGH;
388 
389     case kStateSessionEstablished:
390         error = SendKeepAliveMessage();
391     }
392 
393 exit:
394     return error;
395 }
396 
SendKeepAliveMessage(void)397 Error Dso::Connection::SendKeepAliveMessage(void)
398 {
399     return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
400 }
401 
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)402 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
403 {
404     // Sends a Keep Alive message of a given type. This is a common
405     // method used by both client and server. `aResponseId` is
406     // applicable and used only when the message type is
407     // `kResponseMessage`.
408 
409     Error        error   = kErrorNone;
410     Message     *message = nullptr;
411     KeepAliveTlv keepAliveTlv;
412 
413     switch (mState)
414     {
415     case kStateConnectedButSessionless:
416     case kStateEstablishingSession:
417         if (IsServer())
418         {
419             // While session is being established, server is only allowed
420             // to send a Keep Alive response to a request from client.
421             OT_ASSERT(aMessageType == kResponseMessage);
422         }
423         break;
424 
425     case kStateSessionEstablished:
426         break;
427 
428     case kStateDisconnected:
429     case kStateConnecting:
430         OT_ASSERT(false);
431     }
432 
433     // Server can send Keep Alive response (to a request from client)
434     // or a unidirectional Keep Alive message. Client can send
435     // KeepAlive request message.
436 
437     if (IsServer())
438     {
439         if (aMessageType == kResponseMessage)
440         {
441             OT_ASSERT(aResponseId != 0);
442         }
443         else
444         {
445             OT_ASSERT(aMessageType == kUnidirectionalMessage);
446         }
447     }
448     else
449     {
450         OT_ASSERT(aMessageType == kRequestMessage);
451     }
452 
453     message = NewMessage();
454     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
455 
456     keepAliveTlv.Init();
457 
458     if (IsServer())
459     {
460         keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
461         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
462     }
463     else
464     {
465         keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
466         keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
467     }
468 
469     SuccessOrExit(error = message->Append(keepAliveTlv));
470 
471     error = SendMessage(*message, aMessageType, aResponseId);
472 
473 exit:
474     FreeMessageOnError(message, error);
475     return error;
476 }
477 
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)478 Error Dso::Connection::SendMessage(Message              &aMessage,
479                                    MessageType           aMessageType,
480                                    MessageId            &aMessageId,
481                                    Dns::Header::Response aResponseCode,
482                                    uint32_t              aResponseTimeout)
483 {
484     Error       error          = kErrorNone;
485     Tlv::Type   primaryTlvType = Tlv::kReservedType;
486     Dns::Header header;
487 
488     switch (mState)
489     {
490     case kStateConnectedButSessionless:
491         // To establish session, client MUST send a request message.
492         // Server is not allowed to send any messages. Unidirectional
493         // messages are not allowed before session is established.
494         OT_ASSERT(IsClient());
495         OT_ASSERT(aMessageType == kRequestMessage);
496         break;
497 
498     case kStateEstablishingSession:
499         // During session establishment, client is allowed to send
500         // additional request messages, server is only allowed to
501         // send response.
502         if (IsClient())
503         {
504             OT_ASSERT(aMessageType == kRequestMessage);
505         }
506         else
507         {
508             OT_ASSERT(aMessageType == kResponseMessage);
509         }
510         break;
511 
512     case kStateSessionEstablished:
513         // All message types are allowed.
514         break;
515 
516     case kStateDisconnected:
517     case kStateConnecting:
518         OT_ASSERT(false);
519     }
520 
521     // A DSO request or unidirectional message MUST contain at
522     // least one TLV. The first TLV is the "Primary TLV" and
523     // determines the nature of the operation being performed.
524     // A DSO response message may contain no TLVs, or may contain
525     // one or more TLVs. Response Primary TLV(s) MUST appear first
526     // in a DSO response message.
527 
528     aMessage.SetOffset(0);
529     IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
530 
531     switch (aMessageType)
532     {
533     case kResponseMessage:
534         break;
535     case kRequestMessage:
536     case kUnidirectionalMessage:
537         OT_ASSERT(primaryTlvType != Tlv::kReservedType);
538     }
539 
540     // `header` is cleared from its constructor call so all fields
541     // start as zero.
542 
543     switch (aMessageType)
544     {
545     case kRequestMessage:
546         header.SetType(Dns::Header::kTypeQuery);
547         aMessageId = mNextMessageId;
548         break;
549 
550     case kResponseMessage:
551         header.SetType(Dns::Header::kTypeResponse);
552         break;
553 
554     case kUnidirectionalMessage:
555         header.SetType(Dns::Header::kTypeQuery);
556         aMessageId = 0;
557         break;
558     }
559 
560     header.SetMessageId(aMessageId);
561     header.SetQueryType(Dns::Header::kQueryTypeDso);
562     header.SetResponseCode(aResponseCode);
563     SuccessOrExit(error = aMessage.Prepend(header));
564 
565     SuccessOrExit(error = AppendPadding(aMessage));
566 
567     // Update `mPendingRequests` list with the new request info
568 
569     if (aMessageType == kRequestMessage)
570     {
571         SuccessOrExit(
572             error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
573 
574         if (++mNextMessageId == 0)
575         {
576             mNextMessageId = 1;
577         }
578     }
579 
580     LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
581             mPeerSockAddr.ToString().AsCString());
582 
583     switch (mState)
584     {
585     case kStateConnectedButSessionless:
586         // On client we transition from "connected" state to
587         // "establishing session" state on successfully sending a
588         // request message.
589         if (IsClient())
590         {
591             SetState(kStateEstablishingSession);
592         }
593         break;
594 
595     case kStateEstablishingSession:
596         // On server we transition from "establishing session" state
597         // to "established" on sending a response with success
598         // response code.
599         if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
600         {
601             SetState(kStateSessionEstablished);
602         }
603 
604     default:
605         break;
606     }
607 
608     ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
609 
610     otPlatDsoSend(this, &aMessage);
611 
612     // Signal any state changes. This is done at the very end when the
613     // `SendMessage()` is fully processed (all state and local
614     // variables are updated) to ensure that we do not have any
615     // reentrancy issues (e.g., if the callback signalling state
616     // change triggers another tx).
617 
618     SignalAnyStateChange();
619 
620 exit:
621     return error;
622 }
623 
AppendPadding(Message & aMessage)624 Error Dso::Connection::AppendPadding(Message &aMessage)
625 {
626     // This method appends Encryption Padding TLV to a DSO message.
627     // It uses the padding policy "Random-Block-Length Padding" from
628     // RFC 8467.
629 
630     static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
631 
632     Error                error = kErrorNone;
633     uint16_t             blockLength;
634     EncryptionPaddingTlv paddingTlv;
635 
636     // We pick a random block length. The random selection can be
637     // based on a "weak" source of randomness (so the use of
638     // `NonCrypto` is fine). We add padding to the message such
639     // that its padded length is a multiple of the chosen block
640     // length.
641 
642     blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
643 
644     paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
645 
646     SuccessOrExit(error = aMessage.Append(paddingTlv));
647 
648     for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
649     {
650         SuccessOrExit(error = aMessage.Append<uint8_t>(0));
651     }
652 
653 exit:
654     return error;
655 }
656 
HandleReceive(Message & aMessage)657 void Dso::Connection::HandleReceive(Message &aMessage)
658 {
659     Error       error          = kErrorAbort;
660     Tlv::Type   primaryTlvType = Tlv::kReservedType;
661     Dns::Header header;
662 
663     SuccessOrExit(aMessage.Read(0, header));
664 
665     if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
666     {
667         if (header.GetType() == Dns::Header::kTypeQuery)
668         {
669             SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
670             error = kErrorNone;
671         }
672 
673         ExitNow();
674     }
675 
676     switch (mState)
677     {
678     case kStateConnectedButSessionless:
679         // After connection is established, client should initiate
680         // establishing session (by sending a request). So no rx is
681         // allowed before this. On server, we allow rx of a request
682         // message only.
683         VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
684         break;
685 
686     case kStateEstablishingSession:
687         // Unidirectional message are allowed after session is
688         // established. While session is being established, on client,
689         // we allow rx on response message. On server we can rx
690         // request or response.
691 
692         VerifyOrExit(header.GetMessageId() != 0);
693 
694         if (IsClient())
695         {
696             VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
697         }
698         break;
699 
700     case kStateSessionEstablished:
701         // All message types are allowed.
702         break;
703 
704     case kStateDisconnected:
705     case kStateConnecting:
706         ExitNow();
707     }
708 
709     // All count fields MUST be set to zero in the header.
710     VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
711                  (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
712 
713     aMessage.SetOffset(sizeof(header));
714 
715     switch (ReadPrimaryTlv(aMessage, primaryTlvType))
716     {
717     case kErrorNone:
718         VerifyOrExit(primaryTlvType != Tlv::kReservedType);
719         break;
720 
721     case kErrorNotFound:
722         // The `primaryTlvType` is set to `Tlv::kReservedType`
723         // (value zero) to indicate that there is no primary TLV.
724         break;
725 
726     default:
727         ExitNow();
728     }
729 
730     switch (header.GetType())
731     {
732     case Dns::Header::kTypeQuery:
733         error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
734         break;
735 
736     case Dns::Header::kTypeResponse:
737         error = ProcessResponseMessage(header, aMessage, primaryTlvType);
738         break;
739     }
740 
741 exit:
742     aMessage.Free();
743 
744     if (error == kErrorNone)
745     {
746         ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
747     }
748     else
749     {
750         Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
751     }
752 
753     // We signal any state change at the very end when the received
754     // message is fully processed (all state and local variables are
755     // updated) to ensure that we do not have any reentrancy issues
756     // (e.g., if a `Connection` method happens to be called from the
757     // callback).
758 
759     SignalAnyStateChange();
760 }
761 
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const762 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
763 {
764     // Read and validate the primary TLV (first TLV  after the header).
765     // The `aMessage.GetOffset()` must point to the first TLV. If no
766     // TLV then `kErrorNotFound` is returned. If TLV in message is not
767     // well-formed `kErrorParse` is returned. The read TLV type is
768     // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
769     // (value zero) when `kErrorNotFound`).
770 
771     Error error = kErrorNotFound;
772     Tlv   tlv;
773 
774     aPrimaryTlvType = Tlv::kReservedType;
775 
776     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
777     VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
778     aPrimaryTlvType = tlv.GetType();
779     error           = kErrorNone;
780 
781 exit:
782     return error;
783 }
784 
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)785 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
786                                                              const Message     &aMessage,
787                                                              Tlv::Type          aPrimaryTlvType)
788 {
789     Error error = kErrorAbort;
790 
791     if (IsServer() && (mState == kStateConnectedButSessionless))
792     {
793         SetState(kStateEstablishingSession);
794     }
795 
796     // A DSO request or unidirectional message MUST contain at
797     // least one TLV which is the "Primary TLV" and determines
798     // the nature of the operation being performed.
799 
800     switch (aPrimaryTlvType)
801     {
802     case KeepAliveTlv::kType:
803         error = ProcessKeepAliveMessage(aHeader, aMessage);
804         break;
805 
806     case RetryDelayTlv::kType:
807         error = ProcessRetryDelayMessage(aHeader, aMessage);
808         break;
809 
810     case Tlv::kReservedType:
811     case EncryptionPaddingTlv::kType:
812         // Misbehavior by peer.
813         break;
814 
815     default:
816         if (aHeader.GetMessageId() == 0)
817         {
818             LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
819 
820             error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
821         }
822         else
823         {
824             MessageId messageId = aHeader.GetMessageId();
825 
826             LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
827 
828             error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
829 
830             // `kErrorNotFound` indicates that TLV type is not known.
831 
832             if (error == kErrorNotFound)
833             {
834                 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
835                 error = kErrorNone;
836             }
837         }
838         break;
839     }
840 
841     return error;
842 }
843 
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)844 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
845                                               const Message     &aMessage,
846                                               Tlv::Type          aPrimaryTlvType)
847 {
848     Error     error = kErrorAbort;
849     Tlv::Type requestPrimaryTlvType;
850 
851     // If a client or server receives a response where the message
852     // ID is zero, or is any other value that does not match the
853     // message ID of any of its outstanding operations, this is a
854     // fatal error and the recipient MUST forcibly abort the
855     // connection immediately.
856 
857     VerifyOrExit(aHeader.GetMessageId() != 0);
858     VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
859 
860     // If the response has no error and contains a primary TLV, it
861     // MUST match the request primary TLV.
862 
863     if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
864     {
865         VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
866     }
867 
868     mPendingRequests.Remove(aHeader.GetMessageId());
869 
870     switch (requestPrimaryTlvType)
871     {
872     case KeepAliveTlv::kType:
873         SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
874         break;
875 
876     default:
877         SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
878                                                                  requestPrimaryTlvType));
879         break;
880     }
881 
882     // DSO session is established when client sends a request message
883     // and receives a response from server with no error code.
884 
885     if (IsClient() && (mState == kStateEstablishingSession) &&
886         (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
887     {
888         SetState(kStateSessionEstablished);
889     }
890 
891 exit:
892     return error;
893 }
894 
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)895 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
896 {
897     Error        error  = kErrorAbort;
898     uint16_t     offset = aMessage.GetOffset();
899     Tlv          tlv;
900     KeepAliveTlv keepAliveTlv;
901 
902     if (aHeader.GetType() == Dns::Header::kTypeResponse)
903     {
904         // A Keep Alive response message is allowed on a client from a sever.
905 
906         VerifyOrExit(IsClient());
907 
908         if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
909         {
910             // We got an error response code from server for our
911             // Keep Alive request message. If this happens while
912             // establishing the DSO session, it indicates that server
913             // does not support DSO, so we close the connection. If
914             // this happens while session is already established, it
915             // is a misbehavior (fatal error) by server.
916 
917             if (mState == kStateEstablishingSession)
918             {
919                 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
920                 error = kErrorNone;
921             }
922 
923             ExitNow();
924         }
925     }
926 
927     // Parse and validate the Keep Alive Message
928 
929     SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
930     offset += keepAliveTlv.GetSize();
931 
932     VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
933 
934     // Keep Alive message MUST contain only one Keep Alive TLV.
935 
936     while (offset < aMessage.GetLength())
937     {
938         SuccessOrExit(aMessage.Read(offset, tlv));
939         offset += tlv.GetSize();
940 
941         VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
942     }
943 
944     VerifyOrExit(offset == aMessage.GetLength());
945 
946     if (aHeader.GetType() == Dns::Header::kTypeQuery)
947     {
948         if (IsServer())
949         {
950             // Received a Keep Alive message from client. It MUST
951             // be a request message (not unidirectional). We prepare
952             // and send a Keep Alive response.
953 
954             VerifyOrExit(aHeader.GetMessageId() != 0);
955 
956             LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
957 
958             IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
959             error = kErrorNone;
960             ExitNow();
961         }
962 
963         // Received a Keep Alive message on client from server. Server
964         // Keep Alive message MUST be unidirectional (message ID
965         // zero).
966 
967         VerifyOrExit(aHeader.GetMessageId() == 0);
968     }
969 
970     LogInfo("Received Keep Alive %s message from server %s",
971             (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
972 
973     // Receiving a Keep Alive interval value from server less than the
974     // minimum (ten seconds) is a fatal error and client MUST then
975     // abort the connection.
976 
977     VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
978 
979     // Update the timeout intervals on the connection from
980     // the new values we got from the server. The receive
981     // of the Keep Alive message does not itself reset the
982     // inactivity timer. So we use `AdjustInactivityTimeout`
983     // which takes into account the time elapsed since the
984     // last activity.
985 
986     AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
987     mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
988 
989     LogInfo("Timeouts Inactivity:%lu, KeepAlive:%lu", ToUlong(mInactivity.GetInterval()),
990             ToUlong(mKeepAlive.GetInterval()));
991 
992     error = kErrorNone;
993 
994 exit:
995     return error;
996 }
997 
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)998 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
999 
1000 {
1001     Error         error = kErrorAbort;
1002     RetryDelayTlv retryDelayTlv;
1003 
1004     // Retry Delay TLV can be used as the Primary TLV only in
1005     // a unidirectional message sent from server to client.
1006     // It is used by the server to instruct the client to
1007     // close the session and its underlying connection, and not
1008     // to reconnect for the indicated time interval.
1009 
1010     VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1011 
1012     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1013     VerifyOrExit(retryDelayTlv.IsValid());
1014 
1015     mRetryDelayErrorCode = aHeader.GetResponseCode();
1016     mRetryDelay          = retryDelayTlv.GetRetryDelay();
1017 
1018     LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1019     LogInfo("   RetryDelay:%lu ms, ResponseCode:%d", ToUlong(mRetryDelay), mRetryDelayErrorCode);
1020 
1021     Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1022 
1023 exit:
1024     return error;
1025 }
1026 
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1027 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1028 {
1029     Message    *response = NewMessage();
1030     Dns::Header header;
1031 
1032     VerifyOrExit(response != nullptr);
1033 
1034     header.SetMessageId(aHeader.GetMessageId());
1035     header.SetType(Dns::Header::kTypeResponse);
1036     header.SetQueryType(aHeader.GetQueryType());
1037     header.SetResponseCode(aResponseCode);
1038 
1039     SuccessOrExit(response->Prepend(header));
1040 
1041     otPlatDsoSend(this, response);
1042     response = nullptr;
1043 
1044 exit:
1045     FreeMessage(response);
1046 }
1047 
AdjustInactivityTimeout(uint32_t aNewTimeout)1048 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1049 {
1050     // This method sets the inactivity timeout interval to a new value
1051     // and updates the expiration time based on the new timeout value.
1052     //
1053     // On client, it is called on receiving a Keep Alive response or
1054     // unidirectional message from server. Note that the receive of
1055     // the Keep Alive message does not itself reset the inactivity
1056     // timer. So the time elapsed since the last activity should be
1057     // taken into account with the new inactivity timeout value.
1058     //
1059     // On server this method is called from `SetTimeouts()` when a new
1060     // inactivity timeout value is set.
1061 
1062     TimeMilli now = TimerMilli::GetNow();
1063     TimeMilli start;
1064     TimeMilli newExpiration;
1065 
1066     if (mState == kStateDisconnected)
1067     {
1068         mInactivity.SetInterval(aNewTimeout);
1069         ExitNow();
1070     }
1071 
1072     VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1073 
1074     // Calculate the start time (i.e., the last time inactivity timer
1075     // was cleared). If the previous inactivity time is set to
1076     // `kInfinite` value (`IsUsed()` returns `false`) then
1077     // `GetExpirationTime()` returns the start time. Otherwise, we
1078     // calculate it going back from the current expiration time with
1079     // the current wait interval.
1080 
1081     if (!mInactivity.IsUsed())
1082     {
1083         start = mInactivity.GetExpirationTime();
1084     }
1085     else if (IsClient())
1086     {
1087         start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1088     }
1089     else
1090     {
1091         start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1092     }
1093 
1094     mInactivity.SetInterval(aNewTimeout);
1095 
1096     if (!mInactivity.IsUsed())
1097     {
1098         newExpiration = start;
1099     }
1100     else if (IsClient())
1101     {
1102         newExpiration = start + aNewTimeout;
1103 
1104         if (newExpiration < now)
1105         {
1106             newExpiration = now;
1107         }
1108     }
1109     else
1110     {
1111         newExpiration = start + CalculateServerInactivityWaitTime();
1112 
1113         if (newExpiration < now)
1114         {
1115             // If the server abruptly reduces the inactivity timeout
1116             // such that current elapsed time is already more than
1117             // twice the new inactivity timeout, then the client is
1118             // immediately considered delinquent (server can forcibly
1119             // abort the connection). So to give the client time to
1120             // close the connection gracefully, the server SHOULD
1121             // give the client an additional grace period of either
1122             // five seconds or one quarter of the new inactivity
1123             // timeout, whichever is greater [RFC 8490 - 7.1.1].
1124 
1125             newExpiration = now + Max(kMinServerInactivityWaitTime, aNewTimeout / 4);
1126         }
1127     }
1128 
1129     mInactivity.SetExpirationTime(newExpiration);
1130 
1131 exit:
1132     return;
1133 }
1134 
CalculateServerInactivityWaitTime(void) const1135 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1136 {
1137     // A server will abort an idle session after five seconds
1138     // (`kMinServerInactivityWaitTime`) or twice the inactivity
1139     // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1140 
1141     OT_ASSERT(mInactivity.IsUsed());
1142 
1143     return Max(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1144 }
1145 
ResetTimeouts(bool aIsKeepAliveMessage)1146 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1147 {
1148     NextFireTime nextTime;
1149 
1150     // At both servers and clients, the generation or reception of any
1151     // complete DNS message resets both timers for that DSO
1152     // session, with the one exception being that a DSO Keep Alive
1153     // message resets only the keep alive timer, not the inactivity
1154     // timeout timer [RFC 8490 - 6.3]
1155 
1156     if (mKeepAlive.IsUsed())
1157     {
1158         // On client, we wait for the Keep Alive interval but on server
1159         // we wait for twice the interval before considering Keep Alive
1160         // timeout.
1161         //
1162         // Note that we limit the interval to `Timeout::kMaxInterval`
1163         // (which is ~12 days). This max limit ensures that even twice
1164         // the interval is less than max OpenThread timer duration so
1165         // that the expiration time calculations below stay within the
1166         // `TimerMilli` range.
1167 
1168         mKeepAlive.SetExpirationTime(nextTime.GetNow() + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1169     }
1170 
1171     if (!aIsKeepAliveMessage)
1172     {
1173         if (mInactivity.IsUsed())
1174         {
1175             mInactivity.SetExpirationTime(
1176                 nextTime.GetNow() + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1177         }
1178         else
1179         {
1180             // When Inactivity timeout is not used (i.e., interval is set
1181             // to the special `kInfinite` value), we still need to track
1182             // the time so that if/when later the inactivity interval
1183             // gets changed, we can adjust the remaining time correctly
1184             // from `AdjustInactivityTimeout()`. In this case, we just
1185             // track the current time as "expiration time".
1186 
1187             mInactivity.SetExpirationTime(nextTime.GetNow());
1188         }
1189     }
1190 
1191     UpdateNextFireTime(nextTime);
1192 
1193     Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1194 }
1195 
UpdateNextFireTime(NextFireTime & aNextTime) const1196 void Dso::Connection::UpdateNextFireTime(NextFireTime &aNextTime) const
1197 {
1198     switch (mState)
1199     {
1200     case kStateDisconnected:
1201         break;
1202 
1203     case kStateConnecting:
1204         // While in `kStateConnecting`, Keep Alive timer is
1205         // used for `kConnectingTimeout`.
1206         aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1207         break;
1208 
1209     case kStateConnectedButSessionless:
1210     case kStateEstablishingSession:
1211     case kStateSessionEstablished:
1212         mPendingRequests.UpdateNextFireTime(aNextTime);
1213 
1214         if (mKeepAlive.IsUsed())
1215         {
1216             aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1217         }
1218 
1219         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1220         {
1221             // An operation being active on a DSO Session includes
1222             // a request message waiting for a response, or an
1223             // active long-lived operation.
1224 
1225             aNextTime.UpdateIfEarlier(mInactivity.GetExpirationTime());
1226         }
1227 
1228         break;
1229     }
1230 }
1231 
HandleTimer(NextFireTime & aNextTime)1232 void Dso::Connection::HandleTimer(NextFireTime &aNextTime)
1233 {
1234     switch (mState)
1235     {
1236     case kStateDisconnected:
1237         break;
1238 
1239     case kStateConnecting:
1240         if (mKeepAlive.IsExpired(aNextTime.GetNow()))
1241         {
1242             Disconnect(kGracefullyClose, kReasonFailedToConnect);
1243         }
1244         break;
1245 
1246     case kStateConnectedButSessionless:
1247     case kStateEstablishingSession:
1248     case kStateSessionEstablished:
1249         if (mPendingRequests.HasAnyTimedOut(aNextTime.GetNow()))
1250         {
1251             // If server sends no response to a request, client
1252             // waits for 30 seconds (`kResponseTimeout`) after which
1253             // client MUST forcibly abort the connection.
1254             Disconnect(kForciblyAbort, kReasonResponseTimeout);
1255             ExitNow();
1256         }
1257 
1258         // The inactivity timer is kept clear, while an operation is
1259         // active on the session (which includes a request waiting for
1260         // response or an active long-lived operation).
1261 
1262         if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation &&
1263             mInactivity.IsExpired(aNextTime.GetNow()))
1264         {
1265             // On client, if the inactivity timeout is reached, the
1266             // connection is closed gracefully. On server, if too much
1267             // time (`CalculateServerInactivityWaitTime()`, i.e., five
1268             // seconds or twice the current inactivity timeout interval,
1269             // whichever is grater) elapses server MUST consider the
1270             // client delinquent and MUST forcibly abort the connection.
1271 
1272             Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1273             ExitNow();
1274         }
1275 
1276         if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNextTime.GetNow()))
1277         {
1278             // On client, if the Keep Alive interval elapses without any
1279             // DNS messages being sent or received, the client MUST take
1280             // action and send a DSO Keep Alive message.
1281             //
1282             // On server, if twice the Keep Alive interval value elapses
1283             // without any messages being sent or received, the server
1284             // considers the client delinquent and aborts the connection.
1285 
1286             if (IsClient())
1287             {
1288                 IgnoreError(SendKeepAliveMessage());
1289             }
1290             else
1291             {
1292                 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1293                 ExitNow();
1294             }
1295         }
1296         break;
1297     }
1298 
1299 exit:
1300     UpdateNextFireTime(aNextTime);
1301     SignalAnyStateChange();
1302 }
1303 
StateToString(State aState)1304 const char *Dso::Connection::StateToString(State aState)
1305 {
1306     static const char *const kStateStrings[] = {
1307         "Disconnected",            // (0) kStateDisconnected,
1308         "Connecting",              // (1) kStateConnecting,
1309         "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1310         "EstablishingSession",     // (3) kStateEstablishingSession,
1311         "SessionEstablished",      // (4) kStateSessionEstablished,
1312     };
1313 
1314     static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1315     static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1316     static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1317     static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1318     static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1319 
1320     return kStateStrings[aState];
1321 }
1322 
MessageTypeToString(MessageType aMessageType)1323 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1324 {
1325     static const char *const kMessageTypeStrings[] = {
1326         "Request",        // (0) kRequestMessage
1327         "Response",       // (1) kResponseMessage
1328         "Unidirectional", // (2) kUnidirectionalMessage
1329     };
1330 
1331     static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1332     static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1333     static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1334 
1335     return kMessageTypeStrings[aMessageType];
1336 }
1337 
DisconnectReasonToString(DisconnectReason aReason)1338 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1339 {
1340     static const char *const kDisconnectReasonStrings[] = {
1341         "FailedToConnect",         // (0) kReasonFailedToConnect
1342         "ResponseTimeout",         // (1) kReasonResponseTimeout
1343         "PeerDoesNotSupportDso",   // (2) kReasonPeerDoesNotSupportDso
1344         "PeerClosed",              // (3) kReasonPeerClosed
1345         "PeerAborted",             // (4) kReasonPeerAborted
1346         "InactivityTimeout",       // (5) kReasonInactivityTimeout
1347         "KeepAliveTimeout",        // (6) kReasonKeepAliveTimeout
1348         "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1349         "PeerMisbehavior",         // (8) kReasonPeerMisbehavior
1350         "Unknown",                 // (9) kReasonUnknown
1351     };
1352 
1353     static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1354     static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1355     static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1356     static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1357     static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1358     static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1359     static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1360     static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1361     static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1362     static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1363 
1364     return kDisconnectReasonStrings[aReason];
1365 }
1366 
1367 //---------------------------------------------------------------------------------------------------------------------
1368 // Dso::Connection::PendingRequests
1369 
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1370 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1371 {
1372     bool         contains = true;
1373     const Entry *entry    = mRequests.FindMatching(aMessageId);
1374 
1375     VerifyOrExit(entry != nullptr, contains = false);
1376     aPrimaryTlvType = entry->mPrimaryTlvType;
1377 
1378 exit:
1379     return contains;
1380 }
1381 
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1382 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1383 {
1384     Error  error = kErrorNone;
1385     Entry *entry = mRequests.PushBack();
1386 
1387     VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1388     entry->mMessageId      = aMessageId;
1389     entry->mPrimaryTlvType = aPrimaryTlvType;
1390     entry->mTimeout        = aResponseTimeout;
1391 
1392 exit:
1393     return error;
1394 }
1395 
Remove(MessageId aMessageId)1396 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); }
1397 
HasAnyTimedOut(TimeMilli aNow) const1398 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1399 {
1400     bool timedOut = false;
1401 
1402     for (const Entry &entry : mRequests)
1403     {
1404         if (entry.mTimeout <= aNow)
1405         {
1406             timedOut = true;
1407             break;
1408         }
1409     }
1410 
1411     return timedOut;
1412 }
1413 
UpdateNextFireTime(NextFireTime & aNextTime) const1414 void Dso::Connection::PendingRequests::UpdateNextFireTime(NextFireTime &aNextTime) const
1415 {
1416     for (const Entry &entry : mRequests)
1417     {
1418         aNextTime.UpdateIfEarlier(entry.mTimeout);
1419     }
1420 }
1421 
1422 //---------------------------------------------------------------------------------------------------------------------
1423 // Dso
1424 
Dso(Instance & aInstance)1425 Dso::Dso(Instance &aInstance)
1426     : InstanceLocator(aInstance)
1427     , mAcceptHandler(nullptr)
1428     , mTimer(aInstance)
1429 {
1430 }
1431 
StartListening(AcceptHandler aAcceptHandler)1432 void Dso::StartListening(AcceptHandler aAcceptHandler)
1433 {
1434     mAcceptHandler = aAcceptHandler;
1435     otPlatDsoEnableListening(&GetInstance(), true);
1436 }
1437 
StopListening(void)1438 void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); }
1439 
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1440 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1441 {
1442     return mClientConnections.FindMatching(aPeerSockAddr);
1443 }
1444 
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1445 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1446 {
1447     return mServerConnections.FindMatching(aPeerSockAddr);
1448 }
1449 
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1450 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1451 {
1452     Connection *connection = nullptr;
1453 
1454     VerifyOrExit(mAcceptHandler != nullptr);
1455     connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1456 
1457     VerifyOrExit(connection != nullptr);
1458     connection->Accept();
1459 
1460 exit:
1461     return connection;
1462 }
1463 
HandleTimer(void)1464 void Dso::HandleTimer(void)
1465 {
1466     NextFireTime nextTime;
1467     Connection  *conn;
1468     Connection  *next;
1469 
1470     for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1471     {
1472         next = conn->GetNext();
1473         conn->HandleTimer(nextTime);
1474     }
1475 
1476     for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1477     {
1478         next = conn->GetNext();
1479         conn->HandleTimer(nextTime);
1480     }
1481 
1482     mTimer.FireAtIfEarlier(nextTime);
1483 }
1484 
1485 } // namespace Dns
1486 } // namespace ot
1487 
1488 #if OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1489 
otPlatDsoEnableListening(otInstance * aInstance,bool aEnable)1490 OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable)
1491 {
1492     OT_UNUSED_VARIABLE(aInstance);
1493     OT_UNUSED_VARIABLE(aEnable);
1494 }
1495 
otPlatDsoConnect(otPlatDsoConnection * aConnection,const otSockAddr * aPeerSockAddr)1496 OT_TOOL_WEAK void otPlatDsoConnect(otPlatDsoConnection *aConnection, const otSockAddr *aPeerSockAddr)
1497 {
1498     OT_UNUSED_VARIABLE(aConnection);
1499     OT_UNUSED_VARIABLE(aPeerSockAddr);
1500 }
1501 
otPlatDsoSend(otPlatDsoConnection * aConnection,otMessage * aMessage)1502 OT_TOOL_WEAK void otPlatDsoSend(otPlatDsoConnection *aConnection, otMessage *aMessage)
1503 {
1504     OT_UNUSED_VARIABLE(aConnection);
1505     OT_UNUSED_VARIABLE(aMessage);
1506 }
1507 
otPlatDsoDisconnect(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)1508 OT_TOOL_WEAK void otPlatDsoDisconnect(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
1509 {
1510     OT_UNUSED_VARIABLE(aConnection);
1511     OT_UNUSED_VARIABLE(aMode);
1512 }
1513 
1514 #endif // OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1515 
1516 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1517