1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_dso.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_DSO_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "common/num_utils.hpp"
40 #include "common/random.hpp"
41 #include "instance/instance.hpp"
42
43 /**
44 * @file
45 * This file implements the DNS Stateful Operations (DSO) per RFC 8490.
46 */
47
48 namespace ot {
49 namespace Dns {
50
51 RegisterLogModule("DnsDso");
52
53 //---------------------------------------------------------------------------------------------------------------------
54 // otPlatDso transport callbacks
55
otPlatDsoGetInstance(otPlatDsoConnection * aConnection)56 extern "C" otInstance *otPlatDsoGetInstance(otPlatDsoConnection *aConnection)
57 {
58 return &AsCoreType(aConnection).GetInstance();
59 }
60
otPlatDsoAccept(otInstance * aInstance,const otSockAddr * aPeerSockAddr)61 extern "C" otPlatDsoConnection *otPlatDsoAccept(otInstance *aInstance, const otSockAddr *aPeerSockAddr)
62 {
63 return AsCoreType(aInstance).Get<Dso>().AcceptConnection(AsCoreType(aPeerSockAddr));
64 }
65
otPlatDsoHandleConnected(otPlatDsoConnection * aConnection)66 extern "C" void otPlatDsoHandleConnected(otPlatDsoConnection *aConnection)
67 {
68 AsCoreType(aConnection).HandleConnected();
69 }
70
otPlatDsoHandleReceive(otPlatDsoConnection * aConnection,otMessage * aMessage)71 extern "C" void otPlatDsoHandleReceive(otPlatDsoConnection *aConnection, otMessage *aMessage)
72 {
73 AsCoreType(aConnection).HandleReceive(AsCoreType(aMessage));
74 }
75
otPlatDsoHandleDisconnected(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)76 extern "C" void otPlatDsoHandleDisconnected(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
77 {
78 AsCoreType(aConnection).HandleDisconnected(MapEnum(aMode));
79 }
80
81 //---------------------------------------------------------------------------------------------------------------------
82 // Dso::Connection
83
Connection(Instance & aInstance,const Ip6::SockAddr & aPeerSockAddr,Callbacks & aCallbacks,uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)84 Dso::Connection::Connection(Instance &aInstance,
85 const Ip6::SockAddr &aPeerSockAddr,
86 Callbacks &aCallbacks,
87 uint32_t aInactivityTimeout,
88 uint32_t aKeepAliveInterval)
89 : InstanceLocator(aInstance)
90 , mNext(nullptr)
91 , mCallbacks(aCallbacks)
92 , mPeerSockAddr(aPeerSockAddr)
93 , mState(kStateDisconnected)
94 , mIsServer(false)
95 , mInactivity(aInactivityTimeout)
96 , mKeepAlive(aKeepAliveInterval)
97 {
98 OT_ASSERT(aKeepAliveInterval >= kMinKeepAliveInterval);
99 Init(/* aIsServer */ false);
100 }
101
Init(bool aIsServer)102 void Dso::Connection::Init(bool aIsServer)
103 {
104 mNextMessageId = 1;
105 mIsServer = aIsServer;
106 mStateDidChange = false;
107 mLongLivedOperation = false;
108 mRetryDelay = 0;
109 mRetryDelayErrorCode = Dns::Header::kResponseSuccess;
110 mDisconnectReason = kReasonUnknown;
111 }
112
SetState(State aState)113 void Dso::Connection::SetState(State aState)
114 {
115 VerifyOrExit(mState != aState);
116
117 LogInfo("State: %s -> %s on connection with %s", StateToString(mState), StateToString(aState),
118 mPeerSockAddr.ToString().AsCString());
119
120 mState = aState;
121 mStateDidChange = true;
122
123 exit:
124 return;
125 }
126
SignalAnyStateChange(void)127 void Dso::Connection::SignalAnyStateChange(void)
128 {
129 VerifyOrExit(mStateDidChange);
130 mStateDidChange = false;
131
132 switch (mState)
133 {
134 case kStateDisconnected:
135 mCallbacks.mHandleDisconnected(*this);
136 break;
137
138 case kStateConnectedButSessionless:
139 mCallbacks.mHandleConnected(*this);
140 break;
141
142 case kStateSessionEstablished:
143 mCallbacks.mHandleSessionEstablished(*this);
144 break;
145
146 case kStateConnecting:
147 case kStateEstablishingSession:
148 break;
149 };
150
151 exit:
152 return;
153 }
154
NewMessage(void)155 Message *Dso::Connection::NewMessage(void)
156 {
157 return Get<MessagePool>().Allocate(Message::kTypeOther, sizeof(Dns::Header),
158 Message::Settings(Message::kPriorityNormal));
159 }
160
Connect(void)161 void Dso::Connection::Connect(void)
162 {
163 OT_ASSERT(mState == kStateDisconnected);
164
165 Init(/* aIsServer */ false);
166 Get<Dso>().mClientConnections.Push(*this);
167 MarkAsConnecting();
168 otPlatDsoConnect(this, &mPeerSockAddr);
169 }
170
Accept(void)171 void Dso::Connection::Accept(void)
172 {
173 OT_ASSERT(mState == kStateDisconnected);
174
175 Init(/* aIsServer */ true);
176 Get<Dso>().mServerConnections.Push(*this);
177 MarkAsConnecting();
178 }
179
MarkAsConnecting(void)180 void Dso::Connection::MarkAsConnecting(void)
181 {
182 SetState(kStateConnecting);
183
184 // While in `kStateConnecting` state we use the `mKeepAlive` to
185 // track the `kConnectingTimeout` (if connection is not established
186 // within the timeout, we consider it as failure and close it).
187
188 mKeepAlive.SetExpirationTime(TimerMilli::GetNow() + kConnectingTimeout);
189 Get<Dso>().mTimer.FireAtIfEarlier(mKeepAlive.GetExpirationTime());
190
191 // Wait for `HandleConnected()` or `HandleDisconnected()` callbacks
192 // or timeout.
193 }
194
HandleConnected(void)195 void Dso::Connection::HandleConnected(void)
196 {
197 OT_ASSERT(mState == kStateConnecting);
198
199 SetState(kStateConnectedButSessionless);
200 ResetTimeouts(/* aIsKeepAliveMessage */ false);
201
202 SignalAnyStateChange();
203 }
204
Disconnect(DisconnectMode aMode,DisconnectReason aReason)205 void Dso::Connection::Disconnect(DisconnectMode aMode, DisconnectReason aReason)
206 {
207 VerifyOrExit(mState != kStateDisconnected);
208
209 mDisconnectReason = aReason;
210 MarkAsDisconnected();
211
212 otPlatDsoDisconnect(this, MapEnum(aMode));
213
214 exit:
215 return;
216 }
217
HandleDisconnected(DisconnectMode aMode)218 void Dso::Connection::HandleDisconnected(DisconnectMode aMode)
219 {
220 VerifyOrExit(mState != kStateDisconnected);
221
222 if (mState == kStateConnecting)
223 {
224 mDisconnectReason = kReasonFailedToConnect;
225 }
226 else
227 {
228 switch (aMode)
229 {
230 case kGracefullyClose:
231 mDisconnectReason = kReasonPeerClosed;
232 break;
233
234 case kForciblyAbort:
235 mDisconnectReason = kReasonPeerAborted;
236 }
237 }
238
239 MarkAsDisconnected();
240 SignalAnyStateChange();
241
242 exit:
243 return;
244 }
245
MarkAsDisconnected(void)246 void Dso::Connection::MarkAsDisconnected(void)
247 {
248 if (IsClient())
249 {
250 IgnoreError(Get<Dso>().mClientConnections.Remove(*this));
251 }
252 else
253 {
254 IgnoreError(Get<Dso>().mServerConnections.Remove(*this));
255 }
256
257 mPendingRequests.Clear();
258 SetState(kStateDisconnected);
259
260 LogInfo("Disconnect reason: %s", DisconnectReasonToString(mDisconnectReason));
261 }
262
MarkSessionEstablished(void)263 void Dso::Connection::MarkSessionEstablished(void)
264 {
265 switch (mState)
266 {
267 case kStateConnectedButSessionless:
268 case kStateEstablishingSession:
269 case kStateSessionEstablished:
270 break;
271
272 case kStateDisconnected:
273 case kStateConnecting:
274 OT_ASSERT(false);
275 }
276
277 SetState(kStateSessionEstablished);
278 }
279
SendRequestMessage(Message & aMessage,MessageId & aMessageId,uint32_t aResponseTimeout)280 Error Dso::Connection::SendRequestMessage(Message &aMessage, MessageId &aMessageId, uint32_t aResponseTimeout)
281 {
282 return SendMessage(aMessage, kRequestMessage, aMessageId, Dns::Header::kResponseSuccess, aResponseTimeout);
283 }
284
SendUnidirectionalMessage(Message & aMessage)285 Error Dso::Connection::SendUnidirectionalMessage(Message &aMessage)
286 {
287 MessageId messageId = 0;
288
289 return SendMessage(aMessage, kUnidirectionalMessage, messageId);
290 }
291
SendResponseMessage(Message & aMessage,MessageId aResponseId)292 Error Dso::Connection::SendResponseMessage(Message &aMessage, MessageId aResponseId)
293 {
294 return SendMessage(aMessage, kResponseMessage, aResponseId);
295 }
296
SetLongLivedOperation(bool aLongLivedOperation)297 void Dso::Connection::SetLongLivedOperation(bool aLongLivedOperation)
298 {
299 VerifyOrExit(mLongLivedOperation != aLongLivedOperation);
300
301 mLongLivedOperation = aLongLivedOperation;
302
303 LogInfo("Long-lived operation %s", mLongLivedOperation ? "started" : "stopped");
304
305 if (!mLongLivedOperation)
306 {
307 NextFireTime nextTime;
308
309 UpdateNextFireTime(nextTime);
310 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
311 }
312
313 exit:
314 return;
315 }
316
SendRetryDelayMessage(uint32_t aDelay,Dns::Header::Response aResponseCode)317 Error Dso::Connection::SendRetryDelayMessage(uint32_t aDelay, Dns::Header::Response aResponseCode)
318 {
319 Error error = kErrorNone;
320 Message *message = nullptr;
321 RetryDelayTlv retryDelayTlv;
322 MessageId messageId;
323
324 switch (mState)
325 {
326 case kStateSessionEstablished:
327 OT_ASSERT(IsServer());
328 break;
329
330 case kStateConnectedButSessionless:
331 case kStateEstablishingSession:
332 case kStateDisconnected:
333 case kStateConnecting:
334 OT_ASSERT(false);
335 }
336
337 message = NewMessage();
338 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
339
340 retryDelayTlv.Init();
341 retryDelayTlv.SetRetryDelay(aDelay);
342 SuccessOrExit(error = message->Append(retryDelayTlv));
343 error = SendMessage(*message, kUnidirectionalMessage, messageId, aResponseCode);
344
345 exit:
346 FreeMessageOnError(message, error);
347 return error;
348 }
349
SetTimeouts(uint32_t aInactivityTimeout,uint32_t aKeepAliveInterval)350 Error Dso::Connection::SetTimeouts(uint32_t aInactivityTimeout, uint32_t aKeepAliveInterval)
351 {
352 Error error = kErrorNone;
353
354 VerifyOrExit(aKeepAliveInterval >= kMinKeepAliveInterval, error = kErrorInvalidArgs);
355
356 // If acting as server, the timeout values are the ones we grant
357 // to a connecting clients. If acting as client, the timeout
358 // values are what to request when sending Keep Alive message.
359 // If in `kStateDisconnected` we set both (since we don't know
360 // yet whether we are going to connect as client or server).
361
362 if ((mState == kStateDisconnected) || IsServer())
363 {
364 mKeepAlive.SetInterval(aKeepAliveInterval);
365 AdjustInactivityTimeout(aInactivityTimeout);
366 }
367
368 if ((mState == kStateDisconnected) || IsClient())
369 {
370 mKeepAlive.SetRequestInterval(aKeepAliveInterval);
371 mInactivity.SetRequestInterval(aInactivityTimeout);
372 }
373
374 switch (mState)
375 {
376 case kStateDisconnected:
377 case kStateConnecting:
378 break;
379
380 case kStateConnectedButSessionless:
381 case kStateEstablishingSession:
382 if (IsServer())
383 {
384 break;
385 }
386
387 OT_FALL_THROUGH;
388
389 case kStateSessionEstablished:
390 error = SendKeepAliveMessage();
391 }
392
393 exit:
394 return error;
395 }
396
SendKeepAliveMessage(void)397 Error Dso::Connection::SendKeepAliveMessage(void)
398 {
399 return SendKeepAliveMessage(IsServer() ? kUnidirectionalMessage : kRequestMessage, 0);
400 }
401
SendKeepAliveMessage(MessageType aMessageType,MessageId aResponseId)402 Error Dso::Connection::SendKeepAliveMessage(MessageType aMessageType, MessageId aResponseId)
403 {
404 // Sends a Keep Alive message of a given type. This is a common
405 // method used by both client and server. `aResponseId` is
406 // applicable and used only when the message type is
407 // `kResponseMessage`.
408
409 Error error = kErrorNone;
410 Message *message = nullptr;
411 KeepAliveTlv keepAliveTlv;
412
413 switch (mState)
414 {
415 case kStateConnectedButSessionless:
416 case kStateEstablishingSession:
417 if (IsServer())
418 {
419 // While session is being established, server is only allowed
420 // to send a Keep Alive response to a request from client.
421 OT_ASSERT(aMessageType == kResponseMessage);
422 }
423 break;
424
425 case kStateSessionEstablished:
426 break;
427
428 case kStateDisconnected:
429 case kStateConnecting:
430 OT_ASSERT(false);
431 }
432
433 // Server can send Keep Alive response (to a request from client)
434 // or a unidirectional Keep Alive message. Client can send
435 // KeepAlive request message.
436
437 if (IsServer())
438 {
439 if (aMessageType == kResponseMessage)
440 {
441 OT_ASSERT(aResponseId != 0);
442 }
443 else
444 {
445 OT_ASSERT(aMessageType == kUnidirectionalMessage);
446 }
447 }
448 else
449 {
450 OT_ASSERT(aMessageType == kRequestMessage);
451 }
452
453 message = NewMessage();
454 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
455
456 keepAliveTlv.Init();
457
458 if (IsServer())
459 {
460 keepAliveTlv.SetInactivityTimeout(mInactivity.GetInterval());
461 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetInterval());
462 }
463 else
464 {
465 keepAliveTlv.SetInactivityTimeout(mInactivity.GetRequestInterval());
466 keepAliveTlv.SetKeepAliveInterval(mKeepAlive.GetRequestInterval());
467 }
468
469 SuccessOrExit(error = message->Append(keepAliveTlv));
470
471 error = SendMessage(*message, aMessageType, aResponseId);
472
473 exit:
474 FreeMessageOnError(message, error);
475 return error;
476 }
477
SendMessage(Message & aMessage,MessageType aMessageType,MessageId & aMessageId,Dns::Header::Response aResponseCode,uint32_t aResponseTimeout)478 Error Dso::Connection::SendMessage(Message &aMessage,
479 MessageType aMessageType,
480 MessageId &aMessageId,
481 Dns::Header::Response aResponseCode,
482 uint32_t aResponseTimeout)
483 {
484 Error error = kErrorNone;
485 Tlv::Type primaryTlvType = Tlv::kReservedType;
486 Dns::Header header;
487
488 switch (mState)
489 {
490 case kStateConnectedButSessionless:
491 // To establish session, client MUST send a request message.
492 // Server is not allowed to send any messages. Unidirectional
493 // messages are not allowed before session is established.
494 OT_ASSERT(IsClient());
495 OT_ASSERT(aMessageType == kRequestMessage);
496 break;
497
498 case kStateEstablishingSession:
499 // During session establishment, client is allowed to send
500 // additional request messages, server is only allowed to
501 // send response.
502 if (IsClient())
503 {
504 OT_ASSERT(aMessageType == kRequestMessage);
505 }
506 else
507 {
508 OT_ASSERT(aMessageType == kResponseMessage);
509 }
510 break;
511
512 case kStateSessionEstablished:
513 // All message types are allowed.
514 break;
515
516 case kStateDisconnected:
517 case kStateConnecting:
518 OT_ASSERT(false);
519 }
520
521 // A DSO request or unidirectional message MUST contain at
522 // least one TLV. The first TLV is the "Primary TLV" and
523 // determines the nature of the operation being performed.
524 // A DSO response message may contain no TLVs, or may contain
525 // one or more TLVs. Response Primary TLV(s) MUST appear first
526 // in a DSO response message.
527
528 aMessage.SetOffset(0);
529 IgnoreError(ReadPrimaryTlv(aMessage, primaryTlvType));
530
531 switch (aMessageType)
532 {
533 case kResponseMessage:
534 break;
535 case kRequestMessage:
536 case kUnidirectionalMessage:
537 OT_ASSERT(primaryTlvType != Tlv::kReservedType);
538 }
539
540 // `header` is cleared from its constructor call so all fields
541 // start as zero.
542
543 switch (aMessageType)
544 {
545 case kRequestMessage:
546 header.SetType(Dns::Header::kTypeQuery);
547 aMessageId = mNextMessageId;
548 break;
549
550 case kResponseMessage:
551 header.SetType(Dns::Header::kTypeResponse);
552 break;
553
554 case kUnidirectionalMessage:
555 header.SetType(Dns::Header::kTypeQuery);
556 aMessageId = 0;
557 break;
558 }
559
560 header.SetMessageId(aMessageId);
561 header.SetQueryType(Dns::Header::kQueryTypeDso);
562 header.SetResponseCode(aResponseCode);
563 SuccessOrExit(error = aMessage.Prepend(header));
564
565 SuccessOrExit(error = AppendPadding(aMessage));
566
567 // Update `mPendingRequests` list with the new request info
568
569 if (aMessageType == kRequestMessage)
570 {
571 SuccessOrExit(
572 error = mPendingRequests.Add(mNextMessageId, primaryTlvType, TimerMilli::GetNow() + aResponseTimeout));
573
574 if (++mNextMessageId == 0)
575 {
576 mNextMessageId = 1;
577 }
578 }
579
580 LogInfo("Sending %s message with id %u to %s", MessageTypeToString(aMessageType), aMessageId,
581 mPeerSockAddr.ToString().AsCString());
582
583 switch (mState)
584 {
585 case kStateConnectedButSessionless:
586 // On client we transition from "connected" state to
587 // "establishing session" state on successfully sending a
588 // request message.
589 if (IsClient())
590 {
591 SetState(kStateEstablishingSession);
592 }
593 break;
594
595 case kStateEstablishingSession:
596 // On server we transition from "establishing session" state
597 // to "established" on sending a response with success
598 // response code.
599 if (IsServer() && (aResponseCode == Dns::Header::kResponseSuccess))
600 {
601 SetState(kStateSessionEstablished);
602 }
603
604 default:
605 break;
606 }
607
608 ResetTimeouts(/* aIsKeepAliveMessage*/ (primaryTlvType == KeepAliveTlv::kType));
609
610 otPlatDsoSend(this, &aMessage);
611
612 // Signal any state changes. This is done at the very end when the
613 // `SendMessage()` is fully processed (all state and local
614 // variables are updated) to ensure that we do not have any
615 // reentrancy issues (e.g., if the callback signalling state
616 // change triggers another tx).
617
618 SignalAnyStateChange();
619
620 exit:
621 return error;
622 }
623
AppendPadding(Message & aMessage)624 Error Dso::Connection::AppendPadding(Message &aMessage)
625 {
626 // This method appends Encryption Padding TLV to a DSO message.
627 // It uses the padding policy "Random-Block-Length Padding" from
628 // RFC 8467.
629
630 static const uint16_t kBlockLengths[] = {8, 11, 17, 21};
631
632 Error error = kErrorNone;
633 uint16_t blockLength;
634 EncryptionPaddingTlv paddingTlv;
635
636 // We pick a random block length. The random selection can be
637 // based on a "weak" source of randomness (so the use of
638 // `NonCrypto` is fine). We add padding to the message such
639 // that its padded length is a multiple of the chosen block
640 // length.
641
642 blockLength = kBlockLengths[Random::NonCrypto::GetUint8InRange(0, GetArrayLength(kBlockLengths))];
643
644 paddingTlv.Init((blockLength - ((aMessage.GetLength() + sizeof(Tlv)) % blockLength)) % blockLength);
645
646 SuccessOrExit(error = aMessage.Append(paddingTlv));
647
648 for (uint16_t len = paddingTlv.GetLength(); len > 0; len--)
649 {
650 SuccessOrExit(error = aMessage.Append<uint8_t>(0));
651 }
652
653 exit:
654 return error;
655 }
656
HandleReceive(Message & aMessage)657 void Dso::Connection::HandleReceive(Message &aMessage)
658 {
659 Error error = kErrorAbort;
660 Tlv::Type primaryTlvType = Tlv::kReservedType;
661 Dns::Header header;
662
663 SuccessOrExit(aMessage.Read(0, header));
664
665 if (header.GetQueryType() != Dns::Header::kQueryTypeDso)
666 {
667 if (header.GetType() == Dns::Header::kTypeQuery)
668 {
669 SendErrorResponse(header, Dns::Header::kResponseNotImplemented);
670 error = kErrorNone;
671 }
672
673 ExitNow();
674 }
675
676 switch (mState)
677 {
678 case kStateConnectedButSessionless:
679 // After connection is established, client should initiate
680 // establishing session (by sending a request). So no rx is
681 // allowed before this. On server, we allow rx of a request
682 // message only.
683 VerifyOrExit(IsServer() && (header.GetType() == Dns::Header::kTypeQuery) && (header.GetMessageId() != 0));
684 break;
685
686 case kStateEstablishingSession:
687 // Unidirectional message are allowed after session is
688 // established. While session is being established, on client,
689 // we allow rx on response message. On server we can rx
690 // request or response.
691
692 VerifyOrExit(header.GetMessageId() != 0);
693
694 if (IsClient())
695 {
696 VerifyOrExit(header.GetType() == Dns::Header::kTypeResponse);
697 }
698 break;
699
700 case kStateSessionEstablished:
701 // All message types are allowed.
702 break;
703
704 case kStateDisconnected:
705 case kStateConnecting:
706 ExitNow();
707 }
708
709 // All count fields MUST be set to zero in the header.
710 VerifyOrExit((header.GetQuestionCount() == 0) && (header.GetAnswerCount() == 0) &&
711 (header.GetAuthorityRecordCount() == 0) && (header.GetAdditionalRecordCount() == 0));
712
713 aMessage.SetOffset(sizeof(header));
714
715 switch (ReadPrimaryTlv(aMessage, primaryTlvType))
716 {
717 case kErrorNone:
718 VerifyOrExit(primaryTlvType != Tlv::kReservedType);
719 break;
720
721 case kErrorNotFound:
722 // The `primaryTlvType` is set to `Tlv::kReservedType`
723 // (value zero) to indicate that there is no primary TLV.
724 break;
725
726 default:
727 ExitNow();
728 }
729
730 switch (header.GetType())
731 {
732 case Dns::Header::kTypeQuery:
733 error = ProcessRequestOrUnidirectionalMessage(header, aMessage, primaryTlvType);
734 break;
735
736 case Dns::Header::kTypeResponse:
737 error = ProcessResponseMessage(header, aMessage, primaryTlvType);
738 break;
739 }
740
741 exit:
742 aMessage.Free();
743
744 if (error == kErrorNone)
745 {
746 ResetTimeouts(/* aIsKeepAliveMessage */ (primaryTlvType == KeepAliveTlv::kType));
747 }
748 else
749 {
750 Disconnect(kForciblyAbort, kReasonPeerMisbehavior);
751 }
752
753 // We signal any state change at the very end when the received
754 // message is fully processed (all state and local variables are
755 // updated) to ensure that we do not have any reentrancy issues
756 // (e.g., if a `Connection` method happens to be called from the
757 // callback).
758
759 SignalAnyStateChange();
760 }
761
ReadPrimaryTlv(const Message & aMessage,Tlv::Type & aPrimaryTlvType) const762 Error Dso::Connection::ReadPrimaryTlv(const Message &aMessage, Tlv::Type &aPrimaryTlvType) const
763 {
764 // Read and validate the primary TLV (first TLV after the header).
765 // The `aMessage.GetOffset()` must point to the first TLV. If no
766 // TLV then `kErrorNotFound` is returned. If TLV in message is not
767 // well-formed `kErrorParse` is returned. The read TLV type is
768 // returned in `aPrimaryTlvType` (set to `Tlv::kReservedType`
769 // (value zero) when `kErrorNotFound`).
770
771 Error error = kErrorNotFound;
772 Tlv tlv;
773
774 aPrimaryTlvType = Tlv::kReservedType;
775
776 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), tlv));
777 VerifyOrExit(aMessage.GetOffset() + tlv.GetSize() <= aMessage.GetLength(), error = kErrorParse);
778 aPrimaryTlvType = tlv.GetType();
779 error = kErrorNone;
780
781 exit:
782 return error;
783 }
784
ProcessRequestOrUnidirectionalMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)785 Error Dso::Connection::ProcessRequestOrUnidirectionalMessage(const Dns::Header &aHeader,
786 const Message &aMessage,
787 Tlv::Type aPrimaryTlvType)
788 {
789 Error error = kErrorAbort;
790
791 if (IsServer() && (mState == kStateConnectedButSessionless))
792 {
793 SetState(kStateEstablishingSession);
794 }
795
796 // A DSO request or unidirectional message MUST contain at
797 // least one TLV which is the "Primary TLV" and determines
798 // the nature of the operation being performed.
799
800 switch (aPrimaryTlvType)
801 {
802 case KeepAliveTlv::kType:
803 error = ProcessKeepAliveMessage(aHeader, aMessage);
804 break;
805
806 case RetryDelayTlv::kType:
807 error = ProcessRetryDelayMessage(aHeader, aMessage);
808 break;
809
810 case Tlv::kReservedType:
811 case EncryptionPaddingTlv::kType:
812 // Misbehavior by peer.
813 break;
814
815 default:
816 if (aHeader.GetMessageId() == 0)
817 {
818 LogInfo("Received unidirectional message from %s", mPeerSockAddr.ToString().AsCString());
819
820 error = mCallbacks.mProcessUnidirectionalMessage(*this, aMessage, aPrimaryTlvType);
821 }
822 else
823 {
824 MessageId messageId = aHeader.GetMessageId();
825
826 LogInfo("Received request message with id %u from %s", messageId, mPeerSockAddr.ToString().AsCString());
827
828 error = mCallbacks.mProcessRequestMessage(*this, messageId, aMessage, aPrimaryTlvType);
829
830 // `kErrorNotFound` indicates that TLV type is not known.
831
832 if (error == kErrorNotFound)
833 {
834 SendErrorResponse(aHeader, Dns::Header::kDsoTypeNotImplemented);
835 error = kErrorNone;
836 }
837 }
838 break;
839 }
840
841 return error;
842 }
843
ProcessResponseMessage(const Dns::Header & aHeader,const Message & aMessage,Tlv::Type aPrimaryTlvType)844 Error Dso::Connection::ProcessResponseMessage(const Dns::Header &aHeader,
845 const Message &aMessage,
846 Tlv::Type aPrimaryTlvType)
847 {
848 Error error = kErrorAbort;
849 Tlv::Type requestPrimaryTlvType;
850
851 // If a client or server receives a response where the message
852 // ID is zero, or is any other value that does not match the
853 // message ID of any of its outstanding operations, this is a
854 // fatal error and the recipient MUST forcibly abort the
855 // connection immediately.
856
857 VerifyOrExit(aHeader.GetMessageId() != 0);
858 VerifyOrExit(mPendingRequests.Contains(aHeader.GetMessageId(), requestPrimaryTlvType));
859
860 // If the response has no error and contains a primary TLV, it
861 // MUST match the request primary TLV.
862
863 if ((aHeader.GetResponseCode() == Dns::Header::kResponseSuccess) && (aPrimaryTlvType != Tlv::kReservedType))
864 {
865 VerifyOrExit(aPrimaryTlvType == requestPrimaryTlvType);
866 }
867
868 mPendingRequests.Remove(aHeader.GetMessageId());
869
870 switch (requestPrimaryTlvType)
871 {
872 case KeepAliveTlv::kType:
873 SuccessOrExit(error = ProcessKeepAliveMessage(aHeader, aMessage));
874 break;
875
876 default:
877 SuccessOrExit(error = mCallbacks.mProcessResponseMessage(*this, aHeader, aMessage, aPrimaryTlvType,
878 requestPrimaryTlvType));
879 break;
880 }
881
882 // DSO session is established when client sends a request message
883 // and receives a response from server with no error code.
884
885 if (IsClient() && (mState == kStateEstablishingSession) &&
886 (aHeader.GetResponseCode() == Dns::Header::kResponseSuccess))
887 {
888 SetState(kStateSessionEstablished);
889 }
890
891 exit:
892 return error;
893 }
894
ProcessKeepAliveMessage(const Dns::Header & aHeader,const Message & aMessage)895 Error Dso::Connection::ProcessKeepAliveMessage(const Dns::Header &aHeader, const Message &aMessage)
896 {
897 Error error = kErrorAbort;
898 uint16_t offset = aMessage.GetOffset();
899 Tlv tlv;
900 KeepAliveTlv keepAliveTlv;
901
902 if (aHeader.GetType() == Dns::Header::kTypeResponse)
903 {
904 // A Keep Alive response message is allowed on a client from a sever.
905
906 VerifyOrExit(IsClient());
907
908 if (aHeader.GetResponseCode() != Dns::Header::kResponseSuccess)
909 {
910 // We got an error response code from server for our
911 // Keep Alive request message. If this happens while
912 // establishing the DSO session, it indicates that server
913 // does not support DSO, so we close the connection. If
914 // this happens while session is already established, it
915 // is a misbehavior (fatal error) by server.
916
917 if (mState == kStateEstablishingSession)
918 {
919 Disconnect(kGracefullyClose, kReasonPeerDoesNotSupportDso);
920 error = kErrorNone;
921 }
922
923 ExitNow();
924 }
925 }
926
927 // Parse and validate the Keep Alive Message
928
929 SuccessOrExit(aMessage.Read(offset, keepAliveTlv));
930 offset += keepAliveTlv.GetSize();
931
932 VerifyOrExit((keepAliveTlv.GetType() == KeepAliveTlv::kType) && keepAliveTlv.IsValid());
933
934 // Keep Alive message MUST contain only one Keep Alive TLV.
935
936 while (offset < aMessage.GetLength())
937 {
938 SuccessOrExit(aMessage.Read(offset, tlv));
939 offset += tlv.GetSize();
940
941 VerifyOrExit((tlv.GetType() != KeepAliveTlv::kType) && (tlv.GetType() != RetryDelayTlv::kType));
942 }
943
944 VerifyOrExit(offset == aMessage.GetLength());
945
946 if (aHeader.GetType() == Dns::Header::kTypeQuery)
947 {
948 if (IsServer())
949 {
950 // Received a Keep Alive message from client. It MUST
951 // be a request message (not unidirectional). We prepare
952 // and send a Keep Alive response.
953
954 VerifyOrExit(aHeader.GetMessageId() != 0);
955
956 LogInfo("Received KeepAlive request message from client %s", mPeerSockAddr.ToString().AsCString());
957
958 IgnoreError(SendKeepAliveMessage(kResponseMessage, aHeader.GetMessageId()));
959 error = kErrorNone;
960 ExitNow();
961 }
962
963 // Received a Keep Alive message on client from server. Server
964 // Keep Alive message MUST be unidirectional (message ID
965 // zero).
966
967 VerifyOrExit(aHeader.GetMessageId() == 0);
968 }
969
970 LogInfo("Received Keep Alive %s message from server %s",
971 (aHeader.GetMessageId() == 0) ? "unidirectional" : "response", mPeerSockAddr.ToString().AsCString());
972
973 // Receiving a Keep Alive interval value from server less than the
974 // minimum (ten seconds) is a fatal error and client MUST then
975 // abort the connection.
976
977 VerifyOrExit(keepAliveTlv.GetKeepAliveInterval() >= kMinKeepAliveInterval);
978
979 // Update the timeout intervals on the connection from
980 // the new values we got from the server. The receive
981 // of the Keep Alive message does not itself reset the
982 // inactivity timer. So we use `AdjustInactivityTimeout`
983 // which takes into account the time elapsed since the
984 // last activity.
985
986 AdjustInactivityTimeout(keepAliveTlv.GetInactivityTimeout());
987 mKeepAlive.SetInterval(keepAliveTlv.GetKeepAliveInterval());
988
989 LogInfo("Timeouts Inactivity:%lu, KeepAlive:%lu", ToUlong(mInactivity.GetInterval()),
990 ToUlong(mKeepAlive.GetInterval()));
991
992 error = kErrorNone;
993
994 exit:
995 return error;
996 }
997
ProcessRetryDelayMessage(const Dns::Header & aHeader,const Message & aMessage)998 Error Dso::Connection::ProcessRetryDelayMessage(const Dns::Header &aHeader, const Message &aMessage)
999
1000 {
1001 Error error = kErrorAbort;
1002 RetryDelayTlv retryDelayTlv;
1003
1004 // Retry Delay TLV can be used as the Primary TLV only in
1005 // a unidirectional message sent from server to client.
1006 // It is used by the server to instruct the client to
1007 // close the session and its underlying connection, and not
1008 // to reconnect for the indicated time interval.
1009
1010 VerifyOrExit(IsClient() && (aHeader.GetMessageId() == 0));
1011
1012 SuccessOrExit(aMessage.Read(aMessage.GetOffset(), retryDelayTlv));
1013 VerifyOrExit(retryDelayTlv.IsValid());
1014
1015 mRetryDelayErrorCode = aHeader.GetResponseCode();
1016 mRetryDelay = retryDelayTlv.GetRetryDelay();
1017
1018 LogInfo("Received Retry Delay message from server %s", mPeerSockAddr.ToString().AsCString());
1019 LogInfo(" RetryDelay:%lu ms, ResponseCode:%d", ToUlong(mRetryDelay), mRetryDelayErrorCode);
1020
1021 Disconnect(kGracefullyClose, kReasonServerRetryDelayRequest);
1022
1023 exit:
1024 return error;
1025 }
1026
SendErrorResponse(const Dns::Header & aHeader,Dns::Header::Response aResponseCode)1027 void Dso::Connection::SendErrorResponse(const Dns::Header &aHeader, Dns::Header::Response aResponseCode)
1028 {
1029 Message *response = NewMessage();
1030 Dns::Header header;
1031
1032 VerifyOrExit(response != nullptr);
1033
1034 header.SetMessageId(aHeader.GetMessageId());
1035 header.SetType(Dns::Header::kTypeResponse);
1036 header.SetQueryType(aHeader.GetQueryType());
1037 header.SetResponseCode(aResponseCode);
1038
1039 SuccessOrExit(response->Prepend(header));
1040
1041 otPlatDsoSend(this, response);
1042 response = nullptr;
1043
1044 exit:
1045 FreeMessage(response);
1046 }
1047
AdjustInactivityTimeout(uint32_t aNewTimeout)1048 void Dso::Connection::AdjustInactivityTimeout(uint32_t aNewTimeout)
1049 {
1050 // This method sets the inactivity timeout interval to a new value
1051 // and updates the expiration time based on the new timeout value.
1052 //
1053 // On client, it is called on receiving a Keep Alive response or
1054 // unidirectional message from server. Note that the receive of
1055 // the Keep Alive message does not itself reset the inactivity
1056 // timer. So the time elapsed since the last activity should be
1057 // taken into account with the new inactivity timeout value.
1058 //
1059 // On server this method is called from `SetTimeouts()` when a new
1060 // inactivity timeout value is set.
1061
1062 TimeMilli now = TimerMilli::GetNow();
1063 TimeMilli start;
1064 TimeMilli newExpiration;
1065
1066 if (mState == kStateDisconnected)
1067 {
1068 mInactivity.SetInterval(aNewTimeout);
1069 ExitNow();
1070 }
1071
1072 VerifyOrExit(aNewTimeout != mInactivity.GetInterval());
1073
1074 // Calculate the start time (i.e., the last time inactivity timer
1075 // was cleared). If the previous inactivity time is set to
1076 // `kInfinite` value (`IsUsed()` returns `false`) then
1077 // `GetExpirationTime()` returns the start time. Otherwise, we
1078 // calculate it going back from the current expiration time with
1079 // the current wait interval.
1080
1081 if (!mInactivity.IsUsed())
1082 {
1083 start = mInactivity.GetExpirationTime();
1084 }
1085 else if (IsClient())
1086 {
1087 start = mInactivity.GetExpirationTime() - mInactivity.GetInterval();
1088 }
1089 else
1090 {
1091 start = mInactivity.GetExpirationTime() - CalculateServerInactivityWaitTime();
1092 }
1093
1094 mInactivity.SetInterval(aNewTimeout);
1095
1096 if (!mInactivity.IsUsed())
1097 {
1098 newExpiration = start;
1099 }
1100 else if (IsClient())
1101 {
1102 newExpiration = start + aNewTimeout;
1103
1104 if (newExpiration < now)
1105 {
1106 newExpiration = now;
1107 }
1108 }
1109 else
1110 {
1111 newExpiration = start + CalculateServerInactivityWaitTime();
1112
1113 if (newExpiration < now)
1114 {
1115 // If the server abruptly reduces the inactivity timeout
1116 // such that current elapsed time is already more than
1117 // twice the new inactivity timeout, then the client is
1118 // immediately considered delinquent (server can forcibly
1119 // abort the connection). So to give the client time to
1120 // close the connection gracefully, the server SHOULD
1121 // give the client an additional grace period of either
1122 // five seconds or one quarter of the new inactivity
1123 // timeout, whichever is greater [RFC 8490 - 7.1.1].
1124
1125 newExpiration = now + Max(kMinServerInactivityWaitTime, aNewTimeout / 4);
1126 }
1127 }
1128
1129 mInactivity.SetExpirationTime(newExpiration);
1130
1131 exit:
1132 return;
1133 }
1134
CalculateServerInactivityWaitTime(void) const1135 uint32_t Dso::Connection::CalculateServerInactivityWaitTime(void) const
1136 {
1137 // A server will abort an idle session after five seconds
1138 // (`kMinServerInactivityWaitTime`) or twice the inactivity
1139 // timeout value, whichever is greater [RFC 8490 - 6.4.1].
1140
1141 OT_ASSERT(mInactivity.IsUsed());
1142
1143 return Max(mInactivity.GetInterval() * 2, kMinServerInactivityWaitTime);
1144 }
1145
ResetTimeouts(bool aIsKeepAliveMessage)1146 void Dso::Connection::ResetTimeouts(bool aIsKeepAliveMessage)
1147 {
1148 NextFireTime nextTime;
1149
1150 // At both servers and clients, the generation or reception of any
1151 // complete DNS message resets both timers for that DSO
1152 // session, with the one exception being that a DSO Keep Alive
1153 // message resets only the keep alive timer, not the inactivity
1154 // timeout timer [RFC 8490 - 6.3]
1155
1156 if (mKeepAlive.IsUsed())
1157 {
1158 // On client, we wait for the Keep Alive interval but on server
1159 // we wait for twice the interval before considering Keep Alive
1160 // timeout.
1161 //
1162 // Note that we limit the interval to `Timeout::kMaxInterval`
1163 // (which is ~12 days). This max limit ensures that even twice
1164 // the interval is less than max OpenThread timer duration so
1165 // that the expiration time calculations below stay within the
1166 // `TimerMilli` range.
1167
1168 mKeepAlive.SetExpirationTime(nextTime.GetNow() + mKeepAlive.GetInterval() * (IsServer() ? 2 : 1));
1169 }
1170
1171 if (!aIsKeepAliveMessage)
1172 {
1173 if (mInactivity.IsUsed())
1174 {
1175 mInactivity.SetExpirationTime(
1176 nextTime.GetNow() + (IsServer() ? CalculateServerInactivityWaitTime() : mInactivity.GetInterval()));
1177 }
1178 else
1179 {
1180 // When Inactivity timeout is not used (i.e., interval is set
1181 // to the special `kInfinite` value), we still need to track
1182 // the time so that if/when later the inactivity interval
1183 // gets changed, we can adjust the remaining time correctly
1184 // from `AdjustInactivityTimeout()`. In this case, we just
1185 // track the current time as "expiration time".
1186
1187 mInactivity.SetExpirationTime(nextTime.GetNow());
1188 }
1189 }
1190
1191 UpdateNextFireTime(nextTime);
1192
1193 Get<Dso>().mTimer.FireAtIfEarlier(nextTime);
1194 }
1195
UpdateNextFireTime(NextFireTime & aNextTime) const1196 void Dso::Connection::UpdateNextFireTime(NextFireTime &aNextTime) const
1197 {
1198 switch (mState)
1199 {
1200 case kStateDisconnected:
1201 break;
1202
1203 case kStateConnecting:
1204 // While in `kStateConnecting`, Keep Alive timer is
1205 // used for `kConnectingTimeout`.
1206 aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1207 break;
1208
1209 case kStateConnectedButSessionless:
1210 case kStateEstablishingSession:
1211 case kStateSessionEstablished:
1212 mPendingRequests.UpdateNextFireTime(aNextTime);
1213
1214 if (mKeepAlive.IsUsed())
1215 {
1216 aNextTime.UpdateIfEarlier(mKeepAlive.GetExpirationTime());
1217 }
1218
1219 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation)
1220 {
1221 // An operation being active on a DSO Session includes
1222 // a request message waiting for a response, or an
1223 // active long-lived operation.
1224
1225 aNextTime.UpdateIfEarlier(mInactivity.GetExpirationTime());
1226 }
1227
1228 break;
1229 }
1230 }
1231
HandleTimer(NextFireTime & aNextTime)1232 void Dso::Connection::HandleTimer(NextFireTime &aNextTime)
1233 {
1234 switch (mState)
1235 {
1236 case kStateDisconnected:
1237 break;
1238
1239 case kStateConnecting:
1240 if (mKeepAlive.IsExpired(aNextTime.GetNow()))
1241 {
1242 Disconnect(kGracefullyClose, kReasonFailedToConnect);
1243 }
1244 break;
1245
1246 case kStateConnectedButSessionless:
1247 case kStateEstablishingSession:
1248 case kStateSessionEstablished:
1249 if (mPendingRequests.HasAnyTimedOut(aNextTime.GetNow()))
1250 {
1251 // If server sends no response to a request, client
1252 // waits for 30 seconds (`kResponseTimeout`) after which
1253 // client MUST forcibly abort the connection.
1254 Disconnect(kForciblyAbort, kReasonResponseTimeout);
1255 ExitNow();
1256 }
1257
1258 // The inactivity timer is kept clear, while an operation is
1259 // active on the session (which includes a request waiting for
1260 // response or an active long-lived operation).
1261
1262 if (mInactivity.IsUsed() && mPendingRequests.IsEmpty() && !mLongLivedOperation &&
1263 mInactivity.IsExpired(aNextTime.GetNow()))
1264 {
1265 // On client, if the inactivity timeout is reached, the
1266 // connection is closed gracefully. On server, if too much
1267 // time (`CalculateServerInactivityWaitTime()`, i.e., five
1268 // seconds or twice the current inactivity timeout interval,
1269 // whichever is grater) elapses server MUST consider the
1270 // client delinquent and MUST forcibly abort the connection.
1271
1272 Disconnect(IsClient() ? kGracefullyClose : kForciblyAbort, kReasonInactivityTimeout);
1273 ExitNow();
1274 }
1275
1276 if (mKeepAlive.IsUsed() && mKeepAlive.IsExpired(aNextTime.GetNow()))
1277 {
1278 // On client, if the Keep Alive interval elapses without any
1279 // DNS messages being sent or received, the client MUST take
1280 // action and send a DSO Keep Alive message.
1281 //
1282 // On server, if twice the Keep Alive interval value elapses
1283 // without any messages being sent or received, the server
1284 // considers the client delinquent and aborts the connection.
1285
1286 if (IsClient())
1287 {
1288 IgnoreError(SendKeepAliveMessage());
1289 }
1290 else
1291 {
1292 Disconnect(kForciblyAbort, kReasonKeepAliveTimeout);
1293 ExitNow();
1294 }
1295 }
1296 break;
1297 }
1298
1299 exit:
1300 UpdateNextFireTime(aNextTime);
1301 SignalAnyStateChange();
1302 }
1303
StateToString(State aState)1304 const char *Dso::Connection::StateToString(State aState)
1305 {
1306 static const char *const kStateStrings[] = {
1307 "Disconnected", // (0) kStateDisconnected,
1308 "Connecting", // (1) kStateConnecting,
1309 "ConnectedButSessionless", // (2) kStateConnectedButSessionless,
1310 "EstablishingSession", // (3) kStateEstablishingSession,
1311 "SessionEstablished", // (4) kStateSessionEstablished,
1312 };
1313
1314 static_assert(0 == kStateDisconnected, "kStateDisconnected value is incorrect");
1315 static_assert(1 == kStateConnecting, "kStateConnecting value is incorrect");
1316 static_assert(2 == kStateConnectedButSessionless, "kStateConnectedButSessionless value is incorrect");
1317 static_assert(3 == kStateEstablishingSession, "kStateEstablishingSession value is incorrect");
1318 static_assert(4 == kStateSessionEstablished, "kStateSessionEstablished value is incorrect");
1319
1320 return kStateStrings[aState];
1321 }
1322
MessageTypeToString(MessageType aMessageType)1323 const char *Dso::Connection::MessageTypeToString(MessageType aMessageType)
1324 {
1325 static const char *const kMessageTypeStrings[] = {
1326 "Request", // (0) kRequestMessage
1327 "Response", // (1) kResponseMessage
1328 "Unidirectional", // (2) kUnidirectionalMessage
1329 };
1330
1331 static_assert(0 == kRequestMessage, "kRequestMessage value is incorrect");
1332 static_assert(1 == kResponseMessage, "kResponseMessage value is incorrect");
1333 static_assert(2 == kUnidirectionalMessage, "kUnidirectionalMessage value is incorrect");
1334
1335 return kMessageTypeStrings[aMessageType];
1336 }
1337
DisconnectReasonToString(DisconnectReason aReason)1338 const char *Dso::Connection::DisconnectReasonToString(DisconnectReason aReason)
1339 {
1340 static const char *const kDisconnectReasonStrings[] = {
1341 "FailedToConnect", // (0) kReasonFailedToConnect
1342 "ResponseTimeout", // (1) kReasonResponseTimeout
1343 "PeerDoesNotSupportDso", // (2) kReasonPeerDoesNotSupportDso
1344 "PeerClosed", // (3) kReasonPeerClosed
1345 "PeerAborted", // (4) kReasonPeerAborted
1346 "InactivityTimeout", // (5) kReasonInactivityTimeout
1347 "KeepAliveTimeout", // (6) kReasonKeepAliveTimeout
1348 "ServerRetryDelayRequest", // (7) kReasonServerRetryDelayRequest
1349 "PeerMisbehavior", // (8) kReasonPeerMisbehavior
1350 "Unknown", // (9) kReasonUnknown
1351 };
1352
1353 static_assert(0 == kReasonFailedToConnect, "kReasonFailedToConnect value is incorrect");
1354 static_assert(1 == kReasonResponseTimeout, "kReasonResponseTimeout value is incorrect");
1355 static_assert(2 == kReasonPeerDoesNotSupportDso, "kReasonPeerDoesNotSupportDso value is incorrect");
1356 static_assert(3 == kReasonPeerClosed, "kReasonPeerClosed value is incorrect");
1357 static_assert(4 == kReasonPeerAborted, "kReasonPeerAborted value is incorrect");
1358 static_assert(5 == kReasonInactivityTimeout, "kReasonInactivityTimeout value is incorrect");
1359 static_assert(6 == kReasonKeepAliveTimeout, "kReasonKeepAliveTimeout value is incorrect");
1360 static_assert(7 == kReasonServerRetryDelayRequest, "kReasonServerRetryDelayRequest value is incorrect");
1361 static_assert(8 == kReasonPeerMisbehavior, "kReasonPeerMisbehavior value is incorrect");
1362 static_assert(9 == kReasonUnknown, "kReasonUnknown value is incorrect");
1363
1364 return kDisconnectReasonStrings[aReason];
1365 }
1366
1367 //---------------------------------------------------------------------------------------------------------------------
1368 // Dso::Connection::PendingRequests
1369
Contains(MessageId aMessageId,Tlv::Type & aPrimaryTlvType) const1370 bool Dso::Connection::PendingRequests::Contains(MessageId aMessageId, Tlv::Type &aPrimaryTlvType) const
1371 {
1372 bool contains = true;
1373 const Entry *entry = mRequests.FindMatching(aMessageId);
1374
1375 VerifyOrExit(entry != nullptr, contains = false);
1376 aPrimaryTlvType = entry->mPrimaryTlvType;
1377
1378 exit:
1379 return contains;
1380 }
1381
Add(MessageId aMessageId,Tlv::Type aPrimaryTlvType,TimeMilli aResponseTimeout)1382 Error Dso::Connection::PendingRequests::Add(MessageId aMessageId, Tlv::Type aPrimaryTlvType, TimeMilli aResponseTimeout)
1383 {
1384 Error error = kErrorNone;
1385 Entry *entry = mRequests.PushBack();
1386
1387 VerifyOrExit(entry != nullptr, error = kErrorNoBufs);
1388 entry->mMessageId = aMessageId;
1389 entry->mPrimaryTlvType = aPrimaryTlvType;
1390 entry->mTimeout = aResponseTimeout;
1391
1392 exit:
1393 return error;
1394 }
1395
Remove(MessageId aMessageId)1396 void Dso::Connection::PendingRequests::Remove(MessageId aMessageId) { mRequests.RemoveMatching(aMessageId); }
1397
HasAnyTimedOut(TimeMilli aNow) const1398 bool Dso::Connection::PendingRequests::HasAnyTimedOut(TimeMilli aNow) const
1399 {
1400 bool timedOut = false;
1401
1402 for (const Entry &entry : mRequests)
1403 {
1404 if (entry.mTimeout <= aNow)
1405 {
1406 timedOut = true;
1407 break;
1408 }
1409 }
1410
1411 return timedOut;
1412 }
1413
UpdateNextFireTime(NextFireTime & aNextTime) const1414 void Dso::Connection::PendingRequests::UpdateNextFireTime(NextFireTime &aNextTime) const
1415 {
1416 for (const Entry &entry : mRequests)
1417 {
1418 aNextTime.UpdateIfEarlier(entry.mTimeout);
1419 }
1420 }
1421
1422 //---------------------------------------------------------------------------------------------------------------------
1423 // Dso
1424
Dso(Instance & aInstance)1425 Dso::Dso(Instance &aInstance)
1426 : InstanceLocator(aInstance)
1427 , mAcceptHandler(nullptr)
1428 , mTimer(aInstance)
1429 {
1430 }
1431
StartListening(AcceptHandler aAcceptHandler)1432 void Dso::StartListening(AcceptHandler aAcceptHandler)
1433 {
1434 mAcceptHandler = aAcceptHandler;
1435 otPlatDsoEnableListening(&GetInstance(), true);
1436 }
1437
StopListening(void)1438 void Dso::StopListening(void) { otPlatDsoEnableListening(&GetInstance(), false); }
1439
FindClientConnection(const Ip6::SockAddr & aPeerSockAddr)1440 Dso::Connection *Dso::FindClientConnection(const Ip6::SockAddr &aPeerSockAddr)
1441 {
1442 return mClientConnections.FindMatching(aPeerSockAddr);
1443 }
1444
FindServerConnection(const Ip6::SockAddr & aPeerSockAddr)1445 Dso::Connection *Dso::FindServerConnection(const Ip6::SockAddr &aPeerSockAddr)
1446 {
1447 return mServerConnections.FindMatching(aPeerSockAddr);
1448 }
1449
AcceptConnection(const Ip6::SockAddr & aPeerSockAddr)1450 Dso::Connection *Dso::AcceptConnection(const Ip6::SockAddr &aPeerSockAddr)
1451 {
1452 Connection *connection = nullptr;
1453
1454 VerifyOrExit(mAcceptHandler != nullptr);
1455 connection = mAcceptHandler(GetInstance(), aPeerSockAddr);
1456
1457 VerifyOrExit(connection != nullptr);
1458 connection->Accept();
1459
1460 exit:
1461 return connection;
1462 }
1463
HandleTimer(void)1464 void Dso::HandleTimer(void)
1465 {
1466 NextFireTime nextTime;
1467 Connection *conn;
1468 Connection *next;
1469
1470 for (conn = mClientConnections.GetHead(); conn != nullptr; conn = next)
1471 {
1472 next = conn->GetNext();
1473 conn->HandleTimer(nextTime);
1474 }
1475
1476 for (conn = mServerConnections.GetHead(); conn != nullptr; conn = next)
1477 {
1478 next = conn->GetNext();
1479 conn->HandleTimer(nextTime);
1480 }
1481
1482 mTimer.FireAtIfEarlier(nextTime);
1483 }
1484
1485 } // namespace Dns
1486 } // namespace ot
1487
1488 #if OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1489
otPlatDsoEnableListening(otInstance * aInstance,bool aEnable)1490 OT_TOOL_WEAK void otPlatDsoEnableListening(otInstance *aInstance, bool aEnable)
1491 {
1492 OT_UNUSED_VARIABLE(aInstance);
1493 OT_UNUSED_VARIABLE(aEnable);
1494 }
1495
otPlatDsoConnect(otPlatDsoConnection * aConnection,const otSockAddr * aPeerSockAddr)1496 OT_TOOL_WEAK void otPlatDsoConnect(otPlatDsoConnection *aConnection, const otSockAddr *aPeerSockAddr)
1497 {
1498 OT_UNUSED_VARIABLE(aConnection);
1499 OT_UNUSED_VARIABLE(aPeerSockAddr);
1500 }
1501
otPlatDsoSend(otPlatDsoConnection * aConnection,otMessage * aMessage)1502 OT_TOOL_WEAK void otPlatDsoSend(otPlatDsoConnection *aConnection, otMessage *aMessage)
1503 {
1504 OT_UNUSED_VARIABLE(aConnection);
1505 OT_UNUSED_VARIABLE(aMessage);
1506 }
1507
otPlatDsoDisconnect(otPlatDsoConnection * aConnection,otPlatDsoDisconnectMode aMode)1508 OT_TOOL_WEAK void otPlatDsoDisconnect(otPlatDsoConnection *aConnection, otPlatDsoDisconnectMode aMode)
1509 {
1510 OT_UNUSED_VARIABLE(aConnection);
1511 OT_UNUSED_VARIABLE(aMode);
1512 }
1513
1514 #endif // OPENTHREAD_CONFIG_DNS_DSO_MOCK_PLAT_APIS_ENABLE
1515
1516 #endif // OPENTHREAD_CONFIG_DNS_DSO_ENABLE
1517