1 /*
2 * Copyright (c) 2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements the DNS-SD server.
32 */
33
34 #include "dnssd_server.hpp"
35
36 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
37
38 #include "common/code_utils.hpp"
39 #include "common/debug.hpp"
40 #include "common/instance.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/logging.hpp"
43 #include "net/srp_server.hpp"
44 #include "net/udp6.hpp"
45
46 namespace ot {
47 namespace Dns {
48 namespace ServiceDiscovery {
49
50 const char Server::kDnssdProtocolUdp[4] = {'_', 'u', 'd', 'p'};
51 const char Server::kDnssdProtocolTcp[4] = {'_', 't', 'c', 'p'};
52 const char Server::kDnssdSubTypeLabel[] = "._sub.";
53 const char Server::kDefaultDomainName[] = "default.service.arpa.";
54
Server(Instance & aInstance)55 Server::Server(Instance &aInstance)
56 : InstanceLocator(aInstance)
57 , mSocket(aInstance)
58 , mQueryCallbackContext(nullptr)
59 , mQuerySubscribe(nullptr)
60 , mQueryUnsubscribe(nullptr)
61 , mTimer(aInstance, Server::HandleTimer)
62 {
63 }
64
Start(void)65 Error Server::Start(void)
66 {
67 Error error = kErrorNone;
68
69 VerifyOrExit(!IsRunning());
70
71 SuccessOrExit(error = mSocket.Open(&Server::HandleUdpReceive, this));
72 SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? OT_NETIF_UNSPECIFIED : OT_NETIF_THREAD));
73
74 exit:
75 otLogInfoDns("[server] started: %s", ErrorToString(error));
76
77 if (error != kErrorNone)
78 {
79 IgnoreError(mSocket.Close());
80 }
81
82 return error;
83 }
84
Stop(void)85 void Server::Stop(void)
86 {
87 // Abort all query transactions
88 for (QueryTransaction &query : mQueryTransactions)
89 {
90 if (query.IsValid())
91 {
92 FinalizeQuery(query, Header::kResponseServerFailure);
93 }
94 }
95
96 mTimer.Stop();
97
98 IgnoreError(mSocket.Close());
99 otLogInfoDns("[server] stopped");
100 }
101
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)102 void Server::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
103 {
104 static_cast<Server *>(aContext)->HandleUdpReceive(*static_cast<Message *>(aMessage),
105 *static_cast<const Ip6::MessageInfo *>(aMessageInfo));
106 }
107
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)108 void Server::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
109 {
110 Error error = kErrorNone;
111 Header requestHeader;
112
113 SuccessOrExit(error = aMessage.Read(aMessage.GetOffset(), requestHeader));
114 VerifyOrExit(requestHeader.GetType() == Header::kTypeQuery, error = kErrorDrop);
115
116 ProcessQuery(requestHeader, aMessage, aMessageInfo);
117 exit:
118 return;
119 }
120
ProcessQuery(const Header & aRequestHeader,Message & aRequestMessage,const Ip6::MessageInfo & aMessageInfo)121 void Server::ProcessQuery(const Header &aRequestHeader, Message &aRequestMessage, const Ip6::MessageInfo &aMessageInfo)
122 {
123 Error error = kErrorNone;
124 Message * responseMessage = nullptr;
125 Header responseHeader;
126 NameCompressInfo compressInfo(kDefaultDomainName);
127 Header::Response response = Header::kResponseSuccess;
128 bool resolveByQueryCallbacks = false;
129
130 responseMessage = mSocket.NewMessage(0);
131 VerifyOrExit(responseMessage != nullptr, error = kErrorNoBufs);
132
133 // Allocate space for DNS header
134 SuccessOrExit(error = responseMessage->SetLength(sizeof(Header)));
135
136 // Setup initial DNS response header
137 responseHeader.Clear();
138 responseHeader.SetType(Header::kTypeResponse);
139 responseHeader.SetMessageId(aRequestHeader.GetMessageId());
140
141 // Validate the query
142 VerifyOrExit(aRequestHeader.GetQueryType() == Header::kQueryTypeStandard,
143 response = Header::kResponseNotImplemented);
144 VerifyOrExit(!aRequestHeader.IsTruncationFlagSet(), response = Header::kResponseFormatError);
145 VerifyOrExit(aRequestHeader.GetQuestionCount() > 0, response = Header::kResponseFormatError);
146
147 response = AddQuestions(aRequestHeader, aRequestMessage, responseHeader, *responseMessage, compressInfo);
148 VerifyOrExit(response == Header::kResponseSuccess);
149
150 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
151 // Answer the questions
152 response = ResolveBySrp(responseHeader, *responseMessage, compressInfo);
153 #endif
154
155 // Resolve the question using query callbacks if SRP server failed to resolve the questions.
156 if (responseHeader.GetAnswerCount() == 0 &&
157 kErrorNone == ResolveByQueryCallbacks(responseHeader, *responseMessage, compressInfo, aMessageInfo))
158 {
159 resolveByQueryCallbacks = true;
160 }
161
162 exit:
163 if (error == kErrorNone && !resolveByQueryCallbacks)
164 {
165 SendResponse(responseHeader, response, *responseMessage, aMessageInfo, mSocket);
166 }
167
168 FreeMessageOnError(responseMessage, error);
169 }
170
SendResponse(Header aHeader,Header::Response aResponseCode,Message & aMessage,const Ip6::MessageInfo & aMessageInfo,Ip6::Udp::Socket & aSocket)171 void Server::SendResponse(Header aHeader,
172 Header::Response aResponseCode,
173 Message & aMessage,
174 const Ip6::MessageInfo &aMessageInfo,
175 Ip6::Udp::Socket & aSocket)
176 {
177 Error error;
178
179 if (aResponseCode == Header::kResponseServerFailure)
180 {
181 otLogWarnDns("[server] failed to handle DNS query due to server failure");
182 aHeader.SetQuestionCount(0);
183 aHeader.SetAnswerCount(0);
184 aHeader.SetAdditionalRecordCount(0);
185 IgnoreError(aMessage.SetLength(sizeof(Header)));
186 }
187
188 aHeader.SetResponseCode(aResponseCode);
189 aMessage.Write(0, aHeader);
190
191 error = aSocket.SendTo(aMessage, aMessageInfo);
192
193 FreeMessageOnError(&aMessage, error);
194
195 if (error != kErrorNone)
196 {
197 otLogWarnDns("[server] failed to send DNS-SD reply: %s", otThreadErrorToString(error));
198 }
199 else
200 {
201 otLogInfoDns("[server] send DNS-SD reply: %s, RCODE=%d", otThreadErrorToString(error), aResponseCode);
202 }
203 }
204
AddQuestions(const Header & aRequestHeader,const Message & aRequestMessage,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo)205 Header::Response Server::AddQuestions(const Header & aRequestHeader,
206 const Message & aRequestMessage,
207 Header & aResponseHeader,
208 Message & aResponseMessage,
209 NameCompressInfo &aCompressInfo)
210 {
211 Question question;
212 uint16_t readOffset;
213 Header::Response response = Header::kResponseSuccess;
214 char name[Name::kMaxNameSize];
215
216 readOffset = sizeof(Header);
217
218 // Check and append the questions
219 for (uint16_t i = 0; i < aRequestHeader.GetQuestionCount(); i++)
220 {
221 NameComponentsOffsetInfo nameComponentsOffsetInfo;
222 uint16_t qtype;
223
224 VerifyOrExit(kErrorNone == Name::ReadName(aRequestMessage, readOffset, name, sizeof(name)),
225 response = Header::kResponseFormatError);
226 VerifyOrExit(kErrorNone == aRequestMessage.Read(readOffset, question), response = Header::kResponseFormatError);
227 readOffset += sizeof(question);
228
229 qtype = question.GetType();
230
231 VerifyOrExit(qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv ||
232 qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAaaa,
233 response = Header::kResponseNotImplemented);
234
235 VerifyOrExit(kErrorNone == FindNameComponents(name, aCompressInfo.GetDomainName(), nameComponentsOffsetInfo),
236 response = Header::kResponseNameError);
237
238 switch (question.GetType())
239 {
240 case ResourceRecord::kTypePtr:
241 VerifyOrExit(nameComponentsOffsetInfo.IsServiceName(), response = Header::kResponseNameError);
242 break;
243 case ResourceRecord::kTypeSrv:
244 VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
245 break;
246 case ResourceRecord::kTypeTxt:
247 VerifyOrExit(nameComponentsOffsetInfo.IsServiceInstanceName(), response = Header::kResponseNameError);
248 break;
249 case ResourceRecord::kTypeAaaa:
250 VerifyOrExit(nameComponentsOffsetInfo.IsHostName(), response = Header::kResponseNameError);
251 break;
252 default:
253 ExitNow(response = Header::kResponseNotImplemented);
254 }
255
256 VerifyOrExit(AppendQuestion(name, question, aResponseMessage, aCompressInfo) == kErrorNone,
257 response = Header::kResponseServerFailure);
258 }
259
260 aResponseHeader.SetQuestionCount(aRequestHeader.GetQuestionCount());
261
262 exit:
263 return response;
264 }
265
AppendQuestion(const char * aName,const Question & aQuestion,Message & aMessage,NameCompressInfo & aCompressInfo)266 Error Server::AppendQuestion(const char * aName,
267 const Question & aQuestion,
268 Message & aMessage,
269 NameCompressInfo &aCompressInfo)
270 {
271 Error error = kErrorNone;
272
273 switch (aQuestion.GetType())
274 {
275 case ResourceRecord::kTypePtr:
276 SuccessOrExit(error = AppendServiceName(aMessage, aName, aCompressInfo));
277 break;
278 case ResourceRecord::kTypeSrv:
279 case ResourceRecord::kTypeTxt:
280 SuccessOrExit(error = AppendInstanceName(aMessage, aName, aCompressInfo));
281 break;
282 case ResourceRecord::kTypeAaaa:
283 SuccessOrExit(error = AppendHostName(aMessage, aName, aCompressInfo));
284 break;
285 default:
286 OT_ASSERT(false);
287 }
288
289 error = aMessage.Append(aQuestion);
290
291 exit:
292 return error;
293 }
294
AppendPtrRecord(Message & aMessage,const char * aServiceName,const char * aInstanceName,uint32_t aTtl,NameCompressInfo & aCompressInfo)295 Error Server::AppendPtrRecord(Message & aMessage,
296 const char * aServiceName,
297 const char * aInstanceName,
298 uint32_t aTtl,
299 NameCompressInfo &aCompressInfo)
300 {
301 Error error;
302 PtrRecord ptrRecord;
303 uint16_t recordOffset;
304
305 ptrRecord.Init();
306 ptrRecord.SetTtl(aTtl);
307
308 SuccessOrExit(error = AppendServiceName(aMessage, aServiceName, aCompressInfo));
309
310 recordOffset = aMessage.GetLength();
311 SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(ptrRecord)));
312
313 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
314
315 ptrRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
316 aMessage.Write(recordOffset, ptrRecord);
317
318 exit:
319 return error;
320 }
321
AppendSrvRecord(Message & aMessage,const char * aInstanceName,const char * aHostName,uint32_t aTtl,uint16_t aPriority,uint16_t aWeight,uint16_t aPort,NameCompressInfo & aCompressInfo)322 Error Server::AppendSrvRecord(Message & aMessage,
323 const char * aInstanceName,
324 const char * aHostName,
325 uint32_t aTtl,
326 uint16_t aPriority,
327 uint16_t aWeight,
328 uint16_t aPort,
329 NameCompressInfo &aCompressInfo)
330 {
331 SrvRecord srvRecord;
332 Error error = kErrorNone;
333 uint16_t recordOffset;
334
335 srvRecord.Init();
336 srvRecord.SetTtl(aTtl);
337 srvRecord.SetPriority(aPriority);
338 srvRecord.SetWeight(aWeight);
339 srvRecord.SetPort(aPort);
340
341 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
342
343 recordOffset = aMessage.GetLength();
344 SuccessOrExit(error = aMessage.SetLength(recordOffset + sizeof(srvRecord)));
345
346 SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
347
348 srvRecord.SetLength(aMessage.GetLength() - (recordOffset + sizeof(ResourceRecord)));
349 aMessage.Write(recordOffset, srvRecord);
350
351 exit:
352 return error;
353 }
354
AppendAaaaRecord(Message & aMessage,const char * aHostName,const Ip6::Address & aAddress,uint32_t aTtl,NameCompressInfo & aCompressInfo)355 Error Server::AppendAaaaRecord(Message & aMessage,
356 const char * aHostName,
357 const Ip6::Address &aAddress,
358 uint32_t aTtl,
359 NameCompressInfo & aCompressInfo)
360 {
361 AaaaRecord aaaaRecord;
362 Error error;
363
364 aaaaRecord.Init();
365 aaaaRecord.SetTtl(aTtl);
366 aaaaRecord.SetAddress(aAddress);
367
368 SuccessOrExit(error = AppendHostName(aMessage, aHostName, aCompressInfo));
369 error = aMessage.Append(aaaaRecord);
370
371 exit:
372 return error;
373 }
374
AppendServiceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)375 Error Server::AppendServiceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
376 {
377 Error error;
378 uint16_t serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, aName);
379 const char *serviceName;
380
381 // Check whether `aName` is a sub-type service name.
382 serviceName = StringFind(aName, kDnssdSubTypeLabel);
383
384 if (serviceName != nullptr)
385 {
386 uint8_t subTypeLabelLength = static_cast<uint8_t>(serviceName - aName) + sizeof(kDnssdSubTypeLabel) - 1;
387
388 SuccessOrExit(error = Name::AppendMultipleLabels(aName, subTypeLabelLength, aMessage));
389
390 // Skip over the "._sub." label to get to the root service name.
391 serviceName += sizeof(kDnssdSubTypeLabel) - 1;
392 }
393 else
394 {
395 serviceName = aName;
396 }
397
398 if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
399 {
400 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
401 }
402 else
403 {
404 uint8_t domainStart = static_cast<uint8_t>(StringLength(serviceName, Name::kMaxNameSize - 1) -
405 StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
406 uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
407
408 serviceCompressOffset = aMessage.GetLength();
409 aCompressInfo.SetServiceNameOffset(serviceCompressOffset);
410
411 if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
412 {
413 aCompressInfo.SetDomainNameOffset(serviceCompressOffset + domainStart);
414 error = Name::AppendName(serviceName, aMessage);
415 }
416 else
417 {
418 SuccessOrExit(error = Name::AppendMultipleLabels(serviceName, domainStart, aMessage));
419 error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
420 }
421 }
422
423 exit:
424 return error;
425 }
426
AppendInstanceName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)427 Error Server::AppendInstanceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
428 {
429 Error error;
430 uint16_t instanceCompressOffset = aCompressInfo.GetInstanceNameOffset(aMessage, aName);
431
432 if (instanceCompressOffset != NameCompressInfo::kUnknownOffset)
433 {
434 error = Name::AppendPointerLabel(instanceCompressOffset, aMessage);
435 }
436 else
437 {
438 NameComponentsOffsetInfo nameComponentsInfo;
439
440 IgnoreError(FindNameComponents(aName, aCompressInfo.GetDomainName(), nameComponentsInfo));
441 OT_ASSERT(nameComponentsInfo.IsServiceInstanceName());
442
443 aCompressInfo.SetInstanceNameOffset(aMessage.GetLength());
444
445 // Append the instance name as one label
446 SuccessOrExit(error = Name::AppendLabel(aName, nameComponentsInfo.mServiceOffset - 1, aMessage));
447
448 {
449 const char *serviceName = aName + nameComponentsInfo.mServiceOffset;
450 uint16_t serviceCompressOffset = aCompressInfo.GetServiceNameOffset(aMessage, serviceName);
451
452 if (serviceCompressOffset != NameCompressInfo::kUnknownOffset)
453 {
454 error = Name::AppendPointerLabel(serviceCompressOffset, aMessage);
455 }
456 else
457 {
458 aCompressInfo.SetServiceNameOffset(aMessage.GetLength());
459 error = Name::AppendName(serviceName, aMessage);
460 }
461 }
462 }
463
464 exit:
465 return error;
466 }
467
AppendTxtRecord(Message & aMessage,const char * aInstanceName,const void * aTxtData,uint16_t aTxtLength,uint32_t aTtl,NameCompressInfo & aCompressInfo)468 Error Server::AppendTxtRecord(Message & aMessage,
469 const char * aInstanceName,
470 const void * aTxtData,
471 uint16_t aTxtLength,
472 uint32_t aTtl,
473 NameCompressInfo &aCompressInfo)
474 {
475 Error error = kErrorNone;
476 TxtRecord txtRecord;
477
478 SuccessOrExit(error = AppendInstanceName(aMessage, aInstanceName, aCompressInfo));
479
480 txtRecord.Init();
481 txtRecord.SetTtl(aTtl);
482 txtRecord.SetLength(aTxtLength);
483
484 SuccessOrExit(error = aMessage.Append(txtRecord));
485 error = aMessage.AppendBytes(aTxtData, aTxtLength);
486
487 exit:
488 return error;
489 }
490
AppendHostName(Message & aMessage,const char * aName,NameCompressInfo & aCompressInfo)491 Error Server::AppendHostName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo)
492 {
493 Error error;
494 uint16_t hostCompressOffset = aCompressInfo.GetHostNameOffset(aMessage, aName);
495
496 if (hostCompressOffset != NameCompressInfo::kUnknownOffset)
497 {
498 error = Name::AppendPointerLabel(hostCompressOffset, aMessage);
499 }
500 else
501 {
502 uint8_t domainStart = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength) -
503 StringLength(aCompressInfo.GetDomainName(), Name::kMaxNameSize - 1));
504 uint16_t domainCompressOffset = aCompressInfo.GetDomainNameOffset();
505
506 hostCompressOffset = aMessage.GetLength();
507 aCompressInfo.SetHostNameOffset(hostCompressOffset);
508
509 if (domainCompressOffset == NameCompressInfo::kUnknownOffset)
510 {
511 aCompressInfo.SetDomainNameOffset(hostCompressOffset + domainStart);
512 error = Name::AppendName(aName, aMessage);
513 }
514 else
515 {
516 SuccessOrExit(error = Name::AppendMultipleLabels(aName, domainStart, aMessage));
517 error = Name::AppendPointerLabel(domainCompressOffset, aMessage);
518 }
519 }
520
521 exit:
522 return error;
523 }
524
IncResourceRecordCount(Header & aHeader,bool aAdditional)525 void Server::IncResourceRecordCount(Header &aHeader, bool aAdditional)
526 {
527 if (aAdditional)
528 {
529 aHeader.SetAdditionalRecordCount(aHeader.GetAdditionalRecordCount() + 1);
530 }
531 else
532 {
533 aHeader.SetAnswerCount(aHeader.GetAnswerCount() + 1);
534 }
535 }
536
FindNameComponents(const char * aName,const char * aDomain,NameComponentsOffsetInfo & aInfo)537 Error Server::FindNameComponents(const char *aName, const char *aDomain, NameComponentsOffsetInfo &aInfo)
538 {
539 uint8_t nameLen = static_cast<uint8_t>(StringLength(aName, Name::kMaxNameLength));
540 uint8_t domainLen = static_cast<uint8_t>(StringLength(aDomain, Name::kMaxNameLength));
541 Error error = kErrorNone;
542 uint8_t labelBegin, labelEnd;
543
544 VerifyOrExit(Name::IsSubDomainOf(aName, aDomain), error = kErrorInvalidArgs);
545
546 labelBegin = nameLen - domainLen;
547 aInfo.mDomainOffset = labelBegin;
548
549 while (true)
550 {
551 error = FindPreviousLabel(aName, labelBegin, labelEnd);
552
553 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
554
555 if (labelEnd == labelBegin + kProtocolLabelLength &&
556 (memcmp(&aName[labelBegin], kDnssdProtocolUdp, kProtocolLabelLength) == 0 ||
557 memcmp(&aName[labelBegin], kDnssdProtocolTcp, kProtocolLabelLength) == 0))
558 {
559 // <Protocol> label found
560 aInfo.mProtocolOffset = labelBegin;
561 break;
562 }
563 }
564
565 // Get service label <Service>
566 error = FindPreviousLabel(aName, labelBegin, labelEnd);
567 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
568
569 aInfo.mServiceOffset = labelBegin;
570
571 // Check for service subtype
572 error = FindPreviousLabel(aName, labelBegin, labelEnd);
573 VerifyOrExit(error == kErrorNone, error = (error == kErrorNotFound ? kErrorNone : error));
574
575 // Note that `kDnssdSubTypeLabel` is "._sub.". Here we get the
576 // label only so we want to compare it with "_sub".
577 if ((labelEnd == labelBegin + kSubTypeLabelLength) &&
578 (memcmp(&aName[labelBegin], kDnssdSubTypeLabel + 1, kSubTypeLabelLength) == 0))
579 {
580 SuccessOrExit(error = FindPreviousLabel(aName, labelBegin, labelEnd));
581 VerifyOrExit(labelBegin == 0, error = kErrorInvalidArgs);
582 aInfo.mSubTypeOffset = labelBegin;
583 ExitNow();
584 }
585
586 // Treat everything before <Service> as <Instance> label
587 aInfo.mInstanceOffset = 0;
588
589 exit:
590 return error;
591 }
592
FindPreviousLabel(const char * aName,uint8_t & aStart,uint8_t & aStop)593 Error Server::FindPreviousLabel(const char *aName, uint8_t &aStart, uint8_t &aStop)
594 {
595 // This method finds the previous label before the current label (whose start index is @p aStart), and updates @p
596 // aStart to the start index of the label and @p aStop to the index of the dot just after the label.
597 // @note The input value of @p aStop does not matter because it is only used to output.
598
599 Error error = kErrorNone;
600 uint8_t start = aStart;
601 uint8_t end;
602
603 VerifyOrExit(start > 0, error = kErrorNotFound);
604 VerifyOrExit(aName[--start] == Name::kLabelSeperatorChar, error = kErrorInvalidArgs);
605
606 end = start;
607 while (start > 0 && aName[start - 1] != Name::kLabelSeperatorChar)
608 {
609 start--;
610 }
611
612 VerifyOrExit(start < end, error = kErrorInvalidArgs);
613
614 aStart = start;
615 aStop = end;
616
617 exit:
618 return error;
619 }
620
621 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
ResolveBySrp(Header & aResponseHeader,Message & aResponseMessage,Server::NameCompressInfo & aCompressInfo)622 Header::Response Server::ResolveBySrp(Header & aResponseHeader,
623 Message & aResponseMessage,
624 Server::NameCompressInfo &aCompressInfo)
625 {
626 Question question;
627 uint16_t readOffset = sizeof(Header);
628 Header::Response response = Header::kResponseSuccess;
629 char name[Name::kMaxNameSize];
630
631 for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
632 {
633 IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
634 IgnoreError(aResponseMessage.Read(readOffset, question));
635 readOffset += sizeof(question);
636
637 response = ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo,
638 /* aAdditional */ false);
639
640 otLogInfoDns("[server] ANSWER: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d",
641 aResponseHeader.GetMessageId(), name, question.GetClass(), question.GetType(), response);
642 }
643
644 // Answer the questions with additional RRs if required
645 if (aResponseHeader.GetAnswerCount() > 0)
646 {
647 readOffset = sizeof(Header);
648 for (uint16_t i = 0; i < aResponseHeader.GetQuestionCount(); i++)
649 {
650 IgnoreError(Name::ReadName(aResponseMessage, readOffset, name, sizeof(name)));
651 IgnoreError(aResponseMessage.Read(readOffset, question));
652 readOffset += sizeof(question);
653
654 VerifyOrExit(Header::kResponseServerFailure != ResolveQuestionBySrp(name, question, aResponseHeader,
655 aResponseMessage, aCompressInfo,
656 /* aAdditional */ true),
657 response = Header::kResponseServerFailure);
658
659 otLogInfoDns("[server] ADDITIONAL: TRANSACTION=0x%04x, QUESTION=[%s %d %d], RCODE=%d",
660 aResponseHeader.GetMessageId(), name, question.GetClass(), question.GetType(), response);
661 }
662 }
663 exit:
664 return response;
665 }
666
ResolveQuestionBySrp(const char * aName,const Question & aQuestion,Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,bool aAdditional)667 Header::Response Server::ResolveQuestionBySrp(const char * aName,
668 const Question & aQuestion,
669 Header & aResponseHeader,
670 Message & aResponseMessage,
671 NameCompressInfo &aCompressInfo,
672 bool aAdditional)
673 {
674 Error error = kErrorNone;
675 const Srp::Server::Host *host = nullptr;
676 TimeMilli now = TimerMilli::GetNow();
677 uint16_t qtype = aQuestion.GetType();
678 Header::Response response = Header::kResponseNameError;
679
680 while ((host = GetNextSrpHost(host)) != nullptr)
681 {
682 bool needAdditionalAaaaRecord = false;
683 const char *hostName = host->GetFullName();
684
685 // Handle PTR/SRV/TXT query
686 if (qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv || qtype == ResourceRecord::kTypeTxt)
687 {
688 const Srp::Server::Service *service = nullptr;
689
690 while ((service = GetNextSrpService(*host, service)) != nullptr)
691 {
692 uint32_t instanceTtl = TimeMilli::MsecToSec(service->GetExpireTime() - TimerMilli::GetNow());
693 const char *instanceName = service->GetInstanceName();
694 bool serviceNameMatched = service->MatchesServiceName(aName);
695 bool instanceNameMatched = service->MatchesInstanceName(aName);
696 bool ptrQueryMatched = qtype == ResourceRecord::kTypePtr && serviceNameMatched;
697 bool srvQueryMatched = qtype == ResourceRecord::kTypeSrv && instanceNameMatched;
698 bool txtQueryMatched = qtype == ResourceRecord::kTypeTxt && instanceNameMatched;
699
700 if (ptrQueryMatched || srvQueryMatched)
701 {
702 needAdditionalAaaaRecord = true;
703 }
704
705 if (!aAdditional && ptrQueryMatched)
706 {
707 SuccessOrExit(
708 error = AppendPtrRecord(aResponseMessage, aName, instanceName, instanceTtl, aCompressInfo));
709 IncResourceRecordCount(aResponseHeader, aAdditional);
710 response = Header::kResponseSuccess;
711 }
712
713 if ((!aAdditional && srvQueryMatched) ||
714 (aAdditional && ptrQueryMatched &&
715 !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeSrv)))
716 {
717 SuccessOrExit(error = AppendSrvRecord(aResponseMessage, instanceName, hostName, instanceTtl,
718 service->GetPriority(), service->GetWeight(),
719 service->GetPort(), aCompressInfo));
720 IncResourceRecordCount(aResponseHeader, aAdditional);
721 response = Header::kResponseSuccess;
722 }
723
724 if ((!aAdditional && txtQueryMatched) ||
725 (aAdditional && ptrQueryMatched &&
726 !HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeTxt)))
727 {
728 SuccessOrExit(error = AppendTxtRecord(aResponseMessage, instanceName, service->GetTxtData(),
729 service->GetTxtDataLength(), instanceTtl, aCompressInfo));
730 IncResourceRecordCount(aResponseHeader, aAdditional);
731 response = Header::kResponseSuccess;
732 }
733 }
734 }
735
736 // Handle AAAA query
737 if ((!aAdditional && qtype == ResourceRecord::kTypeAaaa && host->Matches(aName)) ||
738 (aAdditional && needAdditionalAaaaRecord &&
739 !HasQuestion(aResponseHeader, aResponseMessage, hostName, ResourceRecord::kTypeAaaa)))
740 {
741 uint8_t addrNum;
742 const Ip6::Address *addrs = host->GetAddresses(addrNum);
743 uint32_t hostTtl = TimeMilli::MsecToSec(host->GetExpireTime() - now);
744
745 for (uint8_t i = 0; i < addrNum; i++)
746 {
747 SuccessOrExit(error = AppendAaaaRecord(aResponseMessage, hostName, addrs[i], hostTtl, aCompressInfo));
748 IncResourceRecordCount(aResponseHeader, aAdditional);
749 }
750
751 response = Header::kResponseSuccess;
752 }
753 }
754
755 exit:
756 return error == kErrorNone ? response : Header::kResponseServerFailure;
757 }
758
GetNextSrpHost(const Srp::Server::Host * aHost)759 const Srp::Server::Host *Server::GetNextSrpHost(const Srp::Server::Host *aHost)
760 {
761 const Srp::Server::Host *host = Get<Srp::Server>().GetNextHost(aHost);
762
763 while (host != nullptr && host->IsDeleted())
764 {
765 host = Get<Srp::Server>().GetNextHost(host);
766 }
767
768 return host;
769 }
770
GetNextSrpService(const Srp::Server::Host & aHost,const Srp::Server::Service * aService)771 const Srp::Server::Service *Server::GetNextSrpService(const Srp::Server::Host & aHost,
772 const Srp::Server::Service *aService)
773 {
774 return aHost.FindNextService(aService, Srp::Server::kFlagsAnyTypeActiveService);
775 }
776 #endif // OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
777
ResolveByQueryCallbacks(Header & aResponseHeader,Message & aResponseMessage,NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)778 Error Server::ResolveByQueryCallbacks(Header & aResponseHeader,
779 Message & aResponseMessage,
780 NameCompressInfo & aCompressInfo,
781 const Ip6::MessageInfo &aMessageInfo)
782 {
783 QueryTransaction *query = nullptr;
784 DnsQueryType queryType;
785 char name[Name::kMaxNameSize];
786
787 Error error = kErrorNone;
788
789 VerifyOrExit(mQuerySubscribe != nullptr, error = kErrorFailed);
790
791 queryType = GetQueryTypeAndName(aResponseHeader, aResponseMessage, name);
792 VerifyOrExit(queryType != kDnsQueryNone, error = kErrorNotImplemented);
793
794 query = NewQuery(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
795 VerifyOrExit(query != nullptr, error = kErrorNoBufs);
796
797 mQuerySubscribe(mQueryCallbackContext, name);
798
799 exit:
800 return error;
801 }
802
NewQuery(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)803 Server::QueryTransaction *Server::NewQuery(const Header & aResponseHeader,
804 Message & aResponseMessage,
805 const NameCompressInfo &aCompressInfo,
806 const Ip6::MessageInfo &aMessageInfo)
807 {
808 QueryTransaction *newQuery = nullptr;
809
810 for (QueryTransaction &query : mQueryTransactions)
811 {
812 if (query.IsValid())
813 {
814 continue;
815 }
816
817 query.Init(aResponseHeader, aResponseMessage, aCompressInfo, aMessageInfo);
818 ExitNow(newQuery = &query);
819 }
820
821 exit:
822 if (newQuery != nullptr)
823 {
824 ResetTimer();
825 }
826
827 return newQuery;
828 }
829
CanAnswerQuery(const QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)830 bool Server::CanAnswerQuery(const QueryTransaction & aQuery,
831 const char * aServiceFullName,
832 const otDnssdServiceInstanceInfo &aInstanceInfo)
833 {
834 char name[Name::kMaxNameSize];
835 DnsQueryType sdType;
836 bool canAnswer = false;
837
838 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
839
840 switch (sdType)
841 {
842 case kDnsQueryBrowse:
843 canAnswer = (strcmp(name, aServiceFullName) == 0);
844 break;
845 case kDnsQueryResolve:
846 canAnswer = (strcmp(name, aInstanceInfo.mFullName) == 0);
847 break;
848 default:
849 break;
850 }
851
852 return canAnswer;
853 }
854
CanAnswerQuery(const Server::QueryTransaction & aQuery,const char * aHostFullName)855 bool Server::CanAnswerQuery(const Server::QueryTransaction &aQuery, const char *aHostFullName)
856 {
857 char name[Name::kMaxNameSize];
858 DnsQueryType sdType;
859
860 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
861 return (sdType == kDnsQueryResolveHost) && (strcmp(name, aHostFullName) == 0);
862 }
863
AnswerQuery(QueryTransaction & aQuery,const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)864 void Server::AnswerQuery(QueryTransaction & aQuery,
865 const char * aServiceFullName,
866 const otDnssdServiceInstanceInfo &aInstanceInfo)
867 {
868 Header & responseHeader = aQuery.GetResponseHeader();
869 Message & responseMessage = aQuery.GetResponseMessage();
870 Error error = kErrorNone;
871 NameCompressInfo &compressInfo = aQuery.GetNameCompressInfo();
872
873 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aServiceFullName,
874 ResourceRecord::kTypePtr))
875 {
876 SuccessOrExit(error = AppendPtrRecord(responseMessage, aServiceFullName, aInstanceInfo.mFullName,
877 aInstanceInfo.mTtl, compressInfo));
878 IncResourceRecordCount(responseHeader, false);
879 }
880
881 for (uint8_t additional = 0; additional <= 1; additional++)
882 {
883 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
884 ResourceRecord::kTypeSrv) == !additional)
885 {
886 SuccessOrExit(error = AppendSrvRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mHostName,
887 aInstanceInfo.mTtl, aInstanceInfo.mPriority, aInstanceInfo.mWeight,
888 aInstanceInfo.mPort, compressInfo));
889 IncResourceRecordCount(responseHeader, additional);
890 }
891
892 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mFullName,
893 ResourceRecord::kTypeTxt) == !additional)
894 {
895 SuccessOrExit(error = AppendTxtRecord(responseMessage, aInstanceInfo.mFullName, aInstanceInfo.mTxtData,
896 aInstanceInfo.mTxtLength, aInstanceInfo.mTtl, compressInfo));
897 IncResourceRecordCount(responseHeader, additional);
898 }
899
900 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aInstanceInfo.mHostName,
901 ResourceRecord::kTypeAaaa) == !additional)
902 {
903 for (uint8_t i = 0; i < aInstanceInfo.mAddressNum; i++)
904 {
905 const Ip6::Address &address = static_cast<const Ip6::Address &>(aInstanceInfo.mAddresses[i]);
906
907 OT_ASSERT(!address.IsUnspecified() && !address.IsLinkLocal() && !address.IsMulticast() &&
908 !address.IsLoopback());
909
910 SuccessOrExit(error = AppendAaaaRecord(responseMessage, aInstanceInfo.mHostName, address,
911 aInstanceInfo.mTtl, compressInfo));
912 IncResourceRecordCount(responseHeader, additional);
913 }
914 }
915 }
916
917 exit:
918 FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
919 ResetTimer();
920 }
921
AnswerQuery(QueryTransaction & aQuery,const char * aHostFullName,const otDnssdHostInfo & aHostInfo)922 void Server::AnswerQuery(QueryTransaction &aQuery, const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
923 {
924 Header & responseHeader = aQuery.GetResponseHeader();
925 Message & responseMessage = aQuery.GetResponseMessage();
926 Error error = kErrorNone;
927 NameCompressInfo &compressInfo = aQuery.GetNameCompressInfo();
928
929 if (HasQuestion(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), aHostFullName, ResourceRecord::kTypeAaaa))
930 {
931 for (uint8_t i = 0; i < aHostInfo.mAddressNum; i++)
932 {
933 const Ip6::Address &address = static_cast<const Ip6::Address &>(aHostInfo.mAddresses[i]);
934
935 OT_ASSERT(!address.IsUnspecified() && !address.IsMulticast() && !address.IsLinkLocal() &&
936 !address.IsLoopback());
937
938 SuccessOrExit(error =
939 AppendAaaaRecord(responseMessage, aHostFullName, address, aHostInfo.mTtl, compressInfo));
940 IncResourceRecordCount(responseHeader, /* aAdditional */ false);
941 }
942 }
943
944 exit:
945 FinalizeQuery(aQuery, error == kErrorNone ? Header::kResponseSuccess : Header::kResponseServerFailure);
946 ResetTimer();
947 }
948
SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,otDnssdQueryUnsubscribeCallback aUnsubscribe,void * aContext)949 void Server::SetQueryCallbacks(otDnssdQuerySubscribeCallback aSubscribe,
950 otDnssdQueryUnsubscribeCallback aUnsubscribe,
951 void * aContext)
952 {
953 OT_ASSERT((aSubscribe == nullptr) == (aUnsubscribe == nullptr));
954
955 mQuerySubscribe = aSubscribe;
956 mQueryUnsubscribe = aUnsubscribe;
957 mQueryCallbackContext = aContext;
958 }
959
HandleDiscoveredServiceInstance(const char * aServiceFullName,const otDnssdServiceInstanceInfo & aInstanceInfo)960 void Server::HandleDiscoveredServiceInstance(const char * aServiceFullName,
961 const otDnssdServiceInstanceInfo &aInstanceInfo)
962 {
963 OT_ASSERT(StringEndsWith(aServiceFullName, Name::kLabelSeperatorChar));
964 OT_ASSERT(StringEndsWith(aInstanceInfo.mFullName, Name::kLabelSeperatorChar));
965 OT_ASSERT(StringEndsWith(aInstanceInfo.mHostName, Name::kLabelSeperatorChar));
966
967 for (QueryTransaction &query : mQueryTransactions)
968 {
969 if (query.IsValid() && CanAnswerQuery(query, aServiceFullName, aInstanceInfo))
970 {
971 AnswerQuery(query, aServiceFullName, aInstanceInfo);
972 }
973 }
974 }
975
HandleDiscoveredHost(const char * aHostFullName,const otDnssdHostInfo & aHostInfo)976 void Server::HandleDiscoveredHost(const char *aHostFullName, const otDnssdHostInfo &aHostInfo)
977 {
978 OT_ASSERT(StringEndsWith(aHostFullName, Name::kLabelSeperatorChar));
979
980 for (QueryTransaction &query : mQueryTransactions)
981 {
982 if (query.IsValid() && CanAnswerQuery(query, aHostFullName))
983 {
984 AnswerQuery(query, aHostFullName, aHostInfo);
985 }
986 }
987 }
988
GetNextQuery(const otDnssdQuery * aQuery) const989 const otDnssdQuery *Server::GetNextQuery(const otDnssdQuery *aQuery) const
990 {
991 const QueryTransaction *cur = &mQueryTransactions[0];
992 const QueryTransaction *found = nullptr;
993 const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
994
995 if (aQuery != nullptr)
996 {
997 cur = query + 1;
998 }
999
1000 for (; cur < OT_ARRAY_END(mQueryTransactions); cur++)
1001 {
1002 if (cur->IsValid())
1003 {
1004 found = cur;
1005 break;
1006 }
1007 }
1008
1009 return static_cast<const otDnssdQuery *>(found);
1010 }
1011
GetQueryTypeAndName(const otDnssdQuery * aQuery,char (& aName)[Name::kMaxNameSize])1012 Server::DnsQueryType Server::GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize])
1013 {
1014 const QueryTransaction *query = static_cast<const QueryTransaction *>(aQuery);
1015
1016 OT_ASSERT(query->IsValid());
1017 return GetQueryTypeAndName(query->GetResponseHeader(), query->GetResponseMessage(), aName);
1018 }
1019
GetQueryTypeAndName(const Header & aHeader,const Message & aMessage,char (& aName)[Name::kMaxNameSize])1020 Server::DnsQueryType Server::GetQueryTypeAndName(const Header & aHeader,
1021 const Message &aMessage,
1022 char (&aName)[Name::kMaxNameSize])
1023 {
1024 DnsQueryType sdType = kDnsQueryNone;
1025
1026 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1027 {
1028 Question question;
1029
1030 IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1031 IgnoreError(aMessage.Read(readOffset, question));
1032 readOffset += sizeof(question);
1033
1034 switch (question.GetType())
1035 {
1036 case ResourceRecord::kTypePtr:
1037 ExitNow(sdType = kDnsQueryBrowse);
1038 case ResourceRecord::kTypeSrv:
1039 case ResourceRecord::kTypeTxt:
1040 ExitNow(sdType = kDnsQueryResolve);
1041 }
1042 }
1043
1044 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1045 {
1046 Question question;
1047
1048 IgnoreError(Name::ReadName(aMessage, readOffset, aName, sizeof(aName)));
1049 IgnoreError(aMessage.Read(readOffset, question));
1050 readOffset += sizeof(question);
1051
1052 switch (question.GetType())
1053 {
1054 case ResourceRecord::kTypeAaaa:
1055 case ResourceRecord::kTypeA:
1056 ExitNow(sdType = kDnsQueryResolveHost);
1057 }
1058 }
1059
1060 exit:
1061 return sdType;
1062 }
1063
HasQuestion(const Header & aHeader,const Message & aMessage,const char * aName,uint16_t aQuestionType)1064 bool Server::HasQuestion(const Header &aHeader, const Message &aMessage, const char *aName, uint16_t aQuestionType)
1065 {
1066 bool found = false;
1067
1068 for (uint16_t i = 0, readOffset = sizeof(Header); i < aHeader.GetQuestionCount(); i++)
1069 {
1070 Question question;
1071 Error error;
1072
1073 error = Name::CompareName(aMessage, readOffset, aName);
1074 IgnoreError(aMessage.Read(readOffset, question));
1075 readOffset += sizeof(question);
1076
1077 if (error == kErrorNone && aQuestionType == question.GetType())
1078 {
1079 ExitNow(found = true);
1080 }
1081 }
1082
1083 exit:
1084 return found;
1085 }
1086
HandleTimer(Timer & aTimer)1087 void Server::HandleTimer(Timer &aTimer)
1088 {
1089 aTimer.Get<Server>().HandleTimer();
1090 }
1091
HandleTimer(void)1092 void Server::HandleTimer(void)
1093 {
1094 TimeMilli now = TimerMilli::GetNow();
1095
1096 for (QueryTransaction &query : mQueryTransactions)
1097 {
1098 TimeMilli expire;
1099
1100 if (!query.IsValid())
1101 {
1102 continue;
1103 }
1104
1105 expire = query.GetStartTime() + kQueryTimeout;
1106 if (expire <= now)
1107 {
1108 FinalizeQuery(query, Header::kResponseSuccess);
1109 }
1110 }
1111
1112 ResetTimer();
1113 }
1114
ResetTimer(void)1115 void Server::ResetTimer(void)
1116 {
1117 TimeMilli now = TimerMilli::GetNow();
1118 TimeMilli nextExpire = now.GetDistantFuture();
1119
1120 for (QueryTransaction &query : mQueryTransactions)
1121 {
1122 TimeMilli expire;
1123
1124 if (!query.IsValid())
1125 {
1126 continue;
1127 }
1128
1129 expire = query.GetStartTime() + kQueryTimeout;
1130 if (expire <= now)
1131 {
1132 nextExpire = now;
1133 }
1134 else if (expire < nextExpire)
1135 {
1136 nextExpire = expire;
1137 }
1138 }
1139
1140 if (nextExpire < now.GetDistantFuture())
1141 {
1142 mTimer.FireAt(nextExpire);
1143 }
1144 else
1145 {
1146 mTimer.Stop();
1147 }
1148 }
1149
FinalizeQuery(QueryTransaction & aQuery,Header::Response aResponseCode)1150 void Server::FinalizeQuery(QueryTransaction &aQuery, Header::Response aResponseCode)
1151 {
1152 char name[Name::kMaxNameSize];
1153 DnsQueryType sdType;
1154
1155 OT_ASSERT(mQueryUnsubscribe != nullptr);
1156
1157 sdType = GetQueryTypeAndName(aQuery.GetResponseHeader(), aQuery.GetResponseMessage(), name);
1158
1159 OT_ASSERT(sdType != kDnsQueryNone);
1160 OT_UNUSED_VARIABLE(sdType);
1161
1162 mQueryUnsubscribe(mQueryCallbackContext, name);
1163 aQuery.Finalize(aResponseCode, mSocket);
1164 }
1165
Init(const Header & aResponseHeader,Message & aResponseMessage,const NameCompressInfo & aCompressInfo,const Ip6::MessageInfo & aMessageInfo)1166 void Server::QueryTransaction::Init(const Header & aResponseHeader,
1167 Message & aResponseMessage,
1168 const NameCompressInfo &aCompressInfo,
1169 const Ip6::MessageInfo &aMessageInfo)
1170 {
1171 OT_ASSERT(mResponseMessage == nullptr);
1172
1173 mResponseHeader = aResponseHeader;
1174 mResponseMessage = &aResponseMessage;
1175 mCompressInfo = aCompressInfo;
1176 mMessageInfo = aMessageInfo;
1177 mStartTime = TimerMilli::GetNow();
1178 }
1179
Finalize(Header::Response aResponseMessage,Ip6::Udp::Socket & aSocket)1180 void Server::QueryTransaction::Finalize(Header::Response aResponseMessage, Ip6::Udp::Socket &aSocket)
1181 {
1182 OT_ASSERT(mResponseMessage != nullptr);
1183
1184 SendResponse(mResponseHeader, aResponseMessage, *mResponseMessage, mMessageInfo, aSocket);
1185 mResponseMessage = nullptr;
1186 }
1187
1188 } // namespace ServiceDiscovery
1189 } // namespace Dns
1190 } // namespace ot
1191
1192 #endif // OPENTHREAD_CONFIG_DNS_SERVER_ENABLE
1193