1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/code_utils.hpp"
34 #include "common/debug.hpp"
35 #include "common/instance.hpp"
36 #include "common/locator_getters.hpp"
37 #include "common/logging.hpp"
38 #include "net/udp6.hpp"
39 #include "thread/network_data_types.hpp"
40 #include "thread/thread_netif.hpp"
41 
42 /**
43  * @file
44  *   This file implements the DNS client.
45  */
46 
47 namespace ot {
48 namespace Dns {
49 
50 //---------------------------------------------------------------------------------------------------------------------
51 // Client::QueryConfig
52 
53 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
54 
QueryConfig(InitMode aMode)55 Client::QueryConfig::QueryConfig(InitMode aMode)
56 {
57     OT_UNUSED_VARIABLE(aMode);
58 
59     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
60     GetServerSockAddr().SetPort(kDefaultServerPort);
61     SetResponseTimeout(kDefaultResponseTimeout);
62     SetMaxTxAttempts(kDefaultMaxTxAttempts);
63     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
64 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
65     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
66 #endif
67 }
68 
SetFrom(const QueryConfig & aConfig,const QueryConfig & aDefaultConfig)69 void Client::QueryConfig::SetFrom(const QueryConfig &aConfig, const QueryConfig &aDefaultConfig)
70 {
71     // This method sets the config from `aConfig` replacing any
72     // unspecified fields (value zero) with the fields from
73     // `aDefaultConfig`.
74 
75     *this = aConfig;
76 
77     if (GetServerSockAddr().GetAddress().IsUnspecified())
78     {
79         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
80     }
81 
82     if (GetServerSockAddr().GetPort() == 0)
83     {
84         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
85     }
86 
87     if (GetResponseTimeout() == 0)
88     {
89         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
90     }
91 
92     if (GetMaxTxAttempts() == 0)
93     {
94         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
95     }
96 
97     if (GetRecursionFlag() == kFlagUnspecified)
98     {
99         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
100     }
101 
102 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
103     if (GetNat64Mode() == kNat64Unspecified)
104     {
105         SetNat64Mode(aDefaultConfig.GetNat64Mode());
106     }
107 #endif
108 }
109 
110 //---------------------------------------------------------------------------------------------------------------------
111 // Client::Response
112 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const113 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
114 {
115     switch (aSection)
116     {
117     case kAnswerSection:
118         aOffset    = mAnswerOffset;
119         aNumRecord = mAnswerRecordCount;
120         break;
121     case kAdditionalDataSection:
122     default:
123         aOffset    = mAdditionalOffset;
124         aNumRecord = mAdditionalRecordCount;
125         break;
126     }
127 }
128 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const129 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
130 {
131     uint16_t offset = kNameOffsetInQuery;
132 
133     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
134 }
135 
CheckForHostNameAlias(Section aSection,Name & aHostName) const136 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
137 {
138     // If the response includes a CNAME record mapping the query host
139     // name to a canonical name, we update `aHostName` to the new alias
140     // name. Otherwise `aHostName` remains as before.
141 
142     Error       error;
143     uint16_t    offset;
144     uint16_t    numRecords;
145     CnameRecord cnameRecord;
146 
147     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
148 
149     SelectSection(aSection, offset, numRecords);
150     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
151 
152     switch (error)
153     {
154     case kErrorNone:
155         // A CNAME record was found. `offset` now points to after the
156         // last read byte within the `mMessage` into the `cnameRecord`
157         // (which is the start of the new canonical name).
158         aHostName.SetFromMessage(*mMessage, offset);
159         error = Name::ParseName(*mMessage, offset);
160         break;
161 
162     case kErrorNotFound:
163         error = kErrorNone;
164         break;
165 
166     default:
167         break;
168     }
169 
170 exit:
171     return error;
172 }
173 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const174 Error Client::Response::FindHostAddress(Section       aSection,
175                                         const Name &  aHostName,
176                                         uint16_t      aIndex,
177                                         Ip6::Address &aAddress,
178                                         uint32_t &    aTtl) const
179 {
180     Error      error;
181     uint16_t   offset;
182     uint16_t   numRecords;
183     Name       name = aHostName;
184     AaaaRecord aaaaRecord;
185 
186     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
187 
188     SelectSection(aSection, offset, numRecords);
189     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
190     aAddress = aaaaRecord.GetAddress();
191     aTtl     = aaaaRecord.GetTtl();
192 
193 exit:
194     return error;
195 }
196 
197 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
198 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const199 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
200 {
201     Error    error;
202     uint16_t offset;
203     uint16_t numRecords;
204     Name     name = aHostName;
205 
206     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
207 
208     SelectSection(aSection, offset, numRecords);
209     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
210 
211 exit:
212     return error;
213 }
214 
215 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
216 
217 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
218 
FindServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const219 Error Client::Response::FindServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
220 {
221     // This method searches for SRV and TXT records in the given
222     // section matching the record name against `aName`, and updates
223     // the `aServiceInfo` accordingly. It also searches for AAAA
224     // record for host name associated with the service (from SRV
225     // record). The search for AAAA record is always performed in
226     // Additional Data section (independent of the value given in
227     // `aSection`).
228 
229     Error     error;
230     uint16_t  offset;
231     uint16_t  numRecords;
232     Name      hostName;
233     SrvRecord srvRecord;
234     TxtRecord txtRecord;
235 
236     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
237 
238     // Search for a matching SRV record
239     SelectSection(aSection, offset, numRecords);
240     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
241 
242     aServiceInfo.mTtl      = srvRecord.GetTtl();
243     aServiceInfo.mPort     = srvRecord.GetPort();
244     aServiceInfo.mPriority = srvRecord.GetPriority();
245     aServiceInfo.mWeight   = srvRecord.GetWeight();
246 
247     hostName.SetFromMessage(*mMessage, offset);
248 
249     if (aServiceInfo.mHostNameBuffer != nullptr)
250     {
251         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
252                                                            aServiceInfo.mHostNameBufferSize));
253     }
254     else
255     {
256         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
257     }
258 
259     // Search in additional section for AAAA record for the host name.
260 
261     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0,
262                             static_cast<Ip6::Address &>(aServiceInfo.mHostAddress), aServiceInfo.mHostAddressTtl);
263 
264     if (error == kErrorNotFound)
265     {
266         static_cast<Ip6::Address &>(aServiceInfo.mHostAddress).Clear();
267         aServiceInfo.mHostAddressTtl = 0;
268     }
269     else
270     {
271         SuccessOrExit(error);
272     }
273 
274     // A null `mTxtData` indicates that caller does not want to retrieve TXT data.
275     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
276 
277     // Search for a matching TXT record. If not found, indicate this by
278     // setting `aServiceInfo.mTxtDataSize` to zero.
279 
280     SelectSection(aSection, offset, numRecords);
281     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord);
282 
283     switch (error)
284     {
285     case kErrorNone:
286         SuccessOrExit(error =
287                           txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize));
288         aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
289         break;
290 
291     case kErrorNotFound:
292         aServiceInfo.mTxtDataSize = 0;
293         aServiceInfo.mTxtDataTtl  = 0;
294         break;
295 
296     default:
297         ExitNow();
298     }
299 
300 exit:
301     return error;
302 }
303 
304 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
305 
306 //---------------------------------------------------------------------------------------------------------------------
307 // Client::AddressResponse
308 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const309 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
310 {
311     Error error = kErrorNone;
312     Name  name(*mQuery, kNameOffsetInQuery);
313 
314 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
315 
316     // If the response is for an IPv4 address query or if it is an
317     // IPv6 address query response with no IPv6 address but with
318     // an IPv4 in its additional section, we read the IPv4 address
319     // and translate it to an IPv6 address.
320 
321     QueryInfo info;
322 
323     info.ReadFrom(*mQuery);
324 
325     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
326     {
327         Section     section;
328         ARecord     aRecord;
329         Ip6::Prefix nat64Prefix;
330 
331         SuccessOrExit(error = GetNat64Prefix(nat64Prefix));
332 
333         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
334         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
335 
336         aAddress.SynthesizeFromIp4Address(nat64Prefix, aRecord.GetAddress());
337         aTtl = aRecord.GetTtl();
338 
339         ExitNow();
340     }
341 
342 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
343 
344     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
345 
346 exit:
347     return error;
348 }
349 
350 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
351 
GetNat64Prefix(Ip6::Prefix & aPrefix) const352 Error Client::AddressResponse::GetNat64Prefix(Ip6::Prefix &aPrefix) const
353 {
354     Error                            error      = kErrorNotFound;
355     NetworkData::Iterator            iterator   = NetworkData::kIteratorInit;
356     signed int                       preference = NetworkData::kRoutePreferenceLow;
357     NetworkData::ExternalRouteConfig config;
358 
359     aPrefix.Clear();
360 
361     while (mInstance->Get<NetworkData::Leader>().GetNextExternalRoute(iterator, config) == kErrorNone)
362     {
363         if (!config.mNat64 || !config.GetPrefix().IsValidNat64())
364         {
365             continue;
366         }
367 
368         if ((aPrefix.GetLength() == 0) || (config.mPreference > preference))
369         {
370             aPrefix    = config.GetPrefix();
371             preference = config.mPreference;
372             error      = kErrorNone;
373         }
374     }
375 
376     return error;
377 }
378 
379 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
380 
381 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
382 
383 //---------------------------------------------------------------------------------------------------------------------
384 // Client::BrowseResponse
385 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const386 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
387 {
388     Error     error;
389     uint16_t  offset;
390     uint16_t  numRecords;
391     Name      serviceName(*mQuery, kNameOffsetInQuery);
392     PtrRecord ptrRecord;
393 
394     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
395 
396     SelectSection(kAnswerSection, offset, numRecords);
397     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
398     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
399 
400 exit:
401     return error;
402 }
403 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const404 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
405 {
406     Error error;
407     Name  instanceName;
408 
409     // Find a matching PTR record for the service instance label.
410     // Then search and read SRV, TXT and AAAA records in Additional Data section
411     // matching the same name to populate `aServiceInfo`.
412 
413     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
414     error = FindServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo);
415 
416 exit:
417     return error;
418 }
419 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const420 Error Client::BrowseResponse::GetHostAddress(const char *  aHostName,
421                                              uint16_t      aIndex,
422                                              Ip6::Address &aAddress,
423                                              uint32_t &    aTtl) const
424 {
425     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
426 }
427 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const428 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
429 {
430     // This method searches within the Answer Section for a PTR record
431     // matching a given instance label @aInstanceLabel. If found, the
432     // `aName` is updated to return the name in the message.
433 
434     Error     error;
435     uint16_t  offset;
436     Name      serviceName(*mQuery, kNameOffsetInQuery);
437     uint16_t  numRecords;
438     uint16_t  labelOffset;
439     PtrRecord ptrRecord;
440 
441     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
442 
443     SelectSection(kAnswerSection, offset, numRecords);
444 
445     for (; numRecords > 0; numRecords--)
446     {
447         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
448 
449         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
450 
451         if (error == kErrorNotFound)
452         {
453             // `ReadRecord()` updates `offset` to skip over a
454             // non-matching record.
455             continue;
456         }
457 
458         SuccessOrExit(error);
459 
460         // It is a PTR record. Check the first label to match the
461         // instance label.
462 
463         labelOffset = offset;
464         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
465 
466         if (error == kErrorNone)
467         {
468             aInstanceName.SetFromMessage(*mMessage, offset);
469             ExitNow();
470         }
471 
472         VerifyOrExit(error == kErrorNotFound);
473 
474         // Update offset to skip over the PTR record.
475         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
476     }
477 
478     error = kErrorNotFound;
479 
480 exit:
481     return error;
482 }
483 
484 //---------------------------------------------------------------------------------------------------------------------
485 // Client::ServiceResponse
486 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const487 Error Client::ServiceResponse::GetServiceName(char *   aLabelBuffer,
488                                               uint8_t  aLabelBufferSize,
489                                               char *   aNameBuffer,
490                                               uint16_t aNameBufferSize) const
491 {
492     Error    error;
493     uint16_t offset = kNameOffsetInQuery;
494 
495     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
496 
497     VerifyOrExit(aNameBuffer != nullptr);
498     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
499 
500 exit:
501     return error;
502 }
503 
GetServiceInfo(ServiceInfo & aServiceInfo) const504 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
505 {
506     // Search and read SRV, TXT records in Answer Section
507     // matching name from query.
508 
509     return FindServiceInfo(kAnswerSection, Name(*mQuery, kNameOffsetInQuery), aServiceInfo);
510 }
511 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const512 Error Client::ServiceResponse::GetHostAddress(const char *  aHostName,
513                                               uint16_t      aIndex,
514                                               Ip6::Address &aAddress,
515                                               uint32_t &    aTtl) const
516 {
517     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
518 }
519 
520 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
521 
522 //---------------------------------------------------------------------------------------------------------------------
523 // Client
524 
525 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
526 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
527 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
528 #endif
529 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
530 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
531 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
532 #endif
533 
534 const uint8_t Client::kQuestionCount[] = {
535     /* kIp6AddressQuery -> */ OT_ARRAY_LENGTH(kIp6AddressQueryRecordTypes), // AAAA records
536 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
537     /* kIp4AddressQuery -> */ OT_ARRAY_LENGTH(kIp4AddressQueryRecordTypes), // A records
538 #endif
539 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
540     /* kBrowseQuery  -> */ OT_ARRAY_LENGTH(kBrowseQueryRecordTypes),  // PTR records
541     /* kServiceQuery -> */ OT_ARRAY_LENGTH(kServiceQueryRecordTypes), // SRV and TXT records
542 #endif
543 };
544 
545 const uint16_t *Client::kQuestionRecordTypes[] = {
546     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
547 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
548     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
549 #endif
550 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
551     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
552     /* kServiceQuery -> */ kServiceQueryRecordTypes,
553 #endif
554 };
555 
Client(Instance & aInstance)556 Client::Client(Instance &aInstance)
557     : InstanceLocator(aInstance)
558     , mSocket(aInstance)
559     , mTimer(aInstance, Client::HandleTimer)
560     , mDefaultConfig(QueryConfig::kInitFromDefaults)
561 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
562     , mUserDidSetDefaultAddress(false)
563 #endif
564 {
565     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
566 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
567     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
568 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
569     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
570     static_assert(kServiceQuery == 3, "kServiceQuery value is not correct");
571 #endif
572 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
573     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
574     static_assert(kServiceQuery == 2, "kServiceQuery value is not correct");
575 #endif
576 }
577 
Start(void)578 Error Client::Start(void)
579 {
580     Error error;
581 
582     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
583     SuccessOrExit(error = mSocket.Bind(0, OT_NETIF_UNSPECIFIED));
584 
585 exit:
586     return error;
587 }
588 
Stop(void)589 void Client::Stop(void)
590 {
591     Query *query;
592 
593     while ((query = mQueries.GetHead()) != nullptr)
594     {
595         FinalizeQuery(*query, kErrorAbort);
596     }
597 
598     IgnoreError(mSocket.Close());
599 }
600 
SetDefaultConfig(const QueryConfig & aQueryConfig)601 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
602 {
603     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
604 
605     mDefaultConfig.SetFrom(aQueryConfig, startingDefault);
606 
607 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
608     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
609     UpdateDefaultConfigAddress();
610 #endif
611 }
612 
ResetDefaultConfig(void)613 void Client::ResetDefaultConfig(void)
614 {
615     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
616 
617 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
618     mUserDidSetDefaultAddress = false;
619     UpdateDefaultConfigAddress();
620 #endif
621 }
622 
623 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)624 void Client::UpdateDefaultConfigAddress(void)
625 {
626     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
627 
628     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
629         !srpServerAddr.IsUnspecified())
630     {
631         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
632     }
633 }
634 #endif
635 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)636 Error Client::ResolveAddress(const char *       aHostName,
637                              AddressCallback    aCallback,
638                              void *             aContext,
639                              const QueryConfig *aConfig)
640 {
641     QueryInfo info;
642 
643     info.Clear();
644     info.mQueryType                 = kIp6AddressQuery;
645     info.mCallback.mAddressCallback = aCallback;
646 
647     return StartQuery(info, aConfig, nullptr, aHostName, aContext);
648 }
649 
650 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
651 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)652 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
653 {
654     QueryInfo info;
655 
656     info.Clear();
657     info.mQueryType                = kBrowseQuery;
658     info.mCallback.mBrowseCallback = aCallback;
659 
660     return StartQuery(info, aConfig, nullptr, aServiceName, aContext);
661 }
662 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)663 Error Client::ResolveService(const char *       aInstanceLabel,
664                              const char *       aServiceName,
665                              ServiceCallback    aCallback,
666                              void *             aContext,
667                              const QueryConfig *aConfig)
668 {
669     QueryInfo info;
670     Error     error;
671 
672     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
673 
674     info.Clear();
675     info.mQueryType                 = kServiceQuery;
676     info.mCallback.mServiceCallback = aCallback;
677 
678     error = StartQuery(info, aConfig, aInstanceLabel, aServiceName, aContext);
679 
680 exit:
681     return error;
682 }
683 
684 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
685 
StartQuery(QueryInfo & aInfo,const QueryConfig * aConfig,const char * aLabel,const char * aName,void * aContext)686 Error Client::StartQuery(QueryInfo &        aInfo,
687                          const QueryConfig *aConfig,
688                          const char *       aLabel,
689                          const char *       aName,
690                          void *             aContext)
691 {
692     // This method assumes that `mQueryType` and `mCallback` to be
693     // already set by caller on `aInfo`. The `aLabel` can be `nullptr`
694     // and then `aName` provides the full name, otherwise the name is
695     // appended as `{aLabel}.{aName}`.
696 
697     Error  error;
698     Query *query;
699 
700     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
701 
702     if (aConfig == nullptr)
703     {
704         aInfo.mConfig = mDefaultConfig;
705     }
706     else
707     {
708         // To form the config for this query, replace any unspecified
709         // fields (zero value) in the given `aConfig` with the fields
710         // from `mDefaultConfig`.
711 
712         aInfo.mConfig.SetFrom(*aConfig, mDefaultConfig);
713     }
714 
715     aInfo.mCallbackContext = aContext;
716 
717     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
718     mQueries.Enqueue(*query);
719 
720     SendQuery(*query, aInfo, /* aUpdateTimer */ true);
721 
722 exit:
723     return error;
724 }
725 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)726 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
727 {
728     Error error = kErrorNone;
729 
730     aQuery = Get<MessagePool>().New(Message::kTypeOther, /* aReserveHeader */ 0);
731     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
732 
733     SuccessOrExit(error = aQuery->Append(aInfo));
734 
735     if (aLabel != nullptr)
736     {
737         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
738     }
739 
740     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
741 
742 exit:
743     FreeAndNullMessageOnError(aQuery, error);
744     return error;
745 }
746 
FreeQuery(Query & aQuery)747 void Client::FreeQuery(Query &aQuery)
748 {
749     mQueries.DequeueAndFree(aQuery);
750 }
751 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)752 void Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
753 {
754     // This method prepares and sends a query message represented by
755     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
756     // the new `mRetransmissionTime`) and updates it in `aQuery` as
757     // well. `aUpdateTimer` indicates whether the timer should be
758     // updated when query is sent or not (used in the case where timer
759     // is handled by caller).
760 
761     Error            error   = kErrorNone;
762     Message *        message = nullptr;
763     Header           header;
764     Ip6::MessageInfo messageInfo;
765 
766     aInfo.mTransmissionCount++;
767     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
768 
769     if (aInfo.mMessageId == 0)
770     {
771         do
772         {
773             SuccessOrExit(error = header.SetRandomMessageId());
774         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
775 
776         aInfo.mMessageId = header.GetMessageId();
777     }
778     else
779     {
780         header.SetMessageId(aInfo.mMessageId);
781     }
782 
783     header.SetType(Header::kTypeQuery);
784     header.SetQueryType(Header::kQueryTypeStandard);
785 
786     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
787     {
788         header.SetRecursionDesiredFlag();
789     }
790 
791     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
792 
793     message = mSocket.NewMessage(0);
794     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
795 
796     SuccessOrExit(error = message->Append(header));
797 
798     // Prepare the question section.
799 
800     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
801     {
802         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
803         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
804     }
805 
806     messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
807     messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
808 
809     SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
810 
811 exit:
812     FreeMessageOnError(message, error);
813 
814     UpdateQuery(aQuery, aInfo);
815 
816     if (aUpdateTimer)
817     {
818         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
819     }
820 }
821 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)822 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
823 {
824     Error    error = kErrorNone;
825     uint16_t offset;
826     uint16_t length;
827 
828     // The name is encoded and included after the `Info` in `aQuery`. We
829     // first calculate the encoded length of the name, then grow the
830     // message, and finally copy the encoded name bytes from `aQuery`
831     // into `aMessage`.
832 
833     length = aQuery.GetLength() - kNameOffsetInQuery;
834 
835     offset = aMessage.GetLength();
836     SuccessOrExit(error = aMessage.SetLength(offset + length));
837 
838     aQuery.CopyTo(/* aSourceOffset */ kNameOffsetInQuery, /* aDestOffset */ offset, length, aMessage);
839 
840 exit:
841     return error;
842 }
843 
FinalizeQuery(Query & aQuery,Error aError)844 void Client::FinalizeQuery(Query &aQuery, Error aError)
845 {
846     Response  response;
847     QueryInfo info;
848 
849     response.mInstance = &Get<Instance>();
850     response.mQuery    = &aQuery;
851     info.ReadFrom(aQuery);
852 
853     FinalizeQuery(response, info.mQueryType, aError);
854 }
855 
FinalizeQuery(Response & aResponse,QueryType aType,Error aError)856 void Client::FinalizeQuery(Response &aResponse, QueryType aType, Error aError)
857 {
858     Callback callback;
859     void *   context;
860 
861     GetCallback(*aResponse.mQuery, callback, context);
862 
863     switch (aType)
864     {
865     case kIp6AddressQuery:
866 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
867     case kIp4AddressQuery:
868 #endif
869         if (callback.mAddressCallback != nullptr)
870         {
871             callback.mAddressCallback(aError, &aResponse, context);
872         }
873         break;
874 
875 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
876     case kBrowseQuery:
877         if (callback.mBrowseCallback != nullptr)
878         {
879             callback.mBrowseCallback(aError, &aResponse, context);
880         }
881         break;
882 
883     case kServiceQuery:
884         if (callback.mServiceCallback != nullptr)
885         {
886             callback.mServiceCallback(aError, &aResponse, context);
887         }
888         break;
889 #endif
890     }
891 
892     FreeQuery(*aResponse.mQuery);
893 }
894 
GetCallback(const Query & aQuery,Callback & aCallback,void * & aContext)895 void Client::GetCallback(const Query &aQuery, Callback &aCallback, void *&aContext)
896 {
897     QueryInfo info;
898 
899     info.ReadFrom(aQuery);
900 
901     aCallback = info.mCallback;
902     aContext  = info.mCallbackContext;
903 }
904 
FindQueryById(uint16_t aMessageId)905 Client::Query *Client::FindQueryById(uint16_t aMessageId)
906 {
907     Query *   query;
908     QueryInfo info;
909 
910     for (query = mQueries.GetHead(); query != nullptr; query = query->GetNext())
911     {
912         info.ReadFrom(*query);
913 
914         if (info.mMessageId == aMessageId)
915         {
916             break;
917         }
918     }
919 
920     return query;
921 }
922 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)923 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
924 {
925     OT_UNUSED_VARIABLE(aMsgInfo);
926 
927     static_cast<Client *>(aContext)->ProcessResponse(*static_cast<Message *>(aMessage));
928 }
929 
ProcessResponse(const Message & aMessage)930 void Client::ProcessResponse(const Message &aMessage)
931 {
932     Response  response;
933     QueryType type;
934     Error     responseError;
935 
936     response.mInstance = &Get<Instance>();
937     response.mMessage  = &aMessage;
938 
939     // We intentionally parse the response in a separate method
940     // `ParseResponse()` to free all the stack allocated variables
941     // (e.g., `QueryInfo`) used during parsing of the message before
942     // finalizing the query and invoking the user's callback.
943 
944     SuccessOrExit(ParseResponse(response, type, responseError));
945     FinalizeQuery(response, type, responseError);
946 
947 exit:
948     return;
949 }
950 
ParseResponse(Response & aResponse,QueryType & aType,Error & aResponseError)951 Error Client::ParseResponse(Response &aResponse, QueryType &aType, Error &aResponseError)
952 {
953     Error          error   = kErrorNone;
954     const Message &message = *aResponse.mMessage;
955     uint16_t       offset  = message.GetOffset();
956     Header         header;
957     QueryInfo      info;
958     Name           queryName;
959 
960     SuccessOrExit(error = message.Read(offset, header));
961     offset += sizeof(Header);
962 
963     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
964                      !header.IsTruncationFlagSet(),
965                  error = kErrorDrop);
966 
967     aResponse.mQuery = FindQueryById(header.GetMessageId());
968     VerifyOrExit(aResponse.mQuery != nullptr, error = kErrorNotFound);
969 
970     info.ReadFrom(*aResponse.mQuery);
971     aType = info.mQueryType;
972 
973     queryName.SetFromMessage(*aResponse.mQuery, kNameOffsetInQuery);
974 
975     // Check the Question Section
976 
977     VerifyOrExit(header.GetQuestionCount() == kQuestionCount[aType], error = kErrorParse);
978 
979     for (uint8_t num = 0; num < kQuestionCount[aType]; num++)
980     {
981         SuccessOrExit(error = Name::CompareName(message, offset, queryName));
982         offset += sizeof(Question);
983     }
984 
985     // Check the answer, authority and additional record sections
986 
987     aResponse.mAnswerOffset = offset;
988     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAnswerCount()));
989     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAuthorityRecordCount()));
990     aResponse.mAdditionalOffset = offset;
991     SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAdditionalRecordCount()));
992 
993     aResponse.mAnswerRecordCount     = header.GetAnswerCount();
994     aResponse.mAdditionalRecordCount = header.GetAdditionalRecordCount();
995 
996     // Check the response code from server
997 
998     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
999 
1000 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1001 
1002     if (aType == kIp6AddressQuery)
1003     {
1004         Ip6::Address ip6ddress;
1005         uint32_t     ttl;
1006         ARecord      aRecord;
1007 
1008         // If the response does not contain an answer for the IPv6 address
1009         // resolution query and if NAT64 is allowed for this query, we can
1010         // perform IPv4 to IPv6 address translation.
1011 
1012         VerifyOrExit(aResponse.FindHostAddress(Response::kAnswerSection, queryName, /* aIndex */ 0, ip6ddress, ttl) !=
1013                      kErrorNone);
1014         VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1015 
1016         // First, we check if the response already contains an A record
1017         // (IPv4 address) for the query name.
1018 
1019         if (aResponse.FindARecord(Response::kAdditionalDataSection, queryName, /* aIndex */ 0, aRecord) == kErrorNone)
1020         {
1021             aResponse.mIp6QueryResponseRequiresNat64 = true;
1022             aResponseError                           = kErrorNone;
1023             ExitNow();
1024         }
1025 
1026         // Otherwise, we send a new query for IPv4 address resolution
1027         // for the same host name. We reuse the existing `query`
1028         // instance and keep all the info but clear `mTransmissionCount`
1029         // and `mMessageId` (so that a new random message ID is
1030         // selected). The new `info` will be saved in the query in
1031         // `SendQuery()`. Note that the current query is still in the
1032         // `mQueries` list when `SendQuery()` selects a new random
1033         // message ID, so the existing message ID for this query will
1034         // not be reused. Since the query is not yet resolved, we
1035         // return `kErrorPending`.
1036 
1037         info.mQueryType         = kIp4AddressQuery;
1038         info.mMessageId         = 0;
1039         info.mTransmissionCount = 0;
1040 
1041         SendQuery(*aResponse.mQuery, info, /* aUpdateTimer */ true);
1042 
1043         error = kErrorPending;
1044     }
1045 
1046 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1047 
1048 exit:
1049     if (error != kErrorNone)
1050     {
1051         otLogInfoDns("Failed to parse response %s", ErrorToString(error));
1052     }
1053 
1054     return error;
1055 }
1056 
HandleTimer(Timer & aTimer)1057 void Client::HandleTimer(Timer &aTimer)
1058 {
1059     aTimer.Get<Client>().HandleTimer();
1060 }
1061 
HandleTimer(void)1062 void Client::HandleTimer(void)
1063 {
1064     TimeMilli now      = TimerMilli::GetNow();
1065     TimeMilli nextTime = now.GetDistantFuture();
1066     Query *   nextQuery;
1067     QueryInfo info;
1068 
1069     for (Query *query = mQueries.GetHead(); query != nullptr; query = nextQuery)
1070     {
1071         nextQuery = query->GetNext();
1072 
1073         info.ReadFrom(*query);
1074 
1075         if (now >= info.mRetransmissionTime)
1076         {
1077             if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1078             {
1079                 FinalizeQuery(*query, kErrorResponseTimeout);
1080                 continue;
1081             }
1082 
1083             SendQuery(*query, info, /* aUpdateTimer */ false);
1084         }
1085 
1086         if (nextTime > info.mRetransmissionTime)
1087         {
1088             nextTime = info.mRetransmissionTime;
1089         }
1090     }
1091 
1092     if (nextTime < now.GetDistantFuture())
1093     {
1094         mTimer.FireAt(nextTime);
1095     }
1096 }
1097 
1098 } // namespace Dns
1099 } // namespace ot
1100 
1101 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1102