1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/code_utils.hpp"
34 #include "common/debug.hpp"
35 #include "common/instance.hpp"
36 #include "common/locator_getters.hpp"
37 #include "common/logging.hpp"
38 #include "net/udp6.hpp"
39 #include "thread/network_data_types.hpp"
40 #include "thread/thread_netif.hpp"
41
42 /**
43 * @file
44 * This file implements the DNS client.
45 */
46
47 namespace ot {
48 namespace Dns {
49
50 //---------------------------------------------------------------------------------------------------------------------
51 // Client::QueryConfig
52
53 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
54
QueryConfig(InitMode aMode)55 Client::QueryConfig::QueryConfig(InitMode aMode)
56 {
57 OT_UNUSED_VARIABLE(aMode);
58
59 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
60 GetServerSockAddr().SetPort(kDefaultServerPort);
61 SetResponseTimeout(kDefaultResponseTimeout);
62 SetMaxTxAttempts(kDefaultMaxTxAttempts);
63 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
64 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
65 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
66 #endif
67 }
68
SetFrom(const QueryConfig & aConfig,const QueryConfig & aDefaultConfig)69 void Client::QueryConfig::SetFrom(const QueryConfig &aConfig, const QueryConfig &aDefaultConfig)
70 {
71 // This method sets the config from `aConfig` replacing any
72 // unspecified fields (value zero) with the fields from
73 // `aDefaultConfig`.
74
75 *this = aConfig;
76
77 if (GetServerSockAddr().GetAddress().IsUnspecified())
78 {
79 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
80 }
81
82 if (GetServerSockAddr().GetPort() == 0)
83 {
84 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
85 }
86
87 if (GetResponseTimeout() == 0)
88 {
89 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
90 }
91
92 if (GetMaxTxAttempts() == 0)
93 {
94 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
95 }
96
97 if (GetRecursionFlag() == kFlagUnspecified)
98 {
99 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
100 }
101
102 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
103 if (GetNat64Mode() == kNat64Unspecified)
104 {
105 SetNat64Mode(aDefaultConfig.GetNat64Mode());
106 }
107 #endif
108 }
109
110 //---------------------------------------------------------------------------------------------------------------------
111 // Client::Response
112
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const113 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
114 {
115 switch (aSection)
116 {
117 case kAnswerSection:
118 aOffset = mAnswerOffset;
119 aNumRecord = mAnswerRecordCount;
120 break;
121 case kAdditionalDataSection:
122 default:
123 aOffset = mAdditionalOffset;
124 aNumRecord = mAdditionalRecordCount;
125 break;
126 }
127 }
128
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const129 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
130 {
131 uint16_t offset = kNameOffsetInQuery;
132
133 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
134 }
135
CheckForHostNameAlias(Section aSection,Name & aHostName) const136 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
137 {
138 // If the response includes a CNAME record mapping the query host
139 // name to a canonical name, we update `aHostName` to the new alias
140 // name. Otherwise `aHostName` remains as before.
141
142 Error error;
143 uint16_t offset;
144 uint16_t numRecords;
145 CnameRecord cnameRecord;
146
147 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
148
149 SelectSection(aSection, offset, numRecords);
150 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
151
152 switch (error)
153 {
154 case kErrorNone:
155 // A CNAME record was found. `offset` now points to after the
156 // last read byte within the `mMessage` into the `cnameRecord`
157 // (which is the start of the new canonical name).
158 aHostName.SetFromMessage(*mMessage, offset);
159 error = Name::ParseName(*mMessage, offset);
160 break;
161
162 case kErrorNotFound:
163 error = kErrorNone;
164 break;
165
166 default:
167 break;
168 }
169
170 exit:
171 return error;
172 }
173
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const174 Error Client::Response::FindHostAddress(Section aSection,
175 const Name & aHostName,
176 uint16_t aIndex,
177 Ip6::Address &aAddress,
178 uint32_t & aTtl) const
179 {
180 Error error;
181 uint16_t offset;
182 uint16_t numRecords;
183 Name name = aHostName;
184 AaaaRecord aaaaRecord;
185
186 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
187
188 SelectSection(aSection, offset, numRecords);
189 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
190 aAddress = aaaaRecord.GetAddress();
191 aTtl = aaaaRecord.GetTtl();
192
193 exit:
194 return error;
195 }
196
197 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
198
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const199 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
200 {
201 Error error;
202 uint16_t offset;
203 uint16_t numRecords;
204 Name name = aHostName;
205
206 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
207
208 SelectSection(aSection, offset, numRecords);
209 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
210
211 exit:
212 return error;
213 }
214
215 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
216
217 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
218
FindServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const219 Error Client::Response::FindServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
220 {
221 // This method searches for SRV and TXT records in the given
222 // section matching the record name against `aName`, and updates
223 // the `aServiceInfo` accordingly. It also searches for AAAA
224 // record for host name associated with the service (from SRV
225 // record). The search for AAAA record is always performed in
226 // Additional Data section (independent of the value given in
227 // `aSection`).
228
229 Error error;
230 uint16_t offset;
231 uint16_t numRecords;
232 Name hostName;
233 SrvRecord srvRecord;
234 TxtRecord txtRecord;
235
236 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
237
238 // Search for a matching SRV record
239 SelectSection(aSection, offset, numRecords);
240 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
241
242 aServiceInfo.mTtl = srvRecord.GetTtl();
243 aServiceInfo.mPort = srvRecord.GetPort();
244 aServiceInfo.mPriority = srvRecord.GetPriority();
245 aServiceInfo.mWeight = srvRecord.GetWeight();
246
247 hostName.SetFromMessage(*mMessage, offset);
248
249 if (aServiceInfo.mHostNameBuffer != nullptr)
250 {
251 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
252 aServiceInfo.mHostNameBufferSize));
253 }
254 else
255 {
256 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
257 }
258
259 // Search in additional section for AAAA record for the host name.
260
261 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0,
262 static_cast<Ip6::Address &>(aServiceInfo.mHostAddress), aServiceInfo.mHostAddressTtl);
263
264 if (error == kErrorNotFound)
265 {
266 static_cast<Ip6::Address &>(aServiceInfo.mHostAddress).Clear();
267 aServiceInfo.mHostAddressTtl = 0;
268 }
269 else
270 {
271 SuccessOrExit(error);
272 }
273
274 // A null `mTxtData` indicates that caller does not want to retrieve TXT data.
275 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
276
277 // Search for a matching TXT record. If not found, indicate this by
278 // setting `aServiceInfo.mTxtDataSize` to zero.
279
280 SelectSection(aSection, offset, numRecords);
281 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord);
282
283 switch (error)
284 {
285 case kErrorNone:
286 SuccessOrExit(error =
287 txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize));
288 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
289 break;
290
291 case kErrorNotFound:
292 aServiceInfo.mTxtDataSize = 0;
293 aServiceInfo.mTxtDataTtl = 0;
294 break;
295
296 default:
297 ExitNow();
298 }
299
300 exit:
301 return error;
302 }
303
304 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
305
306 //---------------------------------------------------------------------------------------------------------------------
307 // Client::AddressResponse
308
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const309 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
310 {
311 Error error = kErrorNone;
312 Name name(*mQuery, kNameOffsetInQuery);
313
314 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
315
316 // If the response is for an IPv4 address query or if it is an
317 // IPv6 address query response with no IPv6 address but with
318 // an IPv4 in its additional section, we read the IPv4 address
319 // and translate it to an IPv6 address.
320
321 QueryInfo info;
322
323 info.ReadFrom(*mQuery);
324
325 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
326 {
327 Section section;
328 ARecord aRecord;
329 Ip6::Prefix nat64Prefix;
330
331 SuccessOrExit(error = GetNat64Prefix(nat64Prefix));
332
333 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
334 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
335
336 aAddress.SynthesizeFromIp4Address(nat64Prefix, aRecord.GetAddress());
337 aTtl = aRecord.GetTtl();
338
339 ExitNow();
340 }
341
342 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
343
344 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
345
346 exit:
347 return error;
348 }
349
350 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
351
GetNat64Prefix(Ip6::Prefix & aPrefix) const352 Error Client::AddressResponse::GetNat64Prefix(Ip6::Prefix &aPrefix) const
353 {
354 Error error = kErrorNotFound;
355 NetworkData::Iterator iterator = NetworkData::kIteratorInit;
356 signed int preference = NetworkData::kRoutePreferenceLow;
357 NetworkData::ExternalRouteConfig config;
358
359 aPrefix.Clear();
360
361 while (mInstance->Get<NetworkData::Leader>().GetNextExternalRoute(iterator, config) == kErrorNone)
362 {
363 if (!config.mNat64 || !config.GetPrefix().IsValidNat64())
364 {
365 continue;
366 }
367
368 if ((aPrefix.GetLength() == 0) || (config.mPreference > preference))
369 {
370 aPrefix = config.GetPrefix();
371 preference = config.mPreference;
372 error = kErrorNone;
373 }
374 }
375
376 return error;
377 }
378
379 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
380
381 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
382
383 //---------------------------------------------------------------------------------------------------------------------
384 // Client::BrowseResponse
385
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const386 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
387 {
388 Error error;
389 uint16_t offset;
390 uint16_t numRecords;
391 Name serviceName(*mQuery, kNameOffsetInQuery);
392 PtrRecord ptrRecord;
393
394 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
395
396 SelectSection(kAnswerSection, offset, numRecords);
397 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
398 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
399
400 exit:
401 return error;
402 }
403
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const404 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
405 {
406 Error error;
407 Name instanceName;
408
409 // Find a matching PTR record for the service instance label.
410 // Then search and read SRV, TXT and AAAA records in Additional Data section
411 // matching the same name to populate `aServiceInfo`.
412
413 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
414 error = FindServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo);
415
416 exit:
417 return error;
418 }
419
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const420 Error Client::BrowseResponse::GetHostAddress(const char * aHostName,
421 uint16_t aIndex,
422 Ip6::Address &aAddress,
423 uint32_t & aTtl) const
424 {
425 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
426 }
427
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const428 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
429 {
430 // This method searches within the Answer Section for a PTR record
431 // matching a given instance label @aInstanceLabel. If found, the
432 // `aName` is updated to return the name in the message.
433
434 Error error;
435 uint16_t offset;
436 Name serviceName(*mQuery, kNameOffsetInQuery);
437 uint16_t numRecords;
438 uint16_t labelOffset;
439 PtrRecord ptrRecord;
440
441 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
442
443 SelectSection(kAnswerSection, offset, numRecords);
444
445 for (; numRecords > 0; numRecords--)
446 {
447 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
448
449 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
450
451 if (error == kErrorNotFound)
452 {
453 // `ReadRecord()` updates `offset` to skip over a
454 // non-matching record.
455 continue;
456 }
457
458 SuccessOrExit(error);
459
460 // It is a PTR record. Check the first label to match the
461 // instance label.
462
463 labelOffset = offset;
464 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
465
466 if (error == kErrorNone)
467 {
468 aInstanceName.SetFromMessage(*mMessage, offset);
469 ExitNow();
470 }
471
472 VerifyOrExit(error == kErrorNotFound);
473
474 // Update offset to skip over the PTR record.
475 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
476 }
477
478 error = kErrorNotFound;
479
480 exit:
481 return error;
482 }
483
484 //---------------------------------------------------------------------------------------------------------------------
485 // Client::ServiceResponse
486
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const487 Error Client::ServiceResponse::GetServiceName(char * aLabelBuffer,
488 uint8_t aLabelBufferSize,
489 char * aNameBuffer,
490 uint16_t aNameBufferSize) const
491 {
492 Error error;
493 uint16_t offset = kNameOffsetInQuery;
494
495 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
496
497 VerifyOrExit(aNameBuffer != nullptr);
498 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
499
500 exit:
501 return error;
502 }
503
GetServiceInfo(ServiceInfo & aServiceInfo) const504 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
505 {
506 // Search and read SRV, TXT records in Answer Section
507 // matching name from query.
508
509 return FindServiceInfo(kAnswerSection, Name(*mQuery, kNameOffsetInQuery), aServiceInfo);
510 }
511
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const512 Error Client::ServiceResponse::GetHostAddress(const char * aHostName,
513 uint16_t aIndex,
514 Ip6::Address &aAddress,
515 uint32_t & aTtl) const
516 {
517 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
518 }
519
520 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
521
522 //---------------------------------------------------------------------------------------------------------------------
523 // Client
524
525 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
526 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
527 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
528 #endif
529 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
530 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
531 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
532 #endif
533
534 const uint8_t Client::kQuestionCount[] = {
535 /* kIp6AddressQuery -> */ OT_ARRAY_LENGTH(kIp6AddressQueryRecordTypes), // AAAA records
536 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
537 /* kIp4AddressQuery -> */ OT_ARRAY_LENGTH(kIp4AddressQueryRecordTypes), // A records
538 #endif
539 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
540 /* kBrowseQuery -> */ OT_ARRAY_LENGTH(kBrowseQueryRecordTypes), // PTR records
541 /* kServiceQuery -> */ OT_ARRAY_LENGTH(kServiceQueryRecordTypes), // SRV and TXT records
542 #endif
543 };
544
545 const uint16_t *Client::kQuestionRecordTypes[] = {
546 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
547 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
548 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
549 #endif
550 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
551 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
552 /* kServiceQuery -> */ kServiceQueryRecordTypes,
553 #endif
554 };
555
Client(Instance & aInstance)556 Client::Client(Instance &aInstance)
557 : InstanceLocator(aInstance)
558 , mSocket(aInstance)
559 , mTimer(aInstance, Client::HandleTimer)
560 , mDefaultConfig(QueryConfig::kInitFromDefaults)
561 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
562 , mUserDidSetDefaultAddress(false)
563 #endif
564 {
565 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
566 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
567 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
568 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
569 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
570 static_assert(kServiceQuery == 3, "kServiceQuery value is not correct");
571 #endif
572 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
573 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
574 static_assert(kServiceQuery == 2, "kServiceQuery value is not correct");
575 #endif
576 }
577
Start(void)578 Error Client::Start(void)
579 {
580 Error error;
581
582 SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
583 SuccessOrExit(error = mSocket.Bind(0, OT_NETIF_UNSPECIFIED));
584
585 exit:
586 return error;
587 }
588
Stop(void)589 void Client::Stop(void)
590 {
591 Query *query;
592
593 while ((query = mQueries.GetHead()) != nullptr)
594 {
595 FinalizeQuery(*query, kErrorAbort);
596 }
597
598 IgnoreError(mSocket.Close());
599 }
600
SetDefaultConfig(const QueryConfig & aQueryConfig)601 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
602 {
603 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
604
605 mDefaultConfig.SetFrom(aQueryConfig, startingDefault);
606
607 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
608 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
609 UpdateDefaultConfigAddress();
610 #endif
611 }
612
ResetDefaultConfig(void)613 void Client::ResetDefaultConfig(void)
614 {
615 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
616
617 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
618 mUserDidSetDefaultAddress = false;
619 UpdateDefaultConfigAddress();
620 #endif
621 }
622
623 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)624 void Client::UpdateDefaultConfigAddress(void)
625 {
626 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
627
628 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
629 !srpServerAddr.IsUnspecified())
630 {
631 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
632 }
633 }
634 #endif
635
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)636 Error Client::ResolveAddress(const char * aHostName,
637 AddressCallback aCallback,
638 void * aContext,
639 const QueryConfig *aConfig)
640 {
641 QueryInfo info;
642
643 info.Clear();
644 info.mQueryType = kIp6AddressQuery;
645 info.mCallback.mAddressCallback = aCallback;
646
647 return StartQuery(info, aConfig, nullptr, aHostName, aContext);
648 }
649
650 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
651
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)652 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
653 {
654 QueryInfo info;
655
656 info.Clear();
657 info.mQueryType = kBrowseQuery;
658 info.mCallback.mBrowseCallback = aCallback;
659
660 return StartQuery(info, aConfig, nullptr, aServiceName, aContext);
661 }
662
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)663 Error Client::ResolveService(const char * aInstanceLabel,
664 const char * aServiceName,
665 ServiceCallback aCallback,
666 void * aContext,
667 const QueryConfig *aConfig)
668 {
669 QueryInfo info;
670 Error error;
671
672 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
673
674 info.Clear();
675 info.mQueryType = kServiceQuery;
676 info.mCallback.mServiceCallback = aCallback;
677
678 error = StartQuery(info, aConfig, aInstanceLabel, aServiceName, aContext);
679
680 exit:
681 return error;
682 }
683
684 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
685
StartQuery(QueryInfo & aInfo,const QueryConfig * aConfig,const char * aLabel,const char * aName,void * aContext)686 Error Client::StartQuery(QueryInfo & aInfo,
687 const QueryConfig *aConfig,
688 const char * aLabel,
689 const char * aName,
690 void * aContext)
691 {
692 // This method assumes that `mQueryType` and `mCallback` to be
693 // already set by caller on `aInfo`. The `aLabel` can be `nullptr`
694 // and then `aName` provides the full name, otherwise the name is
695 // appended as `{aLabel}.{aName}`.
696
697 Error error;
698 Query *query;
699
700 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
701
702 if (aConfig == nullptr)
703 {
704 aInfo.mConfig = mDefaultConfig;
705 }
706 else
707 {
708 // To form the config for this query, replace any unspecified
709 // fields (zero value) in the given `aConfig` with the fields
710 // from `mDefaultConfig`.
711
712 aInfo.mConfig.SetFrom(*aConfig, mDefaultConfig);
713 }
714
715 aInfo.mCallbackContext = aContext;
716
717 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
718 mQueries.Enqueue(*query);
719
720 SendQuery(*query, aInfo, /* aUpdateTimer */ true);
721
722 exit:
723 return error;
724 }
725
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)726 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
727 {
728 Error error = kErrorNone;
729
730 aQuery = Get<MessagePool>().New(Message::kTypeOther, /* aReserveHeader */ 0);
731 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
732
733 SuccessOrExit(error = aQuery->Append(aInfo));
734
735 if (aLabel != nullptr)
736 {
737 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
738 }
739
740 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
741
742 exit:
743 FreeAndNullMessageOnError(aQuery, error);
744 return error;
745 }
746
FreeQuery(Query & aQuery)747 void Client::FreeQuery(Query &aQuery)
748 {
749 mQueries.DequeueAndFree(aQuery);
750 }
751
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)752 void Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
753 {
754 // This method prepares and sends a query message represented by
755 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
756 // the new `mRetransmissionTime`) and updates it in `aQuery` as
757 // well. `aUpdateTimer` indicates whether the timer should be
758 // updated when query is sent or not (used in the case where timer
759 // is handled by caller).
760
761 Error error = kErrorNone;
762 Message * message = nullptr;
763 Header header;
764 Ip6::MessageInfo messageInfo;
765
766 aInfo.mTransmissionCount++;
767 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
768
769 if (aInfo.mMessageId == 0)
770 {
771 do
772 {
773 SuccessOrExit(error = header.SetRandomMessageId());
774 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
775
776 aInfo.mMessageId = header.GetMessageId();
777 }
778 else
779 {
780 header.SetMessageId(aInfo.mMessageId);
781 }
782
783 header.SetType(Header::kTypeQuery);
784 header.SetQueryType(Header::kQueryTypeStandard);
785
786 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
787 {
788 header.SetRecursionDesiredFlag();
789 }
790
791 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
792
793 message = mSocket.NewMessage(0);
794 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
795
796 SuccessOrExit(error = message->Append(header));
797
798 // Prepare the question section.
799
800 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
801 {
802 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
803 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
804 }
805
806 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
807 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
808
809 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
810
811 exit:
812 FreeMessageOnError(message, error);
813
814 UpdateQuery(aQuery, aInfo);
815
816 if (aUpdateTimer)
817 {
818 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
819 }
820 }
821
AppendNameFromQuery(const Query & aQuery,Message & aMessage)822 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
823 {
824 Error error = kErrorNone;
825 uint16_t offset;
826 uint16_t length;
827
828 // The name is encoded and included after the `Info` in `aQuery`. We
829 // first calculate the encoded length of the name, then grow the
830 // message, and finally copy the encoded name bytes from `aQuery`
831 // into `aMessage`.
832
833 length = aQuery.GetLength() - kNameOffsetInQuery;
834
835 offset = aMessage.GetLength();
836 SuccessOrExit(error = aMessage.SetLength(offset + length));
837
838 aQuery.CopyTo(/* aSourceOffset */ kNameOffsetInQuery, /* aDestOffset */ offset, length, aMessage);
839
840 exit:
841 return error;
842 }
843
FinalizeQuery(Query & aQuery,Error aError)844 void Client::FinalizeQuery(Query &aQuery, Error aError)
845 {
846 Response response;
847 QueryInfo info;
848
849 response.mInstance = &Get<Instance>();
850 response.mQuery = &aQuery;
851 info.ReadFrom(aQuery);
852
853 FinalizeQuery(response, info.mQueryType, aError);
854 }
855
FinalizeQuery(Response & aResponse,QueryType aType,Error aError)856 void Client::FinalizeQuery(Response &aResponse, QueryType aType, Error aError)
857 {
858 Callback callback;
859 void * context;
860
861 GetCallback(*aResponse.mQuery, callback, context);
862
863 switch (aType)
864 {
865 case kIp6AddressQuery:
866 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
867 case kIp4AddressQuery:
868 #endif
869 if (callback.mAddressCallback != nullptr)
870 {
871 callback.mAddressCallback(aError, &aResponse, context);
872 }
873 break;
874
875 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
876 case kBrowseQuery:
877 if (callback.mBrowseCallback != nullptr)
878 {
879 callback.mBrowseCallback(aError, &aResponse, context);
880 }
881 break;
882
883 case kServiceQuery:
884 if (callback.mServiceCallback != nullptr)
885 {
886 callback.mServiceCallback(aError, &aResponse, context);
887 }
888 break;
889 #endif
890 }
891
892 FreeQuery(*aResponse.mQuery);
893 }
894
GetCallback(const Query & aQuery,Callback & aCallback,void * & aContext)895 void Client::GetCallback(const Query &aQuery, Callback &aCallback, void *&aContext)
896 {
897 QueryInfo info;
898
899 info.ReadFrom(aQuery);
900
901 aCallback = info.mCallback;
902 aContext = info.mCallbackContext;
903 }
904
FindQueryById(uint16_t aMessageId)905 Client::Query *Client::FindQueryById(uint16_t aMessageId)
906 {
907 Query * query;
908 QueryInfo info;
909
910 for (query = mQueries.GetHead(); query != nullptr; query = query->GetNext())
911 {
912 info.ReadFrom(*query);
913
914 if (info.mMessageId == aMessageId)
915 {
916 break;
917 }
918 }
919
920 return query;
921 }
922
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)923 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
924 {
925 OT_UNUSED_VARIABLE(aMsgInfo);
926
927 static_cast<Client *>(aContext)->ProcessResponse(*static_cast<Message *>(aMessage));
928 }
929
ProcessResponse(const Message & aMessage)930 void Client::ProcessResponse(const Message &aMessage)
931 {
932 Response response;
933 QueryType type;
934 Error responseError;
935
936 response.mInstance = &Get<Instance>();
937 response.mMessage = &aMessage;
938
939 // We intentionally parse the response in a separate method
940 // `ParseResponse()` to free all the stack allocated variables
941 // (e.g., `QueryInfo`) used during parsing of the message before
942 // finalizing the query and invoking the user's callback.
943
944 SuccessOrExit(ParseResponse(response, type, responseError));
945 FinalizeQuery(response, type, responseError);
946
947 exit:
948 return;
949 }
950
ParseResponse(Response & aResponse,QueryType & aType,Error & aResponseError)951 Error Client::ParseResponse(Response &aResponse, QueryType &aType, Error &aResponseError)
952 {
953 Error error = kErrorNone;
954 const Message &message = *aResponse.mMessage;
955 uint16_t offset = message.GetOffset();
956 Header header;
957 QueryInfo info;
958 Name queryName;
959
960 SuccessOrExit(error = message.Read(offset, header));
961 offset += sizeof(Header);
962
963 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
964 !header.IsTruncationFlagSet(),
965 error = kErrorDrop);
966
967 aResponse.mQuery = FindQueryById(header.GetMessageId());
968 VerifyOrExit(aResponse.mQuery != nullptr, error = kErrorNotFound);
969
970 info.ReadFrom(*aResponse.mQuery);
971 aType = info.mQueryType;
972
973 queryName.SetFromMessage(*aResponse.mQuery, kNameOffsetInQuery);
974
975 // Check the Question Section
976
977 VerifyOrExit(header.GetQuestionCount() == kQuestionCount[aType], error = kErrorParse);
978
979 for (uint8_t num = 0; num < kQuestionCount[aType]; num++)
980 {
981 SuccessOrExit(error = Name::CompareName(message, offset, queryName));
982 offset += sizeof(Question);
983 }
984
985 // Check the answer, authority and additional record sections
986
987 aResponse.mAnswerOffset = offset;
988 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAnswerCount()));
989 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAuthorityRecordCount()));
990 aResponse.mAdditionalOffset = offset;
991 SuccessOrExit(error = ResourceRecord::ParseRecords(message, offset, header.GetAdditionalRecordCount()));
992
993 aResponse.mAnswerRecordCount = header.GetAnswerCount();
994 aResponse.mAdditionalRecordCount = header.GetAdditionalRecordCount();
995
996 // Check the response code from server
997
998 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
999
1000 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1001
1002 if (aType == kIp6AddressQuery)
1003 {
1004 Ip6::Address ip6ddress;
1005 uint32_t ttl;
1006 ARecord aRecord;
1007
1008 // If the response does not contain an answer for the IPv6 address
1009 // resolution query and if NAT64 is allowed for this query, we can
1010 // perform IPv4 to IPv6 address translation.
1011
1012 VerifyOrExit(aResponse.FindHostAddress(Response::kAnswerSection, queryName, /* aIndex */ 0, ip6ddress, ttl) !=
1013 kErrorNone);
1014 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1015
1016 // First, we check if the response already contains an A record
1017 // (IPv4 address) for the query name.
1018
1019 if (aResponse.FindARecord(Response::kAdditionalDataSection, queryName, /* aIndex */ 0, aRecord) == kErrorNone)
1020 {
1021 aResponse.mIp6QueryResponseRequiresNat64 = true;
1022 aResponseError = kErrorNone;
1023 ExitNow();
1024 }
1025
1026 // Otherwise, we send a new query for IPv4 address resolution
1027 // for the same host name. We reuse the existing `query`
1028 // instance and keep all the info but clear `mTransmissionCount`
1029 // and `mMessageId` (so that a new random message ID is
1030 // selected). The new `info` will be saved in the query in
1031 // `SendQuery()`. Note that the current query is still in the
1032 // `mQueries` list when `SendQuery()` selects a new random
1033 // message ID, so the existing message ID for this query will
1034 // not be reused. Since the query is not yet resolved, we
1035 // return `kErrorPending`.
1036
1037 info.mQueryType = kIp4AddressQuery;
1038 info.mMessageId = 0;
1039 info.mTransmissionCount = 0;
1040
1041 SendQuery(*aResponse.mQuery, info, /* aUpdateTimer */ true);
1042
1043 error = kErrorPending;
1044 }
1045
1046 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1047
1048 exit:
1049 if (error != kErrorNone)
1050 {
1051 otLogInfoDns("Failed to parse response %s", ErrorToString(error));
1052 }
1053
1054 return error;
1055 }
1056
HandleTimer(Timer & aTimer)1057 void Client::HandleTimer(Timer &aTimer)
1058 {
1059 aTimer.Get<Client>().HandleTimer();
1060 }
1061
HandleTimer(void)1062 void Client::HandleTimer(void)
1063 {
1064 TimeMilli now = TimerMilli::GetNow();
1065 TimeMilli nextTime = now.GetDistantFuture();
1066 Query * nextQuery;
1067 QueryInfo info;
1068
1069 for (Query *query = mQueries.GetHead(); query != nullptr; query = nextQuery)
1070 {
1071 nextQuery = query->GetNext();
1072
1073 info.ReadFrom(*query);
1074
1075 if (now >= info.mRetransmissionTime)
1076 {
1077 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1078 {
1079 FinalizeQuery(*query, kErrorResponseTimeout);
1080 continue;
1081 }
1082
1083 SendQuery(*query, info, /* aUpdateTimer */ false);
1084 }
1085
1086 if (nextTime > info.mRetransmissionTime)
1087 {
1088 nextTime = info.mRetransmissionTime;
1089 }
1090 }
1091
1092 if (nextTime < now.GetDistantFuture())
1093 {
1094 mTimer.FireAt(nextTime);
1095 }
1096 }
1097
1098 } // namespace Dns
1099 } // namespace ot
1100
1101 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1102