1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the DNS-SD server.
32  */
33 
34 #include "dnssd_server.hpp"
35 
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37 
38 #include "common/code_utils.hpp"
39 #include "common/debug.hpp"
40 #include "common/instance.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/logging.hpp"
43 #include "net/srp_server.hpp"
44 #include "net/udp6.hpp"
45 
46 namespace ot {
47 namespace Dns {
48 namespace ServiceDiscovery {
49 
50 const char Server::kDnssdProtocolUdp[4] = {'_', 'u', 'd', 'p'};
51 const char Server::kDnssdProtocolTcp[4] = {'_', 't', 'c', 'p'};
52 const char Server::kDnssdSubTypeLabel[] = "._sub.";
53 const char Server::kDefaultDomainName[] = "default.service.arpa.";
54 
Server(Instance & aInstance)55 Server::Server(Instance &aInstance)
56     : InstanceLocator(aInstance)
57     , mSocket(aInstance)
58     , mQueryCallbackContext(nullptr)
59     , mQuerySubscribe(nullptr)
60     , mQueryUnsubscribe(nullptr)
61     , mTimer(aInstance, Server::HandleTimer)
62 {
63 }
64 
Start(void)65 Error Server::Start(void)
66 {
67     Error error = kErrorNone;
68 
69     VerifyOrExit(!IsRunning());
70 
71     SuccessOrExit(error = mSocket.Open(&Server::HandleUdpReceive, this));
72     SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? OT_NETIF_UNSPECIFIED : OT_NETIF_THREAD));
73 
74 exit:
75     otLogInfoDns("[server] started: %s", ErrorToString(error));
76 
77     if (error != kErrorNone)
78     {
79         IgnoreError(mSocket.Close());
80     }
81 
82     return error;
83 }
84 
Stop(void)85 void Server::Stop(void)
86 {
87     // Abort all query transactions
88     for (QueryTransaction &query : mQueryTransactions)
89     {
90         if (query.IsValid())
91         {
92             FinalizeQuery(query, Header::kResponseServerFailure);
93         }
94     }
95 
96     mTimer.Stop();
97 
98     IgnoreError(mSocket.Close());
99     otLogInfoDns("[server] stopped");
100 }
101 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)102 void Server::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
103 {
104     static_cast<Server *>(aContext)->HandleUdpReceive(*static_cast<Message *>(aMessage),
105                                                       *static_cast<const Ip6::MessageInfo *>(aMessageInfo));
106 }
107 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)108 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
109 {
110     Error  error = kErrorNone;
111     Header requestHeader;
112 
113     SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), requestHeader));
114     VerifyOrExit(requestHeader.GetType() == Header::kTypeQuery, error = kErrorDrop);
115 
116     ProcessQuery(requestHeader, aMessage, aMessageInfo);
117 exit:
118     return;
119 }
120 
ProcessQuery(const Header & aRequestHeader,Message & aRequestMessage,const Ip6::MessageInfo & aMessageInfo)121 void Server::ProcessQuery(const Header &aRequestHeader, Message &aRequestMessage, const Ip6::MessageInfo &aMessageInfo)
122 {
123     Error            error           = kErrorNone;
124     Message *        responseMessage = nullptr;
125     Header           responseHeader;
126     NameCompressInfo compressInfo(kDefaultDomainName);
127     Header::Response response                = Header::kResponseSuccess;
128     bool             resolveByQueryCallbacks = false;
129 
130     responseMessage = mSocket.NewMessage(0);
131     VerifyOrExit(responseMessage != nullptr, error = kErrorNoBufs);
132 
133     // Allocate space for DNS header
134     SuccessOrExit(error = responseMessage->SetLength(sizeof(Header)));
135 
136     // Setup initial DNS response header
137     responseHeader.Clear();
138     responseHeader.SetType(Header::kTypeResponse);
139     responseHeader.SetMessageId(aRequestHeader.GetMessageId());
140 
141     // Validate the query
142     VerifyOrExit(aRequestHeader.GetQueryType() == Header::kQueryTypeStandard,
143                  response = Header::kResponseNotImplemented);
144     VerifyOrExit(!aRequestHeader.IsTruncationFlagSet(), response = Header::kResponseFormatError);
145     VerifyOrExit(aRequestHeader.GetQuestionCount() > 0, response = Header::kResponseFormatError);
146 
147     response = AddQuestions(aRequestHeader, aRequestMessage, responseHeader, *responseMessage, compressInfo);
148     VerifyOrExit(response == Header::kResponseSuccess);
149 
150 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
151     // Answer the questions
152     response = ResolveBySrp(responseHeader, *responseMessage, compressInfo);
153 #endif
154 
155     // Resolve the question using query callbacks if SRP server failed to resolve the questions.
156     if (responseHeader.GetAnswerCount() == 0 &&
157         kErrorNone == ResolveByQueryCallbacks(responseHeader, *responseMessage, compressInfo, aMessageInfo))
158     {
159         resolveByQueryCallbacks = true;
160     }
161 
162 exit:
163     if (error == kErrorNone && !resolveByQueryCallbacks)
164     {
165         SendResponse(responseHeader, response, *responseMessage, aMessageInfo, mSocket);
166     }
167 
168     FreeMessageOnError(responseMessage, error);
169 }
170 
SendResponse(Header aHeader,Header::Response aResponseCode,Message & aMessage,const Ip6::MessageInfo & aMessageInfo,Ip6::Udp::Socket & aSocket)171 void Server::SendResponse(Header                  aHeader,
172                           Header::Response        aResponseCode,
173                           Message &               aMessage,
174                           const Ip6::MessageInfo &aMessageInfo,
175                           Ip6::Udp::Socket &      aSocket)
176 {
177     Error error;
178 
179     if (aResponseCode == Header::kResponseServerFailure)
180     {
181         otLogWarnDns("[server] failed to handle DNS query due to server failure");
182         aHeader.SetQuestionCount(0);
183         aHeader.SetAnswerCount(0);
184         aHeader.SetAdditionalRecordCount(0);
185         IgnoreError(aMessage.SetLength(sizeof(Header)));
186     }
187 
188     aHeader.SetResponseCode(aResponseCode);
189     aMessage.Write(0, aHeader);
190 
191     error = aSocket.SendTo(aMessage, aMessageInfo);
192 
193     FreeMessageOnError(&aMessage, error);
194 
195     if (error != kErrorNone)
196     {
197         otLogWarnDns("[server] failed to send DNS-SD reply: %s", otThreadErrorToString(error));
198     }
199     else
200     {
201         otLogInfoDns("[server] send DNS-SD reply: %s, RCODE=%d", otThreadErrorToString(error), aResponseCode);
202     }
203 }
204 
AddQuestions(const Header & aRequestHeader,const Message & aRequestMessage,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo)205 Header::Response Server::AddQuestions(const Header &    aRequestHeader,
206                                       const Message &   aRequestMessage,
207                                       Header &          aResponseHeader,
208                                       Message &         aResponseMessage,
209                                       NameCompressInfo &aCompressInfo)
210 {
211     Question         question;
212     uint16_t         readOffset;
213     Header::Response response = Header::kResponseSuccess;
214     char             name[Name::kMaxNameSize];
215 
216     readOffset = sizeof(Header);
217 
218     // Check and append the questions
219     for (uint16_t i = 0; i < aRequestHeader.GetQuestionCount(); i++)
220     {
221         NameComponentsOffsetInfo nameComponentsOffsetInfo;
222         uint16_t                 qtype;
223 
224         VerifyOrExit(kErrorNone == Name::ReadName(aRequestMessage, readOffset, name, sizeof(name)),
225                      response = Header::kResponseFormatError);
226         VerifyOrExit(kErrorNone == aRequestMessage.Read(readOffset, question), response = Header::kResponseFormatError);
227         readOffset += sizeof(question);
228 
229         qtype = question.GetType();
230 
231         VerifyOrExit(qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv ||
232                          qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAaaa,
233                      response = Header::kResponseNotImplemented);
234 
235         VerifyOrExit(kErrorNone == FindNameComponents(name, aCompressInfo.GetDomainName(), nameComponentsOffsetInfo),
236                      response = Header::kResponseNameError);
237 
238         switch (question.GetType())
239         {
240         case ResourceRecord::kTypePtr:
241             VerifyOrExit(nameComponentsOffsetInfo.IsServiceName(), response = Header::kResponseNameError);
242             break;
243         case ResourceRecord::kTypeSrv:
244             VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
245             break;
246         case ResourceRecord::kTypeTxt:
247             VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
248             break;
249         case ResourceRecord::kTypeAaaa:
250             VerifyOrExit(nameComponentsOffsetInfo.IsHostName(), response = Header::kResponseNameError);
251             break;
252         default:
253             ExitNow(response = Header::kResponseNotImplemented);
254         }
255 
256         VerifyOrExit(AppendQuestion(name, question, aResponseMessage, aCompressInfo) == kErrorNone,
257                      response = Header::kResponseServerFailure);
258     }
259 
260     aResponseHeader.SetQuestionCount(aRequestHeader.GetQuestionCount());
261 
262 exit:
263     return response;
264 }
265 
AppendQuestion(const char * aName,const Question & aQuestion,Message & aMessage,NameCompressInfo & aCompressInfo)266 Error Server::AppendQuestion(const char *      aName,
267                              const Question &  aQuestion,
268                              Message &         aMessage,
269                              NameCompressInfo &aCompressInfo)
270 {
271     Error error = kErrorNone;
272 
273     switch (aQuestion.GetType())
274     {
275     case ResourceRecord::kTypePtr:
276         SuccessOrExit(error = AppendServiceName(aMessage, aName, aCompressInfo));
277         break;
278     case ResourceRecord::kTypeSrv:
279     case ResourceRecord::kTypeTxt:
280         SuccessOrExit(error = AppendInstanceName(aMessage, aName, aCompressInfo));
281         break;
282     case ResourceRecord::kTypeAaaa:
283         SuccessOrExit(error = AppendHostName(aMessage, aName, aCompressInfo));
284         break;
285     default:
286         OT_ASSERT(false);
287     }
288 
289     error = aMessage.Append(aQuestion);
290 
291 exit:
292     return error;
293 }
294 
AppendPtrRecord(Message & aMessage,const char * aServiceName,const char * aInstanceName,uint32_t aTtl,NameCompressInfo & aCompressInfo)295 Error Server::AppendPtrRecord(Message &         aMessage,
296                               const char *      aServiceName,
297                               const char *      aInstanceName,
298                               uint32_t          aTtl,
299                               NameCompressInfo &aCompressInfo)
300 {
301     Error     error;
302     PtrRecord ptrRecord;
303     uint16_t  recordOffset;
304 
305     ptrRecord.Init();
306     ptrRecord.SetTtl(aTtl);
307 
308     SuccessOrExit(error = AppendServiceName(aMessage, aServiceName, aCompressInfo));
309 
310     recordOffset = aMessage.GetLength();
311     SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(ptrRecord)));
312 
313     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
314 
315     ptrRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
316     aMessage.Write(recordOffset, ptrRecord);
317 
318 exit:
319     return error;
320 }
321 
AppendSrvRecord(Message & aMessage,const char * aInstanceName,const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort,NameCompressInfo & aCompressInfo)322 Error Server::AppendSrvRecord(Message &         aMessage,
323                               const char *      aInstanceName,
324                               const char *      aHostName,
325                               uint32_t          aTtl,
326                               uint16_t          aPriority,
327                               uint16_t          aWeight,
328                               uint16_t          aPort,
329                               NameCompressInfo &aCompressInfo)
330 {
331     SrvRecord srvRecord;
332     Error     error = kErrorNone;
333     uint16_t  recordOffset;
334 
335     srvRecord.Init();
336     srvRecord.SetTtl(aTtl);
337     srvRecord.SetPriority(aPriority);
338     srvRecord.SetWeight(aWeight);
339     srvRecord.SetPort(aPort);
340 
341     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
342 
343     recordOffset = aMessage.GetLength();
344     SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(srvRecord)));
345 
346     SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
347 
348     srvRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
349     aMessage.Write(recordOffset, srvRecord);
350 
351 exit:
352     return error;
353 }
354 
AppendAaaaRecord(Message & aMessage,const char * aHostName,const Ip6::Address & aAddress,uint32_t aTtl,NameCompressInfo & aCompressInfo)355 Error Server::AppendAaaaRecord(Message &           aMessage,
356                                const char *        aHostName,
357                                const Ip6::Address &aAddress,
358                                uint32_t            aTtl,
359                                NameCompressInfo &  aCompressInfo)
360 {
361     AaaaRecord aaaaRecord;
362     Error      error;
363 
364     aaaaRecord.Init();
365     aaaaRecord.SetTtl(aTtl);
366     aaaaRecord.SetAddress(aAddress);
367 
368     SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
369     error = aMessage.Append(aaaaRecord);
370 
371 exit:
372     return error;
373 }
374 
AppendServiceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)375 Error Server::AppendServiceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
376 {
377     Error       error;
378     uint16_t    serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, aName);
379     const char *serviceName;
380 
381     // Check whether `aName` is a sub-type service name.
382     serviceName = StringFind(aName, kDnssdSubTypeLabel);
383 
384     if (serviceName != nullptr)
385     {
386         uint8_t subTypeLabelLength = static_cast<uint8_t>(serviceName - aName) + sizeof(kDnssdSubTypeLabel) - 1;
387 
388         SuccessOrExit(error = Name::AppendMultipleLabels(aName, subTypeLabelLength, aMessage));
389 
390         // Skip over the "._sub." label to get to the root service name.
391         serviceName += sizeof(kDnssdSubTypeLabel) - 1;
392     }
393     else
394     {
395         serviceName = aName;
396     }
397 
398     if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
399     {
400         error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
401     }
402     else
403     {
404         uint8_t  domainStart          = static_cast<uint8_t>(StringLength(serviceName, Name::kMaxNameSize - 1) -
405                                                    StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
406         uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
407 
408         serviceCompressOffset = aMessage.GetLength();
409         aCompressInfo.SetServiceNameOffset(serviceCompressOffset);
410 
411         if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
412         {
413             aCompressInfo.SetDomainNameOffset(serviceCompressOffset + domainStart);
414             error = Name::AppendName(serviceName, aMessage);
415         }
416         else
417         {
418             SuccessOrExit(error = Name::AppendMultipleLabels(serviceName, domainStart, aMessage));
419             error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
420         }
421     }
422 
423 exit:
424     return error;
425 }
426 
AppendInstanceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)427 Error Server::AppendInstanceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
428 {
429     Error    error;
430     uint16_t instanceCompressOffset = aCompressInfo.GetInstanceNameOffset(aMessage, aName);
431 
432     if (instanceCompressOffset != NameCompressInfo::kUnknownOffset)
433     {
434         error = Name::AppendPointerLabel(instanceCompressOffset, aMessage);
435     }
436     else
437     {
438         NameComponentsOffsetInfo nameComponentsInfo;
439 
440         IgnoreError(FindNameComponents(aName, aCompressInfo.GetDomainName(), nameComponentsInfo));
441         OT_ASSERT(nameComponentsInfo.IsServiceInstanceName());
442 
443         aCompressInfo.SetInstanceNameOffset(aMessage.GetLength());
444 
445         // Append the instance name as one label
446         SuccessOrExit(error = Name::AppendLabel(aName, nameComponentsInfo.mServiceOffset - 1, aMessage));
447 
448         {
449             const char *serviceName           = aName + nameComponentsInfo.mServiceOffset;
450             uint16_t    serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, serviceName);
451 
452             if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
453             {
454                 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
455             }
456             else
457             {
458                 aCompressInfo.SetServiceNameOffset(aMessage.GetLength());
459                 error = Name::AppendName(serviceName, aMessage);
460             }
461         }
462     }
463 
464 exit:
465     return error;
466 }
467 
AppendTxtRecord(Message & aMessage,const char * aInstanceName,const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl,NameCompressInfo & aCompressInfo)468 Error Server::AppendTxtRecord(Message &         aMessage,
469                               const char *      aInstanceName,
470                               const void *      aTxtData,
471                               uint16_t          aTxtLength,
472                               uint32_t          aTtl,
473                               NameCompressInfo &aCompressInfo)
474 {
475     Error     error = kErrorNone;
476     TxtRecord txtRecord;
477 
478     SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
479 
480     txtRecord.Init();
481     txtRecord.SetTtl(aTtl);
482     txtRecord.SetLength(aTxtLength);
483 
484     SuccessOrExit(error = aMessage.Append(txtRecord));
485     error = aMessage.AppendBytes(aTxtData, aTxtLength);
486 
487 exit:
488     return error;
489 }
490 
AppendHostName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)491 Error Server::AppendHostName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
492 {
493     Error    error;
494     uint16_t hostCompressOffset = aCompressInfo.GetHostNameOffset(aMessage, aName);
495 
496     if (hostCompressOffset != NameCompressInfo::kUnknownOffset)
497     {
498         error = Name::AppendPointerLabel(hostCompressOffset, aMessage);
499     }
500     else
501     {
502         uint8_t  domainStart          = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength) -
503                                                    StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
504         uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
505 
506         hostCompressOffset = aMessage.GetLength();
507         aCompressInfo.SetHostNameOffset(hostCompressOffset);
508 
509         if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
510         {
511             aCompressInfo.SetDomainNameOffset(hostCompressOffset + domainStart);
512             error = Name::AppendName(aName, aMessage);
513         }
514         else
515         {
516             SuccessOrExit(error = Name::AppendMultipleLabels(aName, domainStart, aMessage));
517             error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
518         }
519     }
520 
521 exit:
522     return error;
523 }
524 
IncResourceRecordCount(Header & aHeader,bool aAdditional)525 void Server::IncResourceRecordCount(Header &aHeader, bool aAdditional)
526 {
527     if (aAdditional)
528     {
529         aHeader.SetAdditionalRecordCount(aHeader.GetAdditionalRecordCount() + 1);
530     }
531     else
532     {
533         aHeader.SetAnswerCount(aHeader.GetAnswerCount() + 1);
534     }
535 }
536 
FindNameComponents(const char * aName,const char * aDomain,NameComponentsOffsetInfo & aInfo)537 Error Server::FindNameComponents(const char *aName, const char *aDomain, NameComponentsOffsetInfo &aInfo)
538 {
539     uint8_t nameLen   = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength));
540     uint8_t domainLen = static_cast<uint8_t>(StringLength(aDomain, Name::kMaxNameLength));
541     Error   error     = kErrorNone;
542     uint8_t labelBegin, labelEnd;
543 
544     VerifyOrExit(Name::IsSubDomainOf(aName, aDomain), error = kErrorInvalidArgs);
545 
546     labelBegin          = nameLen - domainLen;
547     aInfo.mDomainOffset = labelBegin;
548 
549     while (true)
550     {
551         error = FindPreviousLabel(aName, labelBegin, labelEnd);
552 
553         VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
554 
555         if (labelEnd == labelBegin + kProtocolLabelLength &&
556             (memcmp(&aName[labelBegin], kDnssdProtocolUdp, kProtocolLabelLength) == 0 ||
557              memcmp(&aName[labelBegin], kDnssdProtocolTcp, kProtocolLabelLength) == 0))
558         {
559             // <Protocol> label found
560             aInfo.mProtocolOffset = labelBegin;
561             break;
562         }
563     }
564 
565     // Get service label <Service>
566     error = FindPreviousLabel(aName, labelBegin, labelEnd);
567     VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
568 
569     aInfo.mServiceOffset = labelBegin;
570 
571     // Check for service subtype
572     error = FindPreviousLabel(aName, labelBegin, labelEnd);
573     VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
574 
575     // Note that `kDnssdSubTypeLabel` is "._sub.". Here we get the
576     // label only so we want to compare it with "_sub".
577     if ((labelEnd == labelBegin + kSubTypeLabelLength) &&
578         (memcmp(&aName[labelBegin], kDnssdSubTypeLabel + 1, kSubTypeLabelLength) == 0))
579     {
580         SuccessOrExit(error = FindPreviousLabel(aName, labelBegin, labelEnd));
581         VerifyOrExit(labelBegin == 0, error = kErrorInvalidArgs);
582         aInfo.mSubTypeOffset = labelBegin;
583         ExitNow();
584     }
585 
586     // Treat everything before <Service> as <Instance> label
587     aInfo.mInstanceOffset = 0;
588 
589 exit:
590     return error;
591 }
592 
FindPreviousLabel(const char * aName,uint8_t & aStart,uint8_t & aStop)593 Error Server::FindPreviousLabel(const char *aName, uint8_t &aStart, uint8_t &aStop)
594 {
595     // This method finds the previous label before the current label (whose start index is @p aStart), and updates @p
596     // aStart to the start index of the label and @p aStop to the index of the dot just after the label.
597     // @note The input value of @p aStop does not matter because it is only used to output.
598 
599     Error   error = kErrorNone;
600     uint8_t start = aStart;
601     uint8_t end;
602 
603     VerifyOrExit(start > 0, error = kErrorNotFound);
604     VerifyOrExit(aName[--start] == Name::kLabelSeperatorChar, error = kErrorInvalidArgs);
605 
606     end = start;
607     while (start > 0 && aName[start - 1] != Name::kLabelSeperatorChar)
608     {
609         start--;
610     }
611 
612     VerifyOrExit(start < end, error = kErrorInvalidArgs);
613 
614     aStart = start;
615     aStop  = end;
616 
617 exit:
618     return error;
619 }
620 
621 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
ResolveBySrp(Header & aResponseHeader,Message & aResponseMessage,Server::NameCompressInfo & aCompressInfo)622 Header::Response Server::ResolveBySrp(Header &                  aResponseHeader,
623                                       Message &                 aResponseMessage,
624                                       Server::NameCompressInfo &aCompressInfo)
625 {
626     Question         question;
627     uint16_t         readOffset = sizeof(Header);
628     Header::Response response   = Header::kResponseSuccess;
629     char             name[Name::kMaxNameSize];
630 
631     for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
632     {
633         IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
634         IgnoreError(aResponseMessage.Read(readOffset, question));
635         readOffset += sizeof(question);
636 
637         response = ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo,
638                                         /* aAdditional */ false);
639 
640         otLogInfoDns("[server] ANSWER: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d",
641                      aResponseHeader.GetMessageId(), name, question.GetClass(), question.GetType(), response);
642     }
643 
644     // Answer the questions with additional RRs if required
645     if (aResponseHeader.GetAnswerCount() > 0)
646     {
647         readOffset = sizeof(Header);
648         for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
649         {
650             IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
651             IgnoreError(aResponseMessage.Read(readOffset, question));
652             readOffset += sizeof(question);
653 
654             VerifyOrExit(Header::kResponseServerFailure != ResolveQuestionBySrp(name, question, aResponseHeader,
655                                                                                 aResponseMessage, aCompressInfo,
656                                                                                 /* aAdditional */ true),
657                          response = Header::kResponseServerFailure);
658 
659             otLogInfoDns("[server] ADDITIONAL: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d",
660                          aResponseHeader.GetMessageId(), name, question.GetClass(), question.GetType(), response);
661         }
662     }
663 exit:
664     return response;
665 }
666 
ResolveQuestionBySrp(const char * aName,const Question & aQuestion,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,bool aAdditional)667 Header::Response Server::ResolveQuestionBySrp(const char *      aName,
668                                               const Question &  aQuestion,
669                                               Header &          aResponseHeader,
670                                               Message &         aResponseMessage,
671                                               NameCompressInfo &aCompressInfo,
672                                               bool              aAdditional)
673 {
674     Error                    error    = kErrorNone;
675     const Srp::Server::Host *host     = nullptr;
676     TimeMilli                now      = TimerMilli::GetNow();
677     uint16_t                 qtype    = aQuestion.GetType();
678     Header::Response         response = Header::kResponseNameError;
679 
680     while ((host = GetNextSrpHost(host)) != nullptr)
681     {
682         bool        needAdditionalAaaaRecord = false;
683         const char *hostName                 = host->GetFullName();
684 
685         // Handle PTR/SRV/TXT query
686         if (qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv || qtype == ResourceRecord::kTypeTxt)
687         {
688             const Srp::Server::Service *service = nullptr;
689 
690             while ((service = GetNextSrpService(*host, service)) != nullptr)
691             {
692                 uint32_t    instanceTtl         = TimeMilli::MsecToSec(service->GetExpireTime() - TimerMilli::GetNow());
693                 const char *instanceName        = service->GetInstanceName();
694                 bool        serviceNameMatched  = service->MatchesServiceName(aName);
695                 bool        instanceNameMatched = service->MatchesInstanceName(aName);
696                 bool        ptrQueryMatched     = qtype == ResourceRecord::kTypePtr && serviceNameMatched;
697                 bool        srvQueryMatched     = qtype == ResourceRecord::kTypeSrv && instanceNameMatched;
698                 bool        txtQueryMatched     = qtype == ResourceRecord::kTypeTxt && instanceNameMatched;
699 
700                 if (ptrQueryMatched || srvQueryMatched)
701                 {
702                     needAdditionalAaaaRecord = true;
703                 }
704 
705                 if (!aAdditional && ptrQueryMatched)
706                 {
707                     SuccessOrExit(
708                         error = AppendPtrRecord(aResponseMessage, aName, instanceName, instanceTtl, aCompressInfo));
709                     IncResourceRecordCount(aResponseHeader, aAdditional);
710                     response = Header::kResponseSuccess;
711                 }
712 
713                 if ((!aAdditional && srvQueryMatched) ||
714                     (aAdditional && ptrQueryMatched &&
715                      !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeSrv)))
716                 {
717                     SuccessOrExit(error = AppendSrvRecord(aResponseMessage, instanceName, hostName, instanceTtl,
718                                                           service->GetPriority(), service->GetWeight(),
719                                                           service->GetPort(), aCompressInfo));
720                     IncResourceRecordCount(aResponseHeader, aAdditional);
721                     response = Header::kResponseSuccess;
722                 }
723 
724                 if ((!aAdditional && txtQueryMatched) ||
725                     (aAdditional && ptrQueryMatched &&
726                      !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeTxt)))
727                 {
728                     SuccessOrExit(error = AppendTxtRecord(aResponseMessage, instanceName, service->GetTxtData(),
729                                                           service->GetTxtDataLength(), instanceTtl, aCompressInfo));
730                     IncResourceRecordCount(aResponseHeader, aAdditional);
731                     response = Header::kResponseSuccess;
732                 }
733             }
734         }
735 
736         // Handle AAAA query
737         if ((!aAdditional && qtype == ResourceRecord::kTypeAaaa && host->Matches(aName)) ||
738             (aAdditional && needAdditionalAaaaRecord &&
739              !HasQuestion(aResponseHeader, aResponseMessage, hostName, ResourceRecord::kTypeAaaa)))
740         {
741             uint8_t             addrNum;
742             const Ip6::Address *addrs   = host->GetAddresses(addrNum);
743             uint32_t            hostTtl = TimeMilli::MsecToSec(host->GetExpireTime() - now);
744 
745             for (uint8_t i = 0; i < addrNum; i++)
746             {
747                 SuccessOrExit(error = AppendAaaaRecord(aResponseMessage, hostName, addrs[i], hostTtl, aCompressInfo));
748                 IncResourceRecordCount(aResponseHeader, aAdditional);
749             }
750 
751             response = Header::kResponseSuccess;
752         }
753     }
754 
755 exit:
756     return error == kErrorNone ? response : Header::kResponseServerFailure;
757 }
758 
GetNextSrpHost(const Srp::Server::Host * aHost)759 const Srp::Server::Host *Server::GetNextSrpHost(const Srp::Server::Host *aHost)
760 {
761     const Srp::Server::Host *host = Get<Srp::Server>().GetNextHost(aHost);
762 
763     while (host != nullptr && host->IsDeleted())
764     {
765         host = Get<Srp::Server>().GetNextHost(host);
766     }
767 
768     return host;
769 }
770 
GetNextSrpService(const Srp::Server::Host & aHost,const Srp::Server::Service * aService)771 const Srp::Server::Service *Server::GetNextSrpService(const Srp::Server::Host &   aHost,
772                                                       const Srp::Server::Service *aService)
773 {
774     return aHost.FindNextService(aService, Srp::Server::kFlagsAnyTypeActiveService);
775 }
776 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
777 
ResolveByQueryCallbacks(Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)778 Error Server::ResolveByQueryCallbacks(Header &                aResponseHeader,
779                                       Message &               aResponseMessage,
780                                       NameCompressInfo &      aCompressInfo,
781                                       const Ip6::MessageInfo &aMessageInfo)
782 {
783     QueryTransaction *query = nullptr;
784     DnsQueryType      queryType;
785     char              name[Name::kMaxNameSize];
786 
787     Error error = kErrorNone;
788 
789     VerifyOrExit(mQuerySubscribe != nullptr, error = kErrorFailed);
790 
791     queryType = GetQueryTypeAndName(aResponseHeader, aResponseMessage, name);
792     VerifyOrExit(queryType != kDnsQueryNone, error = kErrorNotImplemented);
793 
794     query = NewQuery(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
795     VerifyOrExit(query != nullptr, error = kErrorNoBufs);
796 
797     mQuerySubscribe(mQueryCallbackContext, name);
798 
799 exit:
800     return error;
801 }
802 
NewQuery(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)803 Server::QueryTransaction *Server::NewQuery(const Header &          aResponseHeader,
804                                            Message &               aResponseMessage,
805                                            const NameCompressInfo &aCompressInfo,
806                                            const Ip6::MessageInfo &aMessageInfo)
807 {
808     QueryTransaction *newQuery = nullptr;
809 
810     for (QueryTransaction &query : mQueryTransactions)
811     {
812         if (query.IsValid())
813         {
814             continue;
815         }
816 
817         query.Init(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
818         ExitNow(newQuery = &query);
819     }
820 
821 exit:
822     if (newQuery != nullptr)
823     {
824         ResetTimer();
825     }
826 
827     return newQuery;
828 }
829 
CanAnswerQuery(const QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)830 bool Server::CanAnswerQuery(const QueryTransaction &          aQuery,
831                             const char *                      aServiceFullName,
832                             const otDnssdServiceInstanceInfo &aInstanceInfo)
833 {
834     char         name[Name::kMaxNameSize];
835     DnsQueryType sdType;
836     bool         canAnswer = false;
837 
838     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
839 
840     switch (sdType)
841     {
842     case kDnsQueryBrowse:
843         canAnswer = (strcmp(name, aServiceFullName) == 0);
844         break;
845     case kDnsQueryResolve:
846         canAnswer = (strcmp(name, aInstanceInfo.mFullName) == 0);
847         break;
848     default:
849         break;
850     }
851 
852     return canAnswer;
853 }
854 
CanAnswerQuery(const Server::QueryTransaction & aQuery,const char * aHostFullName)855 bool Server::CanAnswerQuery(const Server::QueryTransaction &aQuery, const char *aHostFullName)
856 {
857     char         name[Name::kMaxNameSize];
858     DnsQueryType sdType;
859 
860     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
861     return (sdType == kDnsQueryResolveHost) && (strcmp(name, aHostFullName) == 0);
862 }
863 
AnswerQuery(QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)864 void Server::AnswerQuery(QueryTransaction &                aQuery,
865                          const char *                      aServiceFullName,
866                          const otDnssdServiceInstanceInfo &aInstanceInfo)
867 {
868     Header &          responseHeader  = aQuery.GetResponseHeader();
869     Message &         responseMessage = aQuery.GetResponseMessage();
870     Error             error           = kErrorNone;
871     NameCompressInfo &compressInfo    = aQuery.GetNameCompressInfo();
872 
873     if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aServiceFullName,
874                     ResourceRecord::kTypePtr))
875     {
876         SuccessOrExit(error = AppendPtrRecord(responseMessage, aServiceFullName, aInstanceInfo.mFullName,
877                                               aInstanceInfo.mTtl, compressInfo));
878         IncResourceRecordCount(responseHeader, false);
879     }
880 
881     for (uint8_t additional = 0; additional <= 1; additional++)
882     {
883         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
884                         ResourceRecord::kTypeSrv) == !additional)
885         {
886             SuccessOrExit(error = AppendSrvRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mHostName,
887                                                   aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
888                                                   aInstanceInfo.mPort, compressInfo));
889             IncResourceRecordCount(responseHeader, additional);
890         }
891 
892         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
893                         ResourceRecord::kTypeTxt) == !additional)
894         {
895             SuccessOrExit(error = AppendTxtRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mTxtData,
896                                                   aInstanceInfo.mTxtLength, aInstanceInfo.mTtl, compressInfo));
897             IncResourceRecordCount(responseHeader, additional);
898         }
899 
900         if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mHostName,
901                         ResourceRecord::kTypeAaaa) == !additional)
902         {
903             for (uint8_t i = 0; i < aInstanceInfo.mAddressNum; i++)
904             {
905                 const Ip6::Address &address = static_cast<const Ip6::Address &>(aInstanceInfo.mAddresses[i]);
906 
907                 OT_ASSERT(!address.IsUnspecified() && !address.IsLinkLocal() && !address.IsMulticast() &&
908                           !address.IsLoopback());
909 
910                 SuccessOrExit(error = AppendAaaaRecord(responseMessage, aInstanceInfo.mHostName, address,
911                                                        aInstanceInfo.mTtl, compressInfo));
912                 IncResourceRecordCount(responseHeader, additional);
913             }
914         }
915     }
916 
917 exit:
918     FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
919     ResetTimer();
920 }
921 
AnswerQuery(QueryTransaction & aQuery,const char * aHostFullName,const otDnssdHostInfo & aHostInfo)922 void Server::AnswerQuery(QueryTransaction &aQuery, const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
923 {
924     Header &          responseHeader  = aQuery.GetResponseHeader();
925     Message &         responseMessage = aQuery.GetResponseMessage();
926     Error             error           = kErrorNone;
927     NameCompressInfo &compressInfo    = aQuery.GetNameCompressInfo();
928 
929     if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aHostFullName, ResourceRecord::kTypeAaaa))
930     {
931         for (uint8_t i = 0; i < aHostInfo.mAddressNum; i++)
932         {
933             const Ip6::Address &address = static_cast<const Ip6::Address &>(aHostInfo.mAddresses[i]);
934 
935             OT_ASSERT(!address.IsUnspecified() && !address.IsMulticast() && !address.IsLinkLocal() &&
936                       !address.IsLoopback());
937 
938             SuccessOrExit(error =
939                               AppendAaaaRecord(responseMessage, aHostFullName, address, aHostInfo.mTtl, compressInfo));
940             IncResourceRecordCount(responseHeader, /* aAdditional */ false);
941         }
942     }
943 
944 exit:
945     FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
946     ResetTimer();
947 }
948 
SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,otDnssdQueryUnsubscribeCallback aUnsubscribe,void * aContext)949 void Server::SetQueryCallbacks(otDnssdQuerySubscribeCallback   aSubscribe,
950                                otDnssdQueryUnsubscribeCallback aUnsubscribe,
951                                void *                          aContext)
952 {
953     OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
954 
955     mQuerySubscribe       = aSubscribe;
956     mQueryUnsubscribe     = aUnsubscribe;
957     mQueryCallbackContext = aContext;
958 }
959 
HandleDiscoveredServiceInstance(const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)960 void Server::HandleDiscoveredServiceInstance(const char *                      aServiceFullName,
961                                              const otDnssdServiceInstanceInfo &aInstanceInfo)
962 {
963     OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeperatorChar));
964     OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeperatorChar));
965     OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeperatorChar));
966 
967     for (QueryTransaction &query : mQueryTransactions)
968     {
969         if (query.IsValid() && CanAnswerQuery(query, aServiceFullName, aInstanceInfo))
970         {
971             AnswerQuery(query, aServiceFullName, aInstanceInfo);
972         }
973     }
974 }
975 
HandleDiscoveredHost(const char * aHostFullName,const otDnssdHostInfo & aHostInfo)976 void Server::HandleDiscoveredHost(const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
977 {
978     OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeperatorChar));
979 
980     for (QueryTransaction &query : mQueryTransactions)
981     {
982         if (query.IsValid() && CanAnswerQuery(query, aHostFullName))
983         {
984             AnswerQuery(query, aHostFullName, aHostInfo);
985         }
986     }
987 }
988 
GetNextQuery(const otDnssdQuery * aQuery) const989 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
990 {
991     const QueryTransaction *cur   = &mQueryTransactions[0];
992     const QueryTransaction *found = nullptr;
993     const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
994 
995     if (aQuery != nullptr)
996     {
997         cur = query + 1;
998     }
999 
1000     for (; cur < OT_ARRAY_END(mQueryTransactions); cur++)
1001     {
1002         if (cur->IsValid())
1003         {
1004             found = cur;
1005             break;
1006         }
1007     }
1008 
1009     return static_cast<const otDnssdQuery *>(found);
1010 }
1011 
GetQueryTypeAndName(const otDnssdQuery * aQuery,char (& aName)[Name::kMaxNameSize])1012 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize])
1013 {
1014     const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1015 
1016     OT_ASSERT(query->IsValid());
1017     return GetQueryTypeAndName(query->GetResponseHeader(), query->GetResponseMessage(), aName);
1018 }
1019 
GetQueryTypeAndName(const Header & aHeader,const Message & aMessage,char (& aName)[Name::kMaxNameSize])1020 Server::DnsQueryType Server::GetQueryTypeAndName(const Header & aHeader,
1021                                                  const Message &aMessage,
1022                                                  char (&aName)[Name::kMaxNameSize])
1023 {
1024     DnsQueryType sdType = kDnsQueryNone;
1025 
1026     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1027     {
1028         Question question;
1029 
1030         IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1031         IgnoreError(aMessage.Read(readOffset, question));
1032         readOffset += sizeof(question);
1033 
1034         switch (question.GetType())
1035         {
1036         case ResourceRecord::kTypePtr:
1037             ExitNow(sdType = kDnsQueryBrowse);
1038         case ResourceRecord::kTypeSrv:
1039         case ResourceRecord::kTypeTxt:
1040             ExitNow(sdType = kDnsQueryResolve);
1041         }
1042     }
1043 
1044     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1045     {
1046         Question question;
1047 
1048         IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1049         IgnoreError(aMessage.Read(readOffset, question));
1050         readOffset += sizeof(question);
1051 
1052         switch (question.GetType())
1053         {
1054         case ResourceRecord::kTypeAaaa:
1055         case ResourceRecord::kTypeA:
1056             ExitNow(sdType = kDnsQueryResolveHost);
1057         }
1058     }
1059 
1060 exit:
1061     return sdType;
1062 }
1063 
HasQuestion(const Header & aHeader,const Message & aMessage,const char * aName,uint16_t aQuestionType)1064 bool Server::HasQuestion(const Header &aHeader, const Message &aMessage, const char *aName, uint16_t aQuestionType)
1065 {
1066     bool found = false;
1067 
1068     for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1069     {
1070         Question question;
1071         Error    error;
1072 
1073         error = Name::CompareName(aMessage, readOffset, aName);
1074         IgnoreError(aMessage.Read(readOffset, question));
1075         readOffset += sizeof(question);
1076 
1077         if (error == kErrorNone && aQuestionType == question.GetType())
1078         {
1079             ExitNow(found = true);
1080         }
1081     }
1082 
1083 exit:
1084     return found;
1085 }
1086 
HandleTimer(Timer & aTimer)1087 void Server::HandleTimer(Timer &aTimer)
1088 {
1089     aTimer.Get<Server>().HandleTimer();
1090 }
1091 
HandleTimer(void)1092 void Server::HandleTimer(void)
1093 {
1094     TimeMilli now = TimerMilli::GetNow();
1095 
1096     for (QueryTransaction &query : mQueryTransactions)
1097     {
1098         TimeMilli expire;
1099 
1100         if (!query.IsValid())
1101         {
1102             continue;
1103         }
1104 
1105         expire = query.GetStartTime() + kQueryTimeout;
1106         if (expire <= now)
1107         {
1108             FinalizeQuery(query, Header::kResponseSuccess);
1109         }
1110     }
1111 
1112     ResetTimer();
1113 }
1114 
ResetTimer(void)1115 void Server::ResetTimer(void)
1116 {
1117     TimeMilli now        = TimerMilli::GetNow();
1118     TimeMilli nextExpire = now.GetDistantFuture();
1119 
1120     for (QueryTransaction &query : mQueryTransactions)
1121     {
1122         TimeMilli expire;
1123 
1124         if (!query.IsValid())
1125         {
1126             continue;
1127         }
1128 
1129         expire = query.GetStartTime() + kQueryTimeout;
1130         if (expire <= now)
1131         {
1132             nextExpire = now;
1133         }
1134         else if (expire < nextExpire)
1135         {
1136             nextExpire = expire;
1137         }
1138     }
1139 
1140     if (nextExpire < now.GetDistantFuture())
1141     {
1142         mTimer.FireAt(nextExpire);
1143     }
1144     else
1145     {
1146         mTimer.Stop();
1147     }
1148 }
1149 
FinalizeQuery(QueryTransaction & aQuery,Header::Response aResponseCode)1150 void Server::FinalizeQuery(QueryTransaction &aQuery, Header::Response aResponseCode)
1151 {
1152     char         name[Name::kMaxNameSize];
1153     DnsQueryType sdType;
1154 
1155     OT_ASSERT(mQueryUnsubscribe != nullptr);
1156 
1157     sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
1158 
1159     OT_ASSERT(sdType != kDnsQueryNone);
1160     OT_UNUSED_VARIABLE(sdType);
1161 
1162     mQueryUnsubscribe(mQueryCallbackContext, name);
1163     aQuery.Finalize(aResponseCode, mSocket);
1164 }
1165 
Init(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)1166 void Server::QueryTransaction::Init(const Header &          aResponseHeader,
1167                                     Message &               aResponseMessage,
1168                                     const NameCompressInfo &aCompressInfo,
1169                                     const Ip6::MessageInfo &aMessageInfo)
1170 {
1171     OT_ASSERT(mResponseMessage == nullptr);
1172 
1173     mResponseHeader  = aResponseHeader;
1174     mResponseMessage = &aResponseMessage;
1175     mCompressInfo    = aCompressInfo;
1176     mMessageInfo     = aMessageInfo;
1177     mStartTime       = TimerMilli::GetNow();
1178 }
1179 
Finalize(Header::Response aResponseMessage,Ip6::Udp::Socket & aSocket)1180 void Server::QueryTransaction::Finalize(Header::Response aResponseMessage, Ip6::Udp::Socket &aSocket)
1181 {
1182     OT_ASSERT(mResponseMessage != nullptr);
1183 
1184     SendResponse(mResponseHeader, aResponseMessage, *mResponseMessage, mMessageInfo, aSocket);
1185     mResponseMessage = nullptr;
1186 }
1187 
1188 } // namespace ServiceDiscovery
1189 } // namespace Dns
1190 } // namespace ot
1191 
1192 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
1193