1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the DNS-SD server.
32  */
33 
34 #include "dnssd_server.hpp"
35 
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37 
38 #include <openthread/platform/dns.h>
39 
40 #include "common/array.hpp"
41 #include "common/as_core_type.hpp"
42 #include "common/code_utils.hpp"
43 #include "common/debug.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/log.hpp"
46 #include "common/string.hpp"
47 #include "instance/instance.hpp"
48 #include "net/srp_server.hpp"
49 #include "net/udp6.hpp"
50 
51 namespace ot {
52 namespace Dns {
53 namespace ServiceDiscovery {
54 
55 RegisterLogModule("DnssdServer");
56 
57 const char Server::kDefaultDomainName[] = "default.service.arpa.";
58 const char Server::kSubLabel[]          = "_sub";
59 
60 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
61 const char *Server::kBlockedDomains[] = {"ipv4only.arpa."};
62 #endif
63 
Server(Instance & aInstance)64 Server::Server(Instance &aInstance)
65     : InstanceLocator(aInstance)
66     , mSocket(aInstance, *this)
67 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
68     , mDiscoveryProxy(aInstance)
69 #endif
70 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
71     , mEnableUpstreamQuery(false)
72 #endif
73     , mTimer(aInstance)
74     , mTestMode(kTestModeDisabled)
75 {
76     mCounters.Clear();
77 }
78 
Start(void)79 Error Server::Start(void)
80 {
81     Error error = kErrorNone;
82 
83     VerifyOrExit(!IsRunning());
84 
85     SuccessOrExit(error = mSocket.Open());
86     SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? Ip6::kNetifUnspecified : Ip6::kNetifThread));
87 
88 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
89     Get<Srp::Server>().HandleDnssdServerStateChange();
90 #endif
91 
92     LogInfo("Started");
93 
94 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
95     mDiscoveryProxy.UpdateState();
96 #endif
97 
98 exit:
99     if (error != kErrorNone)
100     {
101         IgnoreError(mSocket.Close());
102     }
103 
104     return error;
105 }
106 
Stop(void)107 void Server::Stop(void)
108 {
109     for (ProxyQuery &query : mProxyQueries)
110     {
111         Finalize(query, Header::kResponseServerFailure);
112     }
113 
114 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
115     mDiscoveryProxy.Stop();
116 #endif
117 
118 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
119     for (UpstreamQueryTransaction &txn : mUpstreamQueryTransactions)
120     {
121         if (txn.IsValid())
122         {
123             ResetUpstreamQueryTransaction(txn, kErrorFailed);
124         }
125     }
126 #endif
127 
128     mTimer.Stop();
129 
130     IgnoreError(mSocket.Close());
131     LogInfo("Stopped");
132 
133 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
134     Get<Srp::Server>().HandleDnssdServerStateChange();
135 #endif
136 }
137 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)138 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
139 {
140     Request request;
141 
142 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
143     // We first let the `Srp::Server` process the received message.
144     // It returns `kErrorNone` to indicate that it successfully
145     // processed the message.
146 
147     VerifyOrExit(Get<Srp::Server>().HandleDnssdServerUdpReceive(aMessage, aMessageInfo) != kErrorNone);
148 #endif
149 
150     request.mMessage     = &aMessage;
151     request.mMessageInfo = &aMessageInfo;
152     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), request.mHeader));
153 
154     VerifyOrExit(request.mHeader.GetType() == Header::kTypeQuery);
155 
156     LogInfo("Received query from %s", aMessageInfo.GetPeerAddr().ToString().AsCString());
157 
158     ProcessQuery(request);
159 
160 exit:
161     return;
162 }
163 
ProcessQuery(Request & aRequest)164 void Server::ProcessQuery(Request &aRequest)
165 {
166     ResponseCode rcode         = Header::kResponseSuccess;
167     bool         shouldRespond = true;
168     Response     response(GetInstance());
169 
170 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
171     if (mEnableUpstreamQuery && ShouldForwardToUpstream(aRequest))
172     {
173         Error error = ResolveByUpstream(aRequest);
174 
175         if (error == kErrorNone)
176         {
177             ExitNow();
178         }
179 
180         LogWarnOnError(error, "forwarding to upstream");
181 
182         rcode = Header::kResponseServerFailure;
183 
184         // Continue to allocate and prepare the response message
185         // to send the `kResponseServerFailure` response code.
186     }
187 #endif
188 
189     SuccessOrExit(response.AllocateAndInitFrom(aRequest));
190 
191 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
192     // Forwarding the query to the upstream may have already set the
193     // response error code.
194     SuccessOrExit(rcode);
195 #endif
196 
197     SuccessOrExit(rcode = aRequest.ParseQuestions(mTestMode, shouldRespond));
198     SuccessOrExit(rcode = response.AddQuestionsFrom(aRequest));
199 
200 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
201     response.Log();
202 #endif
203 
204 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
205     switch (response.ResolveBySrp())
206     {
207     case kErrorNone:
208         mCounters.mResolvedBySrp++;
209         ExitNow();
210 
211     case kErrorNotFound:
212         rcode = Header::kResponseNameError;
213         break;
214 
215     default:
216         rcode = Header::kResponseServerFailure;
217         ExitNow();
218     }
219 #endif
220 
221     ResolveByProxy(response, *aRequest.mMessageInfo);
222 
223 exit:
224     if (rcode != Header::kResponseSuccess)
225     {
226         response.SetResponseCode(rcode);
227     }
228 
229     if (shouldRespond)
230     {
231         response.Send(*aRequest.mMessageInfo);
232     }
233 }
234 
Response(Instance & aInstance)235 Server::Response::Response(Instance &aInstance)
236     : InstanceLocator(aInstance)
237 {
238     // `mHeader` constructors already clears it
239 
240     mOffsets.Clear();
241 }
242 
AllocateAndInitFrom(const Request & aRequest)243 Error Server::Response::AllocateAndInitFrom(const Request &aRequest)
244 {
245     Error error = kErrorNone;
246 
247     mMessage.Reset(Get<Server>().mSocket.NewMessage());
248     VerifyOrExit(!mMessage.IsNull(), error = kErrorNoBufs);
249 
250     mHeader.SetType(Header::kTypeResponse);
251     mHeader.SetMessageId(aRequest.mHeader.GetMessageId());
252     mHeader.SetQueryType(aRequest.mHeader.GetQueryType());
253 
254     if (aRequest.mHeader.IsRecursionDesiredFlagSet())
255     {
256         mHeader.SetRecursionDesiredFlag();
257     }
258 
259     // Append the empty header to reserve room for it in the message.
260     // Header will be updated in the message before sending it.
261     error = mMessage->Append(mHeader);
262 
263 exit:
264     if (error != kErrorNone)
265     {
266         mMessage.Free();
267     }
268 
269     return error;
270 }
271 
Send(const Ip6::MessageInfo & aMessageInfo)272 void Server::Response::Send(const Ip6::MessageInfo &aMessageInfo)
273 {
274     ResponseCode rcode = mHeader.GetResponseCode();
275 
276     VerifyOrExit(!mMessage.IsNull());
277 
278     if (rcode == Header::kResponseServerFailure)
279     {
280         mHeader.SetQuestionCount(0);
281         mHeader.SetAnswerCount(0);
282         mHeader.SetAdditionalRecordCount(0);
283         IgnoreError(mMessage->SetLength(sizeof(Header)));
284     }
285 
286     mMessage->Write(0, mHeader);
287 
288     SuccessOrExit(Get<Server>().mSocket.SendTo(*mMessage, aMessageInfo));
289 
290     // When `SendTo()` returns success it takes over ownership of
291     // the given message, so we release ownership of `mMessage`.
292 
293     mMessage.Release();
294 
295     LogInfo("Send response, rcode:%u", rcode);
296 
297     Get<Server>().UpdateResponseCounters(rcode);
298 
299 exit:
300     return;
301 }
302 
ParseQuestions(uint8_t aTestMode,bool & aShouldRespond)303 Server::ResponseCode Server::Request::ParseQuestions(uint8_t aTestMode, bool &aShouldRespond)
304 {
305     // Parse header and questions from a `Request` query message and
306     // determine the `QueryType`.
307 
308     ResponseCode rcode         = Header::kResponseFormatError;
309     uint16_t     offset        = sizeof(Header);
310     uint16_t     questionCount = mHeader.GetQuestionCount();
311     Question     question;
312 
313     aShouldRespond = true;
314 
315     VerifyOrExit(mHeader.GetQueryType() == Header::kQueryTypeStandard, rcode = Header::kResponseNotImplemented);
316     VerifyOrExit(!mHeader.IsTruncationFlagSet());
317 
318     VerifyOrExit(questionCount > 0);
319 
320     SuccessOrExit(Name::ParseName(*mMessage, offset));
321     SuccessOrExit(mMessage->Read(offset, question));
322     offset += sizeof(question);
323 
324     switch (question.GetType())
325     {
326     case ResourceRecord::kTypePtr:
327         mType = kPtrQuery;
328         break;
329     case ResourceRecord::kTypeSrv:
330         mType = kSrvQuery;
331         break;
332     case ResourceRecord::kTypeTxt:
333         mType = kTxtQuery;
334         break;
335     case ResourceRecord::kTypeAaaa:
336         mType = kAaaaQuery;
337         break;
338     case ResourceRecord::kTypeA:
339         mType = kAQuery;
340         break;
341     default:
342         ExitNow(rcode = Header::kResponseNotImplemented);
343     }
344 
345     if (questionCount > 1)
346     {
347         VerifyOrExit(!(aTestMode & kTestModeRejectMultiQuestionQuery));
348         VerifyOrExit(!(aTestMode & kTestModeIgnoreMultiQuestionQuery), aShouldRespond = false);
349 
350         VerifyOrExit(questionCount == 2);
351 
352         SuccessOrExit(Name::CompareName(*mMessage, offset, *mMessage, sizeof(Header)));
353         SuccessOrExit(mMessage->Read(offset, question));
354 
355         switch (question.GetType())
356         {
357         case ResourceRecord::kTypeSrv:
358             VerifyOrExit(mType == kTxtQuery);
359             break;
360 
361         case ResourceRecord::kTypeTxt:
362             VerifyOrExit(mType == kSrvQuery);
363             break;
364 
365         default:
366             ExitNow();
367         }
368 
369         mType = kSrvTxtQuery;
370     }
371 
372     rcode = Header::kResponseSuccess;
373 
374 exit:
375     return rcode;
376 }
377 
AddQuestionsFrom(const Request & aRequest)378 Server::ResponseCode Server::Response::AddQuestionsFrom(const Request &aRequest)
379 {
380     ResponseCode rcode = Header::kResponseServerFailure;
381     uint16_t     offset;
382 
383     mType = aRequest.mType;
384 
385     // Read the name from `aRequest.mMessage` and append it as is to
386     // the response message. This ensures all name formats, including
387     // service instance names with dot characters in the instance
388     // label, are appended correctly.
389 
390     SuccessOrExit(Name(*aRequest.mMessage, sizeof(Header)).AppendTo(*mMessage));
391 
392     // Check the name to include the correct domain name and determine
393     // the domain name offset (for DNS name compression).
394 
395     VerifyOrExit(ParseQueryName() == kErrorNone, rcode = Header::kResponseNameError);
396 
397     mHeader.SetQuestionCount(aRequest.mHeader.GetQuestionCount());
398 
399     offset = sizeof(Header);
400 
401     for (uint16_t questionCount = 0; questionCount < mHeader.GetQuestionCount(); questionCount++)
402     {
403         Question question;
404 
405         // The names and questions in `aRequest` are validated already
406         // from `ParseQuestions()`, so we can `IgnoreError()`  here.
407 
408         IgnoreError(Name::ParseName(*aRequest.mMessage, offset));
409         IgnoreError(aRequest.mMessage->Read(offset, question));
410         offset += sizeof(question);
411 
412         if (questionCount != 0)
413         {
414             SuccessOrExit(AppendQueryName());
415         }
416 
417         SuccessOrExit(mMessage->Append(question));
418     }
419 
420     rcode = Header::kResponseSuccess;
421 
422 exit:
423     return rcode;
424 }
425 
ParseQueryName(void)426 Error Server::Response::ParseQueryName(void)
427 {
428     // Parses and validates the query name and updates
429     // the name compression offsets.
430 
431     Error        error = kErrorNone;
432     Name::Buffer name;
433     uint16_t     offset;
434 
435     offset = sizeof(Header);
436     SuccessOrExit(error = Name::ReadName(*mMessage, offset, name));
437 
438     switch (mType)
439     {
440     case kPtrQuery:
441         // `mOffsets.mServiceName` may be updated as we read labels and if we
442         // determine that the query name is a sub-type service.
443         mOffsets.mServiceName = sizeof(Header);
444         break;
445 
446     case kSrvQuery:
447     case kTxtQuery:
448     case kSrvTxtQuery:
449         mOffsets.mInstanceName = sizeof(Header);
450         break;
451 
452     case kAaaaQuery:
453     case kAQuery:
454         mOffsets.mHostName = sizeof(Header);
455         break;
456     }
457 
458     // Read the query name labels one by one to check if the name is
459     // service sub-type and also check that it is sub-domain of the
460     // default domain name and determine its offset
461 
462     offset = sizeof(Header);
463 
464     while (true)
465     {
466         Name::LabelBuffer label;
467         uint8_t           labelLength = sizeof(label);
468         uint16_t          comapreOffset;
469 
470         SuccessOrExit(error = Name::ReadLabel(*mMessage, offset, label, labelLength));
471 
472         if ((mType == kPtrQuery) && StringMatch(label, kSubLabel, kStringCaseInsensitiveMatch))
473         {
474             mOffsets.mServiceName = offset;
475         }
476 
477         comapreOffset = offset;
478 
479         if (Name::CompareName(*mMessage, comapreOffset, kDefaultDomainName) == kErrorNone)
480         {
481             mOffsets.mDomainName = offset;
482             ExitNow();
483         }
484     }
485 
486     error = kErrorParse;
487 
488 exit:
489     return error;
490 }
491 
ReadQueryName(Name::Buffer & aName) const492 void Server::Response::ReadQueryName(Name::Buffer &aName) const { Server::ReadQueryName(*mMessage, aName); }
493 
QueryNameMatches(const char * aName) const494 bool Server::Response::QueryNameMatches(const char *aName) const { return Server::QueryNameMatches(*mMessage, aName); }
495 
AppendQueryName(void)496 Error Server::Response::AppendQueryName(void) { return Name::AppendPointerLabel(sizeof(Header), *mMessage); }
497 
AppendPtrRecord(const char * aInstanceLabel,uint32_t aTtl)498 Error Server::Response::AppendPtrRecord(const char *aInstanceLabel, uint32_t aTtl)
499 {
500     Error     error;
501     uint16_t  recordOffset;
502     PtrRecord ptrRecord;
503 
504     ptrRecord.Init();
505     ptrRecord.SetTtl(aTtl);
506 
507     SuccessOrExit(error = AppendQueryName());
508 
509     recordOffset = mMessage->GetLength();
510     SuccessOrExit(error = mMessage->Append(ptrRecord));
511 
512     mOffsets.mInstanceName = mMessage->GetLength();
513     SuccessOrExit(error = Name::AppendLabel(aInstanceLabel, *mMessage));
514     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mServiceName, *mMessage));
515 
516     UpdateRecordLength(ptrRecord, recordOffset);
517 
518     IncResourceRecordCount();
519 
520 exit:
521     return error;
522 }
523 
524 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendSrvRecord(const Srp::Server::Service & aService)525 Error Server::Response::AppendSrvRecord(const Srp::Server::Service &aService)
526 {
527     uint32_t ttl = TimeMilli::MsecToSec(aService.GetExpireTime() - TimerMilli::GetNow());
528 
529     return AppendSrvRecord(aService.GetHost().GetFullName(), ttl, aService.GetPriority(), aService.GetWeight(),
530                            aService.GetPort());
531 }
532 #endif
533 
AppendSrvRecord(const ServiceInstanceInfo & aInstanceInfo)534 Error Server::Response::AppendSrvRecord(const ServiceInstanceInfo &aInstanceInfo)
535 {
536     return AppendSrvRecord(aInstanceInfo.mHostName, aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
537                            aInstanceInfo.mPort);
538 }
539 
AppendSrvRecord(const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort)540 Error Server::Response::AppendSrvRecord(const char *aHostName,
541                                         uint32_t    aTtl,
542                                         uint16_t    aPriority,
543                                         uint16_t    aWeight,
544                                         uint16_t    aPort)
545 {
546     Error        error = kErrorNone;
547     SrvRecord    srvRecord;
548     uint16_t     recordOffset;
549     Name::Buffer hostLabels;
550 
551     SuccessOrExit(error = Name::ExtractLabels(aHostName, kDefaultDomainName, hostLabels));
552 
553     srvRecord.Init();
554     srvRecord.SetTtl(aTtl);
555     srvRecord.SetPriority(aPriority);
556     srvRecord.SetWeight(aWeight);
557     srvRecord.SetPort(aPort);
558 
559     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mInstanceName, *mMessage));
560 
561     recordOffset = mMessage->GetLength();
562     SuccessOrExit(error = mMessage->Append(srvRecord));
563 
564     mOffsets.mHostName = mMessage->GetLength();
565     SuccessOrExit(error = Name::AppendMultipleLabels(hostLabels, *mMessage));
566     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mDomainName, *mMessage));
567 
568     UpdateRecordLength(srvRecord, recordOffset);
569 
570     IncResourceRecordCount();
571 
572 exit:
573     return error;
574 }
575 
576 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendHostAddresses(const Srp::Server::Host & aHost)577 Error Server::Response::AppendHostAddresses(const Srp::Server::Host &aHost)
578 {
579     const Ip6::Address *addrs;
580     uint8_t             addrsLength;
581     uint32_t            ttl;
582 
583     addrs = aHost.GetAddresses(addrsLength);
584     ttl   = TimeMilli::MsecToSec(aHost.GetExpireTime() - TimerMilli::GetNow());
585 
586     return AppendHostAddresses(kIp6AddrType, addrs, addrsLength, ttl);
587 }
588 #endif
589 
AppendHostAddresses(AddrType aAddrType,const HostInfo & aHostInfo)590 Error Server::Response::AppendHostAddresses(AddrType aAddrType, const HostInfo &aHostInfo)
591 {
592     return AppendHostAddresses(aAddrType, AsCoreTypePtr(aHostInfo.mAddresses), aHostInfo.mAddressNum, aHostInfo.mTtl);
593 }
594 
AppendHostAddresses(const ServiceInstanceInfo & aInstanceInfo)595 Error Server::Response::AppendHostAddresses(const ServiceInstanceInfo &aInstanceInfo)
596 {
597     return AppendHostAddresses(kIp6AddrType, AsCoreTypePtr(aInstanceInfo.mAddresses), aInstanceInfo.mAddressNum,
598                                aInstanceInfo.mTtl);
599 }
600 
AppendHostAddresses(AddrType aAddrType,const Ip6::Address * aAddrs,uint16_t aAddrsLength,uint32_t aTtl)601 Error Server::Response::AppendHostAddresses(AddrType            aAddrType,
602                                             const Ip6::Address *aAddrs,
603                                             uint16_t            aAddrsLength,
604                                             uint32_t            aTtl)
605 {
606     Error error = kErrorNone;
607 
608     for (uint16_t index = 0; index < aAddrsLength; index++)
609     {
610         const Ip6::Address &address = aAddrs[index];
611 
612         switch (aAddrType)
613         {
614         case kIp6AddrType:
615             SuccessOrExit(error = AppendAaaaRecord(address, aTtl));
616             break;
617 
618         case kIp4AddrType:
619             SuccessOrExit(error = AppendARecord(address, aTtl));
620             break;
621         }
622     }
623 
624 exit:
625     return error;
626 }
627 
AppendAaaaRecord(const Ip6::Address & aAddress,uint32_t aTtl)628 Error Server::Response::AppendAaaaRecord(const Ip6::Address &aAddress, uint32_t aTtl)
629 {
630     Error      error = kErrorNone;
631     AaaaRecord aaaaRecord;
632 
633     VerifyOrExit(!aAddress.IsIp4Mapped());
634 
635     aaaaRecord.Init();
636     aaaaRecord.SetTtl(aTtl);
637     aaaaRecord.SetAddress(aAddress);
638 
639     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mHostName, *mMessage));
640     SuccessOrExit(error = mMessage->Append(aaaaRecord));
641     IncResourceRecordCount();
642 
643 exit:
644     return error;
645 }
646 
AppendARecord(const Ip6::Address & aAddress,uint32_t aTtl)647 Error Server::Response::AppendARecord(const Ip6::Address &aAddress, uint32_t aTtl)
648 {
649     Error        error = kErrorNone;
650     ARecord      aRecord;
651     Ip4::Address ip4Address;
652 
653     SuccessOrExit(ip4Address.ExtractFromIp4MappedIp6Address(aAddress));
654 
655     aRecord.Init();
656     aRecord.SetTtl(aTtl);
657     aRecord.SetAddress(ip4Address);
658 
659     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mHostName, *mMessage));
660     SuccessOrExit(error = mMessage->Append(aRecord));
661     IncResourceRecordCount();
662 
663 exit:
664     return error;
665 }
666 
667 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
AppendTxtRecord(const Srp::Server::Service & aService)668 Error Server::Response::AppendTxtRecord(const Srp::Server::Service &aService)
669 {
670     return AppendTxtRecord(aService.GetTxtData(), aService.GetTxtDataLength(),
671                            TimeMilli::MsecToSec(aService.GetExpireTime() - TimerMilli::GetNow()));
672 }
673 #endif
674 
AppendTxtRecord(const ServiceInstanceInfo & aInstanceInfo)675 Error Server::Response::AppendTxtRecord(const ServiceInstanceInfo &aInstanceInfo)
676 {
677     return AppendTxtRecord(aInstanceInfo.mTxtData, aInstanceInfo.mTxtLength, aInstanceInfo.mTtl);
678 }
679 
AppendTxtRecord(const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl)680 Error Server::Response::AppendTxtRecord(const void *aTxtData, uint16_t aTxtLength, uint32_t aTtl)
681 {
682     Error     error = kErrorNone;
683     TxtRecord txtRecord;
684     uint8_t   emptyTxt = 0;
685 
686     if (aTxtLength == 0)
687     {
688         aTxtData   = &emptyTxt;
689         aTxtLength = sizeof(emptyTxt);
690     }
691 
692     txtRecord.Init();
693     txtRecord.SetTtl(aTtl);
694     txtRecord.SetLength(aTxtLength);
695 
696     SuccessOrExit(error = Name::AppendPointerLabel(mOffsets.mInstanceName, *mMessage));
697     SuccessOrExit(error = mMessage->Append(txtRecord));
698     SuccessOrExit(error = mMessage->AppendBytes(aTxtData, aTxtLength));
699 
700     IncResourceRecordCount();
701 
702 exit:
703     return error;
704 }
705 
UpdateRecordLength(ResourceRecord & aRecord,uint16_t aOffset)706 void Server::Response::UpdateRecordLength(ResourceRecord &aRecord, uint16_t aOffset)
707 {
708     // Calculates RR DATA length and updates and re-writes it in the
709     // response message. This should be called immediately
710     // after all the fields in the record are written in the message.
711     // `aOffset` gives the offset in the message to the start of the
712     // record.
713 
714     aRecord.SetLength(mMessage->GetLength() - aOffset - sizeof(Dns::ResourceRecord));
715     mMessage->Write(aOffset, aRecord);
716 }
717 
IncResourceRecordCount(void)718 void Server::Response::IncResourceRecordCount(void)
719 {
720     switch (mSection)
721     {
722     case kAnswerSection:
723         mHeader.SetAnswerCount(mHeader.GetAnswerCount() + 1);
724         break;
725     case kAdditionalDataSection:
726         mHeader.SetAdditionalRecordCount(mHeader.GetAdditionalRecordCount() + 1);
727         break;
728     }
729 }
730 
731 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
Log(void) const732 void Server::Response::Log(void) const
733 {
734     Name::Buffer name;
735 
736     ReadQueryName(name);
737     LogInfo("%s query for '%s'", QueryTypeToString(mType), name);
738 }
739 
QueryTypeToString(QueryType aType)740 const char *Server::Response::QueryTypeToString(QueryType aType)
741 {
742     static const char *const kTypeNames[] = {
743         "PTR",       // (0) kPtrQuery
744         "SRV",       // (1) kSrvQuery
745         "TXT",       // (2) kTxtQuery
746         "SRV & TXT", // (3) kSrvTxtQuery
747         "AAAA",      // (4) kAaaaQuery
748         "A",         // (5) kAQuery
749     };
750 
751     static_assert(0 == kPtrQuery, "kPtrQuery value is incorrect");
752     static_assert(1 == kSrvQuery, "kSrvQuery value is incorrect");
753     static_assert(2 == kTxtQuery, "kTxtQuery value is incorrect");
754     static_assert(3 == kSrvTxtQuery, "kSrvTxtQuery value is incorrect");
755     static_assert(4 == kAaaaQuery, "kAaaaQuery value is incorrect");
756     static_assert(5 == kAQuery, "kAQuery value is incorrect");
757 
758     return kTypeNames[aType];
759 }
760 #endif
761 
762 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
763 
ResolveBySrp(void)764 Error Server::Response::ResolveBySrp(void)
765 {
766     static const Section kSections[] = {kAnswerSection, kAdditionalDataSection};
767 
768     Error                       error          = kErrorNotFound;
769     const Srp::Server::Service *matchedService = nullptr;
770     bool                        found          = false;
771     Section                     srvSection;
772     Section                     txtSection;
773 
774     mSection = kAnswerSection;
775 
776     for (const Srp::Server::Host &host : Get<Srp::Server>().GetHosts())
777     {
778         if (host.IsDeleted())
779         {
780             continue;
781         }
782 
783         if ((mType == kAaaaQuery) || (mType == kAQuery))
784         {
785             if (QueryNameMatches(host.GetFullName()))
786             {
787                 mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection;
788                 error    = AppendHostAddresses(host);
789                 ExitNow();
790             }
791 
792             continue;
793         }
794 
795         // `mType` is PTR or SRV/TXT query
796 
797         for (const Srp::Server::Service &service : host.GetServices())
798         {
799             if (service.IsDeleted())
800             {
801                 continue;
802             }
803 
804             if (mType == kPtrQuery)
805             {
806                 if (QueryNameMatchesService(service))
807                 {
808                     uint32_t ttl = TimeMilli::MsecToSec(service.GetExpireTime() - TimerMilli::GetNow());
809 
810                     SuccessOrExit(error = AppendPtrRecord(service.GetInstanceLabel(), ttl));
811                     matchedService = &service;
812                 }
813             }
814             else if (QueryNameMatches(service.GetInstanceName()))
815             {
816                 matchedService = &service;
817                 found          = true;
818                 break;
819             }
820         }
821 
822         if (found)
823         {
824             break;
825         }
826     }
827 
828     VerifyOrExit(matchedService != nullptr);
829 
830     if (mType == kPtrQuery)
831     {
832         // Skip adding additional records, when answering a
833         // PTR query with more than one answer. This is the
834         // recommended behavior to keep the size of the
835         // response small.
836 
837         VerifyOrExit(mHeader.GetAnswerCount() == 1);
838     }
839 
840     srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
841     txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
842 
843     for (Section section : kSections)
844     {
845         mSection = section;
846 
847         if (mSection == kAdditionalDataSection)
848         {
849             VerifyOrExit(!(Get<Server>().mTestMode & kTestModeEmptyAdditionalSection));
850         }
851 
852         if (srvSection == mSection)
853         {
854             SuccessOrExit(error = AppendSrvRecord(*matchedService));
855         }
856 
857         if (txtSection == mSection)
858         {
859             SuccessOrExit(error = AppendTxtRecord(*matchedService));
860         }
861     }
862 
863     SuccessOrExit(error = AppendHostAddresses(matchedService->GetHost()));
864 
865 exit:
866     return error;
867 }
868 
QueryNameMatchesService(const Srp::Server::Service & aService) const869 bool Server::Response::QueryNameMatchesService(const Srp::Server::Service &aService) const
870 {
871     // Check if the query name matches the base service name or any
872     // sub-type service names associated with `aService`.
873 
874     bool matches = QueryNameMatches(aService.GetServiceName());
875 
876     VerifyOrExit(!matches);
877 
878     for (uint16_t index = 0; index < aService.GetNumberOfSubTypes(); index++)
879     {
880         matches = QueryNameMatches(aService.GetSubTypeServiceNameAt(index));
881         VerifyOrExit(!matches);
882     }
883 
884 exit:
885     return matches;
886 }
887 
888 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
889 
890 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
ShouldForwardToUpstream(const Request & aRequest)891 bool Server::ShouldForwardToUpstream(const Request &aRequest)
892 {
893     bool         shouldForward = false;
894     uint16_t     readOffset;
895     Name::Buffer name;
896 
897     VerifyOrExit(aRequest.mHeader.IsRecursionDesiredFlagSet());
898     readOffset = sizeof(Header);
899 
900     for (uint16_t i = 0; i < aRequest.mHeader.GetQuestionCount(); i++)
901     {
902         SuccessOrExit(Name::ReadName(*aRequest.mMessage, readOffset, name));
903         readOffset += sizeof(Question);
904 
905         VerifyOrExit(!Name::IsSubDomainOf(name, kDefaultDomainName));
906 
907         for (const char *blockedDomain : kBlockedDomains)
908         {
909             VerifyOrExit(!Name::IsSameDomain(name, blockedDomain));
910         }
911     }
912 
913     shouldForward = true;
914 
915 exit:
916     return shouldForward;
917 }
918 
OnUpstreamQueryDone(UpstreamQueryTransaction & aQueryTransaction,Message * aResponseMessage)919 void Server::OnUpstreamQueryDone(UpstreamQueryTransaction &aQueryTransaction, Message *aResponseMessage)
920 {
921     Error error = kErrorNone;
922 
923     VerifyOrExit(aQueryTransaction.IsValid(), error = kErrorInvalidArgs);
924 
925     if (aResponseMessage != nullptr)
926     {
927         error = mSocket.SendTo(*aResponseMessage, aQueryTransaction.GetMessageInfo());
928     }
929     else
930     {
931         error = kErrorResponseTimeout;
932     }
933 
934     ResetUpstreamQueryTransaction(aQueryTransaction, error);
935 
936 exit:
937     FreeMessageOnError(aResponseMessage, error);
938 }
939 
AllocateUpstreamQueryTransaction(const Ip6::MessageInfo & aMessageInfo)940 Server::UpstreamQueryTransaction *Server::AllocateUpstreamQueryTransaction(const Ip6::MessageInfo &aMessageInfo)
941 {
942     UpstreamQueryTransaction *newTxn = nullptr;
943 
944     for (UpstreamQueryTransaction &txn : mUpstreamQueryTransactions)
945     {
946         if (!txn.IsValid())
947         {
948             newTxn = &txn;
949             break;
950         }
951     }
952 
953     VerifyOrExit(newTxn != nullptr, mCounters.mUpstreamDnsCounters.mFailures++);
954 
955     newTxn->Init(aMessageInfo);
956     LogInfo("Upstream query transaction %d initialized.", static_cast<int>(newTxn - mUpstreamQueryTransactions));
957     mTimer.FireAtIfEarlier(newTxn->GetExpireTime());
958 
959 exit:
960     return newTxn;
961 }
962 
ResolveByUpstream(const Request & aRequest)963 Error Server::ResolveByUpstream(const Request &aRequest)
964 {
965     Error                     error = kErrorNone;
966     UpstreamQueryTransaction *txn;
967 
968     txn = AllocateUpstreamQueryTransaction(*aRequest.mMessageInfo);
969     VerifyOrExit(txn != nullptr, error = kErrorNoBufs);
970 
971     otPlatDnsStartUpstreamQuery(&GetInstance(), txn, aRequest.mMessage);
972     mCounters.mUpstreamDnsCounters.mQueries++;
973 
974 exit:
975     return error;
976 }
977 #endif // OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
978 
ResolveByProxy(Response & aResponse,const Ip6::MessageInfo & aMessageInfo)979 void Server::ResolveByProxy(Response &aResponse, const Ip6::MessageInfo &aMessageInfo)
980 {
981     ProxyQuery    *query;
982     ProxyQueryInfo info;
983 
984 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
985     VerifyOrExit(mQuerySubscribe.IsSet() || mDiscoveryProxy.IsRunning());
986 #else
987     VerifyOrExit(mQuerySubscribe.IsSet());
988 #endif
989 
990     // We try to convert `aResponse.mMessage` to a `ProxyQuery` by
991     // appending `ProxyQueryInfo` to it.
992 
993     info.mType        = aResponse.mType;
994     info.mMessageInfo = aMessageInfo;
995     info.mExpireTime  = TimerMilli::GetNow() + kQueryTimeout;
996     info.mOffsets     = aResponse.mOffsets;
997 
998 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
999     info.mAction = kNoAction;
1000 #endif
1001 
1002     if (aResponse.mMessage->Append(info) != kErrorNone)
1003     {
1004         aResponse.SetResponseCode(Header::kResponseServerFailure);
1005         ExitNow();
1006     }
1007 
1008     // Take over the ownership of `aResponse.mMessage` and add it as a
1009     // `ProxyQuery` in `mProxyQueries` list.
1010 
1011     query = aResponse.mMessage.Release();
1012 
1013     query->Write(0, aResponse.mHeader);
1014     mProxyQueries.Enqueue(*query);
1015 
1016     mTimer.FireAtIfEarlier(info.mExpireTime);
1017 
1018 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1019     if (mQuerySubscribe.IsSet())
1020 #endif
1021     {
1022         Name::Buffer name;
1023 
1024         ReadQueryName(*query, name);
1025         mQuerySubscribe.Invoke(name);
1026     }
1027 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1028     else
1029     {
1030         mDiscoveryProxy.Resolve(*query, info);
1031     }
1032 #endif
1033 
1034 exit:
1035     return;
1036 }
1037 
ReadQueryName(const Message & aQuery,Name::Buffer & aName)1038 void Server::ReadQueryName(const Message &aQuery, Name::Buffer &aName)
1039 {
1040     uint16_t offset = sizeof(Header);
1041 
1042     IgnoreError(Name::ReadName(aQuery, offset, aName));
1043 }
1044 
QueryNameMatches(const Message & aQuery,const char * aName)1045 bool Server::QueryNameMatches(const Message &aQuery, const char *aName)
1046 {
1047     uint16_t offset = sizeof(Header);
1048 
1049     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1050 }
1051 
ReadQueryInstanceName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1052 void Server::ReadQueryInstanceName(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Name::Buffer &aName)
1053 {
1054     uint16_t offset = aInfo.mOffsets.mInstanceName;
1055 
1056     IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName)));
1057 }
1058 
ReadQueryInstanceName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::LabelBuffer & aInstanceLabel,Name::Buffer & aServiceType)1059 void Server::ReadQueryInstanceName(const ProxyQuery     &aQuery,
1060                                    const ProxyQueryInfo &aInfo,
1061                                    Name::LabelBuffer    &aInstanceLabel,
1062                                    Name::Buffer         &aServiceType)
1063 {
1064     // Reads the service instance label and service type with domain
1065     // name stripped.
1066 
1067     uint16_t offset      = aInfo.mOffsets.mInstanceName;
1068     uint8_t  labelLength = sizeof(aInstanceLabel);
1069 
1070     IgnoreError(Dns::Name::ReadLabel(aQuery, offset, aInstanceLabel, labelLength));
1071     IgnoreError(Dns::Name::ReadName(aQuery, offset, aServiceType));
1072     IgnoreError(StripDomainName(aServiceType));
1073 }
1074 
QueryInstanceNameMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,const char * aName)1075 bool Server::QueryInstanceNameMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, const char *aName)
1076 {
1077     uint16_t offset = aInfo.mOffsets.mInstanceName;
1078 
1079     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1080 }
1081 
ReadQueryHostName(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1082 void Server::ReadQueryHostName(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, Name::Buffer &aName)
1083 {
1084     uint16_t offset = aInfo.mOffsets.mHostName;
1085 
1086     IgnoreError(Name::ReadName(aQuery, offset, aName, sizeof(aName)));
1087 }
1088 
QueryHostNameMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,const char * aName)1089 bool Server::QueryHostNameMatches(const ProxyQuery &aQuery, const ProxyQueryInfo &aInfo, const char *aName)
1090 {
1091     uint16_t offset = aInfo.mOffsets.mHostName;
1092 
1093     return (Name::CompareName(aQuery, offset, aName) == kErrorNone);
1094 }
1095 
StripDomainName(Name::Buffer & aName)1096 Error Server::StripDomainName(Name::Buffer &aName)
1097 {
1098     // In-place removes the domain name from `aName`.
1099 
1100     return Name::StripName(aName, kDefaultDomainName);
1101 }
1102 
StripDomainName(const char * aFullName,Name::Buffer & aLabels)1103 Error Server::StripDomainName(const char *aFullName, Name::Buffer &aLabels)
1104 {
1105     // Remove the domain name from `aFullName` and copies
1106     // the result into `aLabels`.
1107 
1108     return Name::ExtractLabels(aFullName, kDefaultDomainName, aLabels, sizeof(aLabels));
1109 }
1110 
ConstructFullName(const char * aLabels,Name::Buffer & aFullName)1111 void Server::ConstructFullName(const char *aLabels, Name::Buffer &aFullName)
1112 {
1113     // Construct a full name by appending the default domain name
1114     // to `aLabels`.
1115 
1116     StringWriter fullName(aFullName, sizeof(aFullName));
1117 
1118     fullName.Append("%s.%s", aLabels, kDefaultDomainName);
1119 }
1120 
ConstructFullInstanceName(const char * aInstanceLabel,const char * aServiceType,Name::Buffer & aFullName)1121 void Server::ConstructFullInstanceName(const char *aInstanceLabel, const char *aServiceType, Name::Buffer &aFullName)
1122 {
1123     StringWriter fullName(aFullName, sizeof(aFullName));
1124 
1125     fullName.Append("%s.%s.%s", aInstanceLabel, aServiceType, kDefaultDomainName);
1126 }
1127 
ConstructFullServiceSubTypeName(const char * aServiceType,const char * aSubTypeLabel,Name::Buffer & aFullName)1128 void Server::ConstructFullServiceSubTypeName(const char   *aServiceType,
1129                                              const char   *aSubTypeLabel,
1130                                              Name::Buffer &aFullName)
1131 {
1132     StringWriter fullName(aFullName, sizeof(aFullName));
1133 
1134     fullName.Append("%s._sub.%s.%s", aSubTypeLabel, aServiceType, kDefaultDomainName);
1135 }
1136 
ReadFrom(const ProxyQuery & aQuery)1137 void Server::ProxyQueryInfo::ReadFrom(const ProxyQuery &aQuery)
1138 {
1139     SuccessOrAssert(aQuery.Read(aQuery.GetLength() - sizeof(ProxyQueryInfo), *this));
1140 }
1141 
RemoveFrom(ProxyQuery & aQuery) const1142 void Server::ProxyQueryInfo::RemoveFrom(ProxyQuery &aQuery) const { aQuery.RemoveFooter(sizeof(ProxyQueryInfo)); }
1143 
UpdateIn(ProxyQuery & aQuery) const1144 void Server::ProxyQueryInfo::UpdateIn(ProxyQuery &aQuery) const
1145 {
1146     aQuery.Write(aQuery.GetLength() - sizeof(ProxyQueryInfo), *this);
1147 }
1148 
ExtractServiceInstanceLabel(const char * aInstanceName,Name::LabelBuffer & aLabel)1149 Error Server::Response::ExtractServiceInstanceLabel(const char *aInstanceName, Name::LabelBuffer &aLabel)
1150 {
1151     uint16_t     offset;
1152     Name::Buffer serviceName;
1153 
1154     offset = mOffsets.mServiceName;
1155     IgnoreError(Name::ReadName(*mMessage, offset, serviceName));
1156 
1157     return Name::ExtractLabels(aInstanceName, serviceName, aLabel);
1158 }
1159 
RemoveQueryAndPrepareResponse(ProxyQuery & aQuery,ProxyQueryInfo & aInfo,Response & aResponse)1160 void Server::RemoveQueryAndPrepareResponse(ProxyQuery &aQuery, ProxyQueryInfo &aInfo, Response &aResponse)
1161 {
1162 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1163     mDiscoveryProxy.CancelAction(aQuery, aInfo);
1164 #endif
1165 
1166     mProxyQueries.Dequeue(aQuery);
1167     aInfo.RemoveFrom(aQuery);
1168 
1169     if (mQueryUnsubscribe.IsSet())
1170     {
1171         Name::Buffer name;
1172 
1173         ReadQueryName(aQuery, name);
1174         mQueryUnsubscribe.Invoke(name);
1175     }
1176 
1177     aResponse.InitFrom(aQuery, aInfo);
1178 }
1179 
InitFrom(ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1180 void Server::Response::InitFrom(ProxyQuery &aQuery, const ProxyQueryInfo &aInfo)
1181 {
1182     mMessage.Reset(&aQuery);
1183     IgnoreError(mMessage->Read(0, mHeader));
1184     mType    = aInfo.mType;
1185     mOffsets = aInfo.mOffsets;
1186 }
1187 
Answer(const ServiceInstanceInfo & aInstanceInfo,const Ip6::MessageInfo & aMessageInfo)1188 void Server::Response::Answer(const ServiceInstanceInfo &aInstanceInfo, const Ip6::MessageInfo &aMessageInfo)
1189 {
1190     static const Section kSections[] = {kAnswerSection, kAdditionalDataSection};
1191 
1192     Error   error      = kErrorNone;
1193     Section srvSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
1194     Section txtSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
1195 
1196     if (mType == kPtrQuery)
1197     {
1198         Name::LabelBuffer instanceLabel;
1199 
1200         SuccessOrExit(error = ExtractServiceInstanceLabel(aInstanceInfo.mFullName, instanceLabel));
1201         mSection = kAnswerSection;
1202         SuccessOrExit(error = AppendPtrRecord(instanceLabel, aInstanceInfo.mTtl));
1203     }
1204 
1205     for (Section section : kSections)
1206     {
1207         mSection = section;
1208 
1209         if (mSection == kAdditionalDataSection)
1210         {
1211             VerifyOrExit(!(Get<Server>().mTestMode & kTestModeEmptyAdditionalSection));
1212         }
1213 
1214         if (srvSection == mSection)
1215         {
1216             SuccessOrExit(error = AppendSrvRecord(aInstanceInfo));
1217         }
1218 
1219         if (txtSection == mSection)
1220         {
1221             SuccessOrExit(error = AppendTxtRecord(aInstanceInfo));
1222         }
1223     }
1224 
1225     error = AppendHostAddresses(aInstanceInfo);
1226 
1227 exit:
1228     if (error != kErrorNone)
1229     {
1230         SetResponseCode(Header::kResponseServerFailure);
1231     }
1232 
1233     Send(aMessageInfo);
1234 }
1235 
Answer(const HostInfo & aHostInfo,const Ip6::MessageInfo & aMessageInfo)1236 void Server::Response::Answer(const HostInfo &aHostInfo, const Ip6::MessageInfo &aMessageInfo)
1237 {
1238     // Caller already ensures that `mType` is either `kAaaaQuery` or
1239     // `kAQuery`.
1240 
1241     AddrType addrType = (mType == kAaaaQuery) ? kIp6AddrType : kIp4AddrType;
1242 
1243     mSection = kAnswerSection;
1244 
1245     if (AppendHostAddresses(addrType, aHostInfo) != kErrorNone)
1246     {
1247         SetResponseCode(Header::kResponseServerFailure);
1248     }
1249 
1250     Send(aMessageInfo);
1251 }
1252 
SetQueryCallbacks(SubscribeCallback aSubscribe,UnsubscribeCallback aUnsubscribe,void * aContext)1253 void Server::SetQueryCallbacks(SubscribeCallback aSubscribe, UnsubscribeCallback aUnsubscribe, void *aContext)
1254 {
1255     OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
1256 
1257     mQuerySubscribe.Set(aSubscribe, aContext);
1258     mQueryUnsubscribe.Set(aUnsubscribe, aContext);
1259 }
1260 
HandleDiscoveredServiceInstance(const char * aServiceFullName,const ServiceInstanceInfo & aInstanceInfo)1261 void Server::HandleDiscoveredServiceInstance(const char *aServiceFullName, const ServiceInstanceInfo &aInstanceInfo)
1262 {
1263     OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeparatorChar));
1264     OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeparatorChar));
1265     OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeparatorChar));
1266 
1267     // It is safe to remove entries from `mProxyQueries` as we iterate
1268     // over it since it is a `MessageQueue`.
1269 
1270     for (ProxyQuery &query : mProxyQueries)
1271     {
1272         bool           canAnswer = false;
1273         ProxyQueryInfo info;
1274 
1275         info.ReadFrom(query);
1276 
1277         switch (info.mType)
1278         {
1279         case kPtrQuery:
1280             canAnswer = QueryNameMatches(query, aServiceFullName);
1281             break;
1282 
1283         case kSrvQuery:
1284         case kTxtQuery:
1285         case kSrvTxtQuery:
1286             canAnswer = QueryNameMatches(query, aInstanceInfo.mFullName);
1287             break;
1288 
1289         case kAaaaQuery:
1290         case kAQuery:
1291             break;
1292         }
1293 
1294         if (canAnswer)
1295         {
1296             Response response(GetInstance());
1297 
1298             RemoveQueryAndPrepareResponse(query, info, response);
1299             response.Answer(aInstanceInfo, info.mMessageInfo);
1300         }
1301     }
1302 }
1303 
HandleDiscoveredHost(const char * aHostFullName,const HostInfo & aHostInfo)1304 void Server::HandleDiscoveredHost(const char *aHostFullName, const HostInfo &aHostInfo)
1305 {
1306     OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeparatorChar));
1307 
1308     for (ProxyQuery &query : mProxyQueries)
1309     {
1310         ProxyQueryInfo info;
1311 
1312         info.ReadFrom(query);
1313 
1314         switch (info.mType)
1315         {
1316         case kAaaaQuery:
1317         case kAQuery:
1318             if (QueryNameMatches(query, aHostFullName))
1319             {
1320                 Response response(GetInstance());
1321 
1322                 RemoveQueryAndPrepareResponse(query, info, response);
1323                 response.Answer(aHostInfo, info.mMessageInfo);
1324             }
1325 
1326             break;
1327 
1328         default:
1329             break;
1330         }
1331     }
1332 }
1333 
GetNextQuery(const otDnssdQuery * aQuery) const1334 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
1335 {
1336     const ProxyQuery *query = static_cast<const ProxyQuery *>(aQuery);
1337 
1338     return (query == nullptr) ? mProxyQueries.GetHead() : query->GetNext();
1339 }
1340 
GetQueryTypeAndName(const otDnssdQuery * aQuery,Dns::Name::Buffer & aName)1341 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, Dns::Name::Buffer &aName)
1342 {
1343     const ProxyQuery *query = static_cast<const ProxyQuery *>(aQuery);
1344     ProxyQueryInfo    info;
1345     DnsQueryType      type;
1346 
1347     ReadQueryName(*query, aName);
1348     info.ReadFrom(*query);
1349 
1350     type = kDnsQueryBrowse;
1351 
1352     switch (info.mType)
1353     {
1354     case kPtrQuery:
1355         break;
1356 
1357     case kSrvQuery:
1358     case kTxtQuery:
1359     case kSrvTxtQuery:
1360         type = kDnsQueryResolve;
1361         break;
1362 
1363     case kAaaaQuery:
1364     case kAQuery:
1365         type = kDnsQueryResolveHost;
1366         break;
1367     }
1368 
1369     return type;
1370 }
1371 
HandleTimer(void)1372 void Server::HandleTimer(void)
1373 {
1374     NextFireTime nextExpire;
1375 
1376     for (ProxyQuery &query : mProxyQueries)
1377     {
1378         ProxyQueryInfo info;
1379 
1380         info.ReadFrom(query);
1381 
1382         if (info.mExpireTime <= nextExpire.GetNow())
1383         {
1384             Finalize(query, Header::kResponseSuccess);
1385         }
1386         else
1387         {
1388             nextExpire.UpdateIfEarlier(info.mExpireTime);
1389         }
1390     }
1391 
1392 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
1393     for (UpstreamQueryTransaction &query : mUpstreamQueryTransactions)
1394     {
1395         if (!query.IsValid())
1396         {
1397             continue;
1398         }
1399 
1400         if (query.GetExpireTime() <= nextExpire.GetNow())
1401         {
1402             otPlatDnsCancelUpstreamQuery(&GetInstance(), &query);
1403         }
1404         else
1405         {
1406             nextExpire.UpdateIfEarlier(query.GetExpireTime());
1407         }
1408     }
1409 #endif
1410 
1411     mTimer.FireAtIfEarlier(nextExpire);
1412 }
1413 
Finalize(ProxyQuery & aQuery,ResponseCode aResponseCode)1414 void Server::Finalize(ProxyQuery &aQuery, ResponseCode aResponseCode)
1415 {
1416     Response       response(GetInstance());
1417     ProxyQueryInfo info;
1418 
1419     info.ReadFrom(aQuery);
1420     RemoveQueryAndPrepareResponse(aQuery, info, response);
1421 
1422     response.SetResponseCode(aResponseCode);
1423     response.Send(info.mMessageInfo);
1424 }
1425 
UpdateResponseCounters(ResponseCode aResponseCode)1426 void Server::UpdateResponseCounters(ResponseCode aResponseCode)
1427 {
1428     switch (aResponseCode)
1429     {
1430     case UpdateHeader::kResponseSuccess:
1431         ++mCounters.mSuccessResponse;
1432         break;
1433     case UpdateHeader::kResponseServerFailure:
1434         ++mCounters.mServerFailureResponse;
1435         break;
1436     case UpdateHeader::kResponseFormatError:
1437         ++mCounters.mFormatErrorResponse;
1438         break;
1439     case UpdateHeader::kResponseNameError:
1440         ++mCounters.mNameErrorResponse;
1441         break;
1442     case UpdateHeader::kResponseNotImplemented:
1443         ++mCounters.mNotImplementedResponse;
1444         break;
1445     default:
1446         ++mCounters.mOtherResponse;
1447         break;
1448     }
1449 }
1450 
1451 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
Init(const Ip6::MessageInfo & aMessageInfo)1452 void Server::UpstreamQueryTransaction::Init(const Ip6::MessageInfo &aMessageInfo)
1453 {
1454     mMessageInfo = aMessageInfo;
1455     mValid       = true;
1456     mExpireTime  = TimerMilli::GetNow() + kQueryTimeout;
1457 }
1458 
ResetUpstreamQueryTransaction(UpstreamQueryTransaction & aTxn,Error aError)1459 void Server::ResetUpstreamQueryTransaction(UpstreamQueryTransaction &aTxn, Error aError)
1460 {
1461     int index = static_cast<int>(&aTxn - mUpstreamQueryTransactions);
1462 
1463     // Avoid the warnings when info / warn logging is disabled.
1464     OT_UNUSED_VARIABLE(index);
1465     if (aError == kErrorNone)
1466     {
1467         mCounters.mUpstreamDnsCounters.mResponses++;
1468         LogInfo("Upstream query transaction %d completed.", index);
1469     }
1470     else
1471     {
1472         mCounters.mUpstreamDnsCounters.mFailures++;
1473         LogWarn("Upstream query transaction %d closed: %s.", index, ErrorToString(aError));
1474     }
1475     aTxn.Reset();
1476 }
1477 #endif
1478 
1479 #if OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
1480 
DiscoveryProxy(Instance & aInstance)1481 Server::DiscoveryProxy::DiscoveryProxy(Instance &aInstance)
1482     : InstanceLocator(aInstance)
1483     , mIsRunning(false)
1484 {
1485 }
1486 
UpdateState(void)1487 void Server::DiscoveryProxy::UpdateState(void)
1488 {
1489     if (Get<Server>().IsRunning() && Get<Dnssd>().IsReady() && Get<BorderRouter::InfraIf>().IsRunning())
1490     {
1491         Start();
1492     }
1493     else
1494     {
1495         Stop();
1496     }
1497 }
1498 
Start(void)1499 void Server::DiscoveryProxy::Start(void)
1500 {
1501     VerifyOrExit(!mIsRunning);
1502     mIsRunning = true;
1503     LogInfo("Started discovery proxy");
1504 
1505 exit:
1506     return;
1507 }
1508 
Stop(void)1509 void Server::DiscoveryProxy::Stop(void)
1510 {
1511     VerifyOrExit(mIsRunning);
1512 
1513     for (ProxyQuery &query : Get<Server>().mProxyQueries)
1514     {
1515         Get<Server>().Finalize(query, Header::kResponseSuccess);
1516     }
1517 
1518     mIsRunning = false;
1519     LogInfo("Stopped discovery proxy");
1520 
1521 exit:
1522     return;
1523 }
1524 
Resolve(ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1525 void Server::DiscoveryProxy::Resolve(ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1526 {
1527     ProxyAction action = kNoAction;
1528 
1529     switch (aInfo.mType)
1530     {
1531     case kPtrQuery:
1532         action = kBrowsing;
1533         break;
1534 
1535     case kSrvQuery:
1536     case kSrvTxtQuery:
1537         action = kResolvingSrv;
1538         break;
1539 
1540     case kTxtQuery:
1541         action = kResolvingTxt;
1542         break;
1543 
1544     case kAaaaQuery:
1545         action = kResolvingIp6Address;
1546         break;
1547     case kAQuery:
1548         action = kResolvingIp4Address;
1549         break;
1550     }
1551 
1552     Perform(action, aQuery, aInfo);
1553 }
1554 
Perform(ProxyAction aAction,ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1555 void Server::DiscoveryProxy::Perform(ProxyAction aAction, ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1556 {
1557     bool         shouldStart;
1558     Name::Buffer name;
1559 
1560     VerifyOrExit(aAction != kNoAction);
1561 
1562     // The order of the steps below is crucial. First, we read the
1563     // name associated with the action. Then we check if another
1564     // query has an active browser/resolver for the same name. This
1565     // helps us determine if a new browser/resolver is needed. Then,
1566     // we update the `ProxyQueryInfo` within `aQuery` to reflect the
1567     // `aAction` being performed. Finally, if necessary, we start the
1568     // proper browser/resolver on DNS-SD/mDNS. Placing this last
1569     // ensures correct processing even if a DNS-SD/mDNS callback is
1570     // invoked immediately.
1571 
1572     ReadNameFor(aAction, aQuery, aInfo, name);
1573 
1574     shouldStart = !HasActive(aAction, name);
1575 
1576     aInfo.mAction = aAction;
1577     aInfo.UpdateIn(aQuery);
1578 
1579     VerifyOrExit(shouldStart);
1580     UpdateProxy(kStart, aAction, aQuery, aInfo, name);
1581 
1582 exit:
1583     return;
1584 }
1585 
ReadNameFor(ProxyAction aAction,ProxyQuery & aQuery,ProxyQueryInfo & aInfo,Name::Buffer & aName) const1586 void Server::DiscoveryProxy::ReadNameFor(ProxyAction     aAction,
1587                                          ProxyQuery     &aQuery,
1588                                          ProxyQueryInfo &aInfo,
1589                                          Name::Buffer   &aName) const
1590 {
1591     // Read the name corresponding to `aAction` from `aQuery`.
1592 
1593     switch (aAction)
1594     {
1595     case kNoAction:
1596         break;
1597     case kBrowsing:
1598         ReadQueryName(aQuery, aName);
1599         break;
1600     case kResolvingSrv:
1601     case kResolvingTxt:
1602         ReadQueryInstanceName(aQuery, aInfo, aName);
1603         break;
1604     case kResolvingIp6Address:
1605     case kResolvingIp4Address:
1606         ReadQueryHostName(aQuery, aInfo, aName);
1607         break;
1608     }
1609 }
1610 
CancelAction(ProxyQuery & aQuery,ProxyQueryInfo & aInfo)1611 void Server::DiscoveryProxy::CancelAction(ProxyQuery &aQuery, ProxyQueryInfo &aInfo)
1612 {
1613     // Cancel the current action for a given `aQuery`, then
1614     // determine if we need to stop any browser/resolver
1615     // on infrastructure.
1616 
1617     ProxyAction  action = aInfo.mAction;
1618     Name::Buffer name;
1619 
1620     VerifyOrExit(mIsRunning);
1621     VerifyOrExit(action != kNoAction);
1622 
1623     // We first update the `aInfo` on `aQuery` before calling
1624     // `HasActive()`. This ensures that the current query is not
1625     // taken into account when we try to determine if any query
1626     // is waiting for same `aAction` browser/resolver.
1627 
1628     ReadNameFor(action, aQuery, aInfo, name);
1629 
1630     aInfo.mAction = kNoAction;
1631     aInfo.UpdateIn(aQuery);
1632 
1633     VerifyOrExit(!HasActive(action, name));
1634     UpdateProxy(kStop, action, aQuery, aInfo, name);
1635 
1636 exit:
1637     return;
1638 }
1639 
UpdateProxy(Command aCommand,ProxyAction aAction,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,Name::Buffer & aName)1640 void Server::DiscoveryProxy::UpdateProxy(Command               aCommand,
1641                                          ProxyAction           aAction,
1642                                          const ProxyQuery     &aQuery,
1643                                          const ProxyQueryInfo &aInfo,
1644                                          Name::Buffer         &aName)
1645 {
1646     // Start or stop browser/resolver corresponding to `aAction`.
1647     // `aName` may be changed.
1648 
1649     switch (aAction)
1650     {
1651     case kNoAction:
1652         break;
1653     case kBrowsing:
1654         StartOrStopBrowser(aCommand, aName);
1655         break;
1656     case kResolvingSrv:
1657         StartOrStopSrvResolver(aCommand, aQuery, aInfo);
1658         break;
1659     case kResolvingTxt:
1660         StartOrStopTxtResolver(aCommand, aQuery, aInfo);
1661         break;
1662     case kResolvingIp6Address:
1663         StartOrStopIp6Resolver(aCommand, aName);
1664         break;
1665     case kResolvingIp4Address:
1666         StartOrStopIp4Resolver(aCommand, aName);
1667         break;
1668     }
1669 }
1670 
StartOrStopBrowser(Command aCommand,Name::Buffer & aServiceName)1671 void Server::DiscoveryProxy::StartOrStopBrowser(Command aCommand, Name::Buffer &aServiceName)
1672 {
1673     // Start or stop a service browser for a given service type
1674     // or sub-type.
1675 
1676     static const char kFullSubLabel[] = "._sub.";
1677 
1678     Dnssd::Browser browser;
1679     char          *ptr;
1680 
1681     browser.Clear();
1682 
1683     IgnoreError(StripDomainName(aServiceName));
1684 
1685     // Check if the service name is a sub-type with name
1686     // format: "<sub-label>._sub.<service-labels>.
1687 
1688     ptr = AsNonConst(StringFind(aServiceName, kFullSubLabel, kStringCaseInsensitiveMatch));
1689 
1690     if (ptr != nullptr)
1691     {
1692         *ptr = kNullChar;
1693         ptr += sizeof(kFullSubLabel) - 1;
1694 
1695         browser.mServiceType  = ptr;
1696         browser.mSubTypeLabel = aServiceName;
1697     }
1698     else
1699     {
1700         browser.mServiceType  = aServiceName;
1701         browser.mSubTypeLabel = nullptr;
1702     }
1703 
1704     browser.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1705     browser.mCallback     = HandleBrowseResult;
1706 
1707     switch (aCommand)
1708     {
1709     case kStart:
1710         Get<Dnssd>().StartBrowser(browser);
1711         break;
1712 
1713     case kStop:
1714         Get<Dnssd>().StopBrowser(browser);
1715         break;
1716     }
1717 }
1718 
StartOrStopSrvResolver(Command aCommand,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1719 void Server::DiscoveryProxy::StartOrStopSrvResolver(Command               aCommand,
1720                                                     const ProxyQuery     &aQuery,
1721                                                     const ProxyQueryInfo &aInfo)
1722 {
1723     // Start or stop an SRV record resolver for a given query.
1724 
1725     Dnssd::SrvResolver resolver;
1726     Name::LabelBuffer  instanceLabel;
1727     Name::Buffer       serviceType;
1728 
1729     ReadQueryInstanceName(aQuery, aInfo, instanceLabel, serviceType);
1730 
1731     resolver.Clear();
1732 
1733     resolver.mServiceInstance = instanceLabel;
1734     resolver.mServiceType     = serviceType;
1735     resolver.mInfraIfIndex    = Get<BorderRouter::InfraIf>().GetIfIndex();
1736     resolver.mCallback        = HandleSrvResult;
1737 
1738     switch (aCommand)
1739     {
1740     case kStart:
1741         Get<Dnssd>().StartSrvResolver(resolver);
1742         break;
1743 
1744     case kStop:
1745         Get<Dnssd>().StopSrvResolver(resolver);
1746         break;
1747     }
1748 }
1749 
StartOrStopTxtResolver(Command aCommand,const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo)1750 void Server::DiscoveryProxy::StartOrStopTxtResolver(Command               aCommand,
1751                                                     const ProxyQuery     &aQuery,
1752                                                     const ProxyQueryInfo &aInfo)
1753 {
1754     // Start or stop a TXT record resolver for a given query.
1755 
1756     Dnssd::TxtResolver resolver;
1757     Name::LabelBuffer  instanceLabel;
1758     Name::Buffer       serviceType;
1759 
1760     ReadQueryInstanceName(aQuery, aInfo, instanceLabel, serviceType);
1761 
1762     resolver.Clear();
1763 
1764     resolver.mServiceInstance = instanceLabel;
1765     resolver.mServiceType     = serviceType;
1766     resolver.mInfraIfIndex    = Get<BorderRouter::InfraIf>().GetIfIndex();
1767     resolver.mCallback        = HandleTxtResult;
1768 
1769     switch (aCommand)
1770     {
1771     case kStart:
1772         Get<Dnssd>().StartTxtResolver(resolver);
1773         break;
1774 
1775     case kStop:
1776         Get<Dnssd>().StopTxtResolver(resolver);
1777         break;
1778     }
1779 }
1780 
StartOrStopIp6Resolver(Command aCommand,Name::Buffer & aHostName)1781 void Server::DiscoveryProxy::StartOrStopIp6Resolver(Command aCommand, Name::Buffer &aHostName)
1782 {
1783     // Start or stop an IPv6 address resolver for a given host name.
1784 
1785     Dnssd::AddressResolver resolver;
1786 
1787     IgnoreError(StripDomainName(aHostName));
1788 
1789     resolver.mHostName     = aHostName;
1790     resolver.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1791     resolver.mCallback     = HandleIp6AddressResult;
1792 
1793     switch (aCommand)
1794     {
1795     case kStart:
1796         Get<Dnssd>().StartIp6AddressResolver(resolver);
1797         break;
1798 
1799     case kStop:
1800         Get<Dnssd>().StopIp6AddressResolver(resolver);
1801         break;
1802     }
1803 }
1804 
StartOrStopIp4Resolver(Command aCommand,Name::Buffer & aHostName)1805 void Server::DiscoveryProxy::StartOrStopIp4Resolver(Command aCommand, Name::Buffer &aHostName)
1806 {
1807     // Start or stop an IPv4 address resolver for a given host name.
1808 
1809     Dnssd::AddressResolver resolver;
1810 
1811     IgnoreError(StripDomainName(aHostName));
1812 
1813     resolver.mHostName     = aHostName;
1814     resolver.mInfraIfIndex = Get<BorderRouter::InfraIf>().GetIfIndex();
1815     resolver.mCallback     = HandleIp4AddressResult;
1816 
1817     switch (aCommand)
1818     {
1819     case kStart:
1820         Get<Dnssd>().StartIp4AddressResolver(resolver);
1821         break;
1822 
1823     case kStop:
1824         Get<Dnssd>().StopIp4AddressResolver(resolver);
1825         break;
1826     }
1827 }
1828 
QueryMatches(const ProxyQuery & aQuery,const ProxyQueryInfo & aInfo,ProxyAction aAction,const Name::Buffer & aName) const1829 bool Server::DiscoveryProxy::QueryMatches(const ProxyQuery     &aQuery,
1830                                           const ProxyQueryInfo &aInfo,
1831                                           ProxyAction           aAction,
1832                                           const Name::Buffer   &aName) const
1833 {
1834     // Check whether `aQuery` is performing `aAction` and
1835     // its name matches `aName`.
1836 
1837     bool matches = false;
1838 
1839     VerifyOrExit(aInfo.mAction == aAction);
1840 
1841     switch (aAction)
1842     {
1843     case kBrowsing:
1844         VerifyOrExit(QueryNameMatches(aQuery, aName));
1845         break;
1846     case kResolvingSrv:
1847     case kResolvingTxt:
1848         VerifyOrExit(QueryInstanceNameMatches(aQuery, aInfo, aName));
1849         break;
1850     case kResolvingIp6Address:
1851     case kResolvingIp4Address:
1852         VerifyOrExit(QueryHostNameMatches(aQuery, aInfo, aName));
1853         break;
1854     case kNoAction:
1855         ExitNow();
1856     }
1857 
1858     matches = true;
1859 
1860 exit:
1861     return matches;
1862 }
1863 
HasActive(ProxyAction aAction,const Name::Buffer & aName) const1864 bool Server::DiscoveryProxy::HasActive(ProxyAction aAction, const Name::Buffer &aName) const
1865 {
1866     // Determine whether or not we have an active browser/resolver
1867     // corresponding to `aAction` for `aName`.
1868 
1869     bool has = false;
1870 
1871     for (const ProxyQuery &query : Get<Server>().mProxyQueries)
1872     {
1873         ProxyQueryInfo info;
1874 
1875         info.ReadFrom(query);
1876 
1877         if (QueryMatches(query, info, aAction, aName))
1878         {
1879             has = true;
1880             break;
1881         }
1882     }
1883 
1884     return has;
1885 }
1886 
HandleBrowseResult(otInstance * aInstance,const otPlatDnssdBrowseResult * aResult)1887 void Server::DiscoveryProxy::HandleBrowseResult(otInstance *aInstance, const otPlatDnssdBrowseResult *aResult)
1888 {
1889     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleBrowseResult(*aResult);
1890 }
1891 
HandleBrowseResult(const Dnssd::BrowseResult & aResult)1892 void Server::DiscoveryProxy::HandleBrowseResult(const Dnssd::BrowseResult &aResult)
1893 {
1894     Name::Buffer serviceName;
1895 
1896     VerifyOrExit(mIsRunning);
1897     VerifyOrExit(aResult.mTtl != 0);
1898     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1899 
1900     if (aResult.mSubTypeLabel != nullptr)
1901     {
1902         ConstructFullServiceSubTypeName(aResult.mServiceType, aResult.mSubTypeLabel, serviceName);
1903     }
1904     else
1905     {
1906         ConstructFullName(aResult.mServiceType, serviceName);
1907     }
1908 
1909     HandleResult(kBrowsing, serviceName, &Response::AppendPtrRecord, ProxyResult(aResult));
1910 
1911 exit:
1912     return;
1913 }
1914 
HandleSrvResult(otInstance * aInstance,const otPlatDnssdSrvResult * aResult)1915 void Server::DiscoveryProxy::HandleSrvResult(otInstance *aInstance, const otPlatDnssdSrvResult *aResult)
1916 {
1917     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleSrvResult(*aResult);
1918 }
1919 
HandleSrvResult(const Dnssd::SrvResult & aResult)1920 void Server::DiscoveryProxy::HandleSrvResult(const Dnssd::SrvResult &aResult)
1921 {
1922     Name::Buffer instanceName;
1923 
1924     VerifyOrExit(mIsRunning);
1925     VerifyOrExit(aResult.mTtl != 0);
1926     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1927 
1928     ConstructFullInstanceName(aResult.mServiceInstance, aResult.mServiceType, instanceName);
1929     HandleResult(kResolvingSrv, instanceName, &Response::AppendSrvRecord, ProxyResult(aResult));
1930 
1931 exit:
1932     return;
1933 }
1934 
HandleTxtResult(otInstance * aInstance,const otPlatDnssdTxtResult * aResult)1935 void Server::DiscoveryProxy::HandleTxtResult(otInstance *aInstance, const otPlatDnssdTxtResult *aResult)
1936 {
1937     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleTxtResult(*aResult);
1938 }
1939 
HandleTxtResult(const Dnssd::TxtResult & aResult)1940 void Server::DiscoveryProxy::HandleTxtResult(const Dnssd::TxtResult &aResult)
1941 {
1942     Name::Buffer instanceName;
1943 
1944     VerifyOrExit(mIsRunning);
1945     VerifyOrExit(aResult.mTtl != 0);
1946     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1947 
1948     ConstructFullInstanceName(aResult.mServiceInstance, aResult.mServiceType, instanceName);
1949     HandleResult(kResolvingTxt, instanceName, &Response::AppendTxtRecord, ProxyResult(aResult));
1950 
1951 exit:
1952     return;
1953 }
1954 
HandleIp6AddressResult(otInstance * aInstance,const otPlatDnssdAddressResult * aResult)1955 void Server::DiscoveryProxy::HandleIp6AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult)
1956 {
1957     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleIp6AddressResult(*aResult);
1958 }
1959 
HandleIp6AddressResult(const Dnssd::AddressResult & aResult)1960 void Server::DiscoveryProxy::HandleIp6AddressResult(const Dnssd::AddressResult &aResult)
1961 {
1962     bool         hasValidAddress = false;
1963     Name::Buffer fullHostName;
1964 
1965     VerifyOrExit(mIsRunning);
1966     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
1967 
1968     for (uint16_t index = 0; index < aResult.mAddressesLength; index++)
1969     {
1970         const Dnssd::AddressAndTtl &entry   = aResult.mAddresses[index];
1971         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
1972 
1973         if (entry.mTtl == 0)
1974         {
1975             continue;
1976         }
1977 
1978         if (IsProxyAddressValid(address))
1979         {
1980             hasValidAddress = true;
1981             break;
1982         }
1983     }
1984 
1985     VerifyOrExit(hasValidAddress);
1986 
1987     ConstructFullName(aResult.mHostName, fullHostName);
1988     HandleResult(kResolvingIp6Address, fullHostName, &Response::AppendHostIp6Addresses, ProxyResult(aResult));
1989 
1990 exit:
1991     return;
1992 }
1993 
HandleIp4AddressResult(otInstance * aInstance,const otPlatDnssdAddressResult * aResult)1994 void Server::DiscoveryProxy::HandleIp4AddressResult(otInstance *aInstance, const otPlatDnssdAddressResult *aResult)
1995 {
1996     AsCoreType(aInstance).Get<Server>().mDiscoveryProxy.HandleIp4AddressResult(*aResult);
1997 }
1998 
HandleIp4AddressResult(const Dnssd::AddressResult & aResult)1999 void Server::DiscoveryProxy::HandleIp4AddressResult(const Dnssd::AddressResult &aResult)
2000 {
2001     bool         hasValidAddress = false;
2002     Name::Buffer fullHostName;
2003 
2004     VerifyOrExit(mIsRunning);
2005     VerifyOrExit(aResult.mInfraIfIndex == Get<BorderRouter::InfraIf>().GetIfIndex());
2006 
2007     for (uint16_t index = 0; index < aResult.mAddressesLength; index++)
2008     {
2009         const Dnssd::AddressAndTtl &entry   = aResult.mAddresses[index];
2010         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
2011 
2012         if (entry.mTtl == 0)
2013         {
2014             continue;
2015         }
2016 
2017         if (address.IsIp4Mapped())
2018         {
2019             hasValidAddress = true;
2020             break;
2021         }
2022     }
2023 
2024     VerifyOrExit(hasValidAddress);
2025 
2026     ConstructFullName(aResult.mHostName, fullHostName);
2027     HandleResult(kResolvingIp4Address, fullHostName, &Response::AppendHostIp4Addresses, ProxyResult(aResult));
2028 
2029 exit:
2030     return;
2031 }
2032 
HandleResult(ProxyAction aAction,const Name::Buffer & aName,ResponseAppender aAppender,const ProxyResult & aResult)2033 void Server::DiscoveryProxy::HandleResult(ProxyAction         aAction,
2034                                           const Name::Buffer &aName,
2035                                           ResponseAppender    aAppender,
2036                                           const ProxyResult  &aResult)
2037 {
2038     // Common method that handles result from DNS-SD/mDNS. It
2039     // iterates over all `ProxyQuery` entries and checks if any entry
2040     // is waiting for the result of `aAction` for `aName`. Matching
2041     // queries are updated using the `aAppender` method pointer,
2042     // which appends the corresponding record(s) to the response. We
2043     // then determine the next action to be performed for the
2044     // `ProxyQuery` or if it can be finalized.
2045 
2046     ProxyQueryList nextActionQueries;
2047     ProxyQueryInfo info;
2048     ProxyAction    nextAction;
2049 
2050     for (ProxyQuery &query : Get<Server>().mProxyQueries)
2051     {
2052         Response response(GetInstance());
2053         bool     shouldFinalize;
2054 
2055         info.ReadFrom(query);
2056 
2057         if (!QueryMatches(query, info, aAction, aName))
2058         {
2059             continue;
2060         }
2061 
2062         CancelAction(query, info);
2063 
2064         nextAction = kNoAction;
2065 
2066         switch (aAction)
2067         {
2068         case kBrowsing:
2069             nextAction = kResolvingSrv;
2070             break;
2071         case kResolvingSrv:
2072             nextAction = (info.mType == kSrvQuery) ? kResolvingIp6Address : kResolvingTxt;
2073             break;
2074         case kResolvingTxt:
2075             nextAction = (info.mType == kTxtQuery) ? kNoAction : kResolvingIp6Address;
2076             break;
2077         case kNoAction:
2078         case kResolvingIp6Address:
2079         case kResolvingIp4Address:
2080             break;
2081         }
2082 
2083         shouldFinalize = (nextAction == kNoAction);
2084 
2085         if ((Get<Server>().mTestMode & kTestModeEmptyAdditionalSection) &&
2086             IsActionForAdditionalSection(nextAction, info.mType))
2087         {
2088             shouldFinalize = true;
2089         }
2090 
2091         Get<Server>().mProxyQueries.Dequeue(query);
2092         info.RemoveFrom(query);
2093         response.InitFrom(query, info);
2094 
2095         if ((response.*aAppender)(aResult) != kErrorNone)
2096         {
2097             response.SetResponseCode(Header::kResponseServerFailure);
2098             shouldFinalize = true;
2099         }
2100 
2101         if (shouldFinalize)
2102         {
2103             response.Send(info.mMessageInfo);
2104             continue;
2105         }
2106 
2107         // The `query` is not yet finished and we need to perform
2108         // the `nextAction` for it.
2109 
2110         // Reinitialize `response` as a `ProxyQuey` by updating
2111         // and appending `info` to it after the newly appended
2112         // records from `aResult` and saving the `mHeader`.
2113 
2114         info.mOffsets = response.mOffsets;
2115         info.mAction  = nextAction;
2116         response.mMessage->Write(0, response.mHeader);
2117 
2118         if (response.mMessage->Append(info) != kErrorNone)
2119         {
2120             response.SetResponseCode(Header::kResponseServerFailure);
2121             response.Send(info.mMessageInfo);
2122             continue;
2123         }
2124 
2125         // Take back ownership of `response.mMessage` as we still
2126         // treat it as a `ProxyQuery`.
2127 
2128         response.mMessage.Release();
2129 
2130         // We place the `query` in a separate list and add it back to
2131         // the main `mProxyQueries` list after we are done with the
2132         // current iteration. This ensures that other entries in the
2133         // `mProxyQueries` list are not updated or removed due to the
2134         // DNS-SD platform callback being invoked immediately when we
2135         // potentially start a browser or resolver to perform the
2136         // `nextAction` for `query`.
2137 
2138         nextActionQueries.Enqueue(query);
2139     }
2140 
2141     for (ProxyQuery &query : nextActionQueries)
2142     {
2143         nextActionQueries.Dequeue(query);
2144 
2145         info.ReadFrom(query);
2146 
2147         nextAction = info.mAction;
2148 
2149         info.mAction = kNoAction;
2150         info.UpdateIn(query);
2151 
2152         Get<Server>().mProxyQueries.Enqueue(query);
2153         Perform(nextAction, query, info);
2154     }
2155 }
2156 
IsActionForAdditionalSection(ProxyAction aAction,QueryType aQueryType)2157 bool Server::DiscoveryProxy::IsActionForAdditionalSection(ProxyAction aAction, QueryType aQueryType)
2158 {
2159     bool isForAddnlSection = false;
2160 
2161     switch (aAction)
2162     {
2163     case kResolvingSrv:
2164         VerifyOrExit((aQueryType == kSrvQuery) || (aQueryType == kSrvTxtQuery));
2165         break;
2166     case kResolvingTxt:
2167         VerifyOrExit((aQueryType == kTxtQuery) || (aQueryType == kSrvTxtQuery));
2168         break;
2169 
2170     case kResolvingIp6Address:
2171         VerifyOrExit(aQueryType == kAaaaQuery);
2172         break;
2173 
2174     case kResolvingIp4Address:
2175         VerifyOrExit(aQueryType == kAQuery);
2176         break;
2177 
2178     case kNoAction:
2179     case kBrowsing:
2180         ExitNow();
2181     }
2182 
2183     isForAddnlSection = true;
2184 
2185 exit:
2186     return isForAddnlSection;
2187 }
2188 
AppendPtrRecord(const ProxyResult & aResult)2189 Error Server::Response::AppendPtrRecord(const ProxyResult &aResult)
2190 {
2191     const Dnssd::BrowseResult *browseResult = aResult.mBrowseResult;
2192 
2193     mSection = kAnswerSection;
2194 
2195     return AppendPtrRecord(browseResult->mServiceInstance, browseResult->mTtl);
2196 }
2197 
AppendSrvRecord(const ProxyResult & aResult)2198 Error Server::Response::AppendSrvRecord(const ProxyResult &aResult)
2199 {
2200     const Dnssd::SrvResult *srvResult = aResult.mSrvResult;
2201     Name::Buffer            fullHostName;
2202 
2203     mSection = ((mType == kSrvQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
2204 
2205     ConstructFullName(srvResult->mHostName, fullHostName);
2206 
2207     return AppendSrvRecord(fullHostName, srvResult->mTtl, srvResult->mPriority, srvResult->mWeight, srvResult->mPort);
2208 }
2209 
AppendTxtRecord(const ProxyResult & aResult)2210 Error Server::Response::AppendTxtRecord(const ProxyResult &aResult)
2211 {
2212     const Dnssd::TxtResult *txtResult = aResult.mTxtResult;
2213 
2214     mSection = ((mType == kTxtQuery) || (mType == kSrvTxtQuery)) ? kAnswerSection : kAdditionalDataSection;
2215 
2216     return AppendTxtRecord(txtResult->mTxtData, txtResult->mTxtDataLength, txtResult->mTtl);
2217 }
2218 
AppendHostIp6Addresses(const ProxyResult & aResult)2219 Error Server::Response::AppendHostIp6Addresses(const ProxyResult &aResult)
2220 {
2221     Error                       error      = kErrorNone;
2222     const Dnssd::AddressResult *addrResult = aResult.mAddressResult;
2223 
2224     mSection = (mType == kAaaaQuery) ? kAnswerSection : kAdditionalDataSection;
2225 
2226     for (uint16_t index = 0; index < addrResult->mAddressesLength; index++)
2227     {
2228         const Dnssd::AddressAndTtl &entry   = addrResult->mAddresses[index];
2229         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
2230 
2231         if (entry.mTtl == 0)
2232         {
2233             continue;
2234         }
2235 
2236         if (!IsProxyAddressValid(address))
2237         {
2238             continue;
2239         }
2240 
2241         SuccessOrExit(error = AppendAaaaRecord(address, entry.mTtl));
2242     }
2243 
2244 exit:
2245     return error;
2246 }
2247 
AppendHostIp4Addresses(const ProxyResult & aResult)2248 Error Server::Response::AppendHostIp4Addresses(const ProxyResult &aResult)
2249 {
2250     Error                       error      = kErrorNone;
2251     const Dnssd::AddressResult *addrResult = aResult.mAddressResult;
2252 
2253     mSection = (mType == kAQuery) ? kAnswerSection : kAdditionalDataSection;
2254 
2255     for (uint16_t index = 0; index < addrResult->mAddressesLength; index++)
2256     {
2257         const Dnssd::AddressAndTtl &entry   = addrResult->mAddresses[index];
2258         const Ip6::Address         &address = AsCoreType(&entry.mAddress);
2259 
2260         if (entry.mTtl == 0)
2261         {
2262             continue;
2263         }
2264 
2265         SuccessOrExit(error = AppendARecord(address, entry.mTtl));
2266     }
2267 
2268 exit:
2269     return error;
2270 }
2271 
IsProxyAddressValid(const Ip6::Address & aAddress)2272 bool Server::IsProxyAddressValid(const Ip6::Address &aAddress)
2273 {
2274     return !aAddress.IsLinkLocalUnicast() && !aAddress.IsMulticast() && !aAddress.IsUnspecified() &&
2275            !aAddress.IsLoopback();
2276 }
2277 
2278 #endif // OPENTHREAD_CONFIG_DNSSD_DISCOVERY_PROXY_ENABLE
2279 
2280 } // namespace ServiceDiscovery
2281 } // namespace Dns
2282 } // namespace ot
2283 
2284 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE && OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_MOCK_PLAT_APIS_ENABLE
otPlatDnsStartUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)2285 void otPlatDnsStartUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
2286 {
2287     OT_UNUSED_VARIABLE(aInstance);
2288     OT_UNUSED_VARIABLE(aTxn);
2289     OT_UNUSED_VARIABLE(aQuery);
2290 }
2291 
otPlatDnsCancelUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn)2292 void otPlatDnsCancelUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn)
2293 {
2294     otPlatDnsUpstreamQueryDone(aInstance, aTxn, nullptr);
2295 }
2296 #endif
2297 
2298 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
2299