1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "instance/instance.hpp"
34 #include "utils/static_counter.hpp"
35
36 /**
37 * @file
38 * This file implements the DNS client.
39 */
40
41 namespace ot {
42 namespace Dns {
43
44 RegisterLogModule("DnsClient");
45
46 //---------------------------------------------------------------------------------------------------------------------
47 // Client::QueryConfig
48
49 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
50
QueryConfig(InitMode aMode)51 Client::QueryConfig::QueryConfig(InitMode aMode)
52 {
53 OT_UNUSED_VARIABLE(aMode);
54
55 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
56 GetServerSockAddr().SetPort(kDefaultServerPort);
57 SetResponseTimeout(kDefaultResponseTimeout);
58 SetMaxTxAttempts(kDefaultMaxTxAttempts);
59 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
60 SetServiceMode(kDefaultServiceMode);
61 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
62 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
63 #endif
64 SetTransportProto(kDnsTransportUdp);
65 }
66
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)67 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
68 {
69 // This method sets the config from `aConfig` replacing any
70 // unspecified fields (value zero) with the fields from
71 // `aDefaultConfig`. If `aConfig` is `nullptr` then
72 // `aDefaultConfig` is used.
73
74 if (aConfig == nullptr)
75 {
76 *this = aDefaultConfig;
77 ExitNow();
78 }
79
80 *this = *aConfig;
81
82 if (GetServerSockAddr().GetAddress().IsUnspecified())
83 {
84 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
85 }
86
87 if (GetServerSockAddr().GetPort() == 0)
88 {
89 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
90 }
91
92 if (GetResponseTimeout() == 0)
93 {
94 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
95 }
96
97 if (GetMaxTxAttempts() == 0)
98 {
99 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
100 }
101
102 if (GetRecursionFlag() == kFlagUnspecified)
103 {
104 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
105 }
106
107 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
108 if (GetNat64Mode() == kNat64Unspecified)
109 {
110 SetNat64Mode(aDefaultConfig.GetNat64Mode());
111 }
112 #endif
113
114 if (GetServiceMode() == kServiceModeUnspecified)
115 {
116 SetServiceMode(aDefaultConfig.GetServiceMode());
117 }
118
119 if (GetTransportProto() == kDnsTransportUnspecified)
120 {
121 SetTransportProto(aDefaultConfig.GetTransportProto());
122 }
123
124 exit:
125 return;
126 }
127
128 //---------------------------------------------------------------------------------------------------------------------
129 // Client::Response
130
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const131 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
132 {
133 switch (aSection)
134 {
135 case kAnswerSection:
136 aOffset = mAnswerOffset;
137 aNumRecord = mAnswerRecordCount;
138 break;
139 case kAdditionalDataSection:
140 default:
141 aOffset = mAdditionalOffset;
142 aNumRecord = mAdditionalRecordCount;
143 break;
144 }
145 }
146
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const147 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
148 {
149 uint16_t offset = kNameOffsetInQuery;
150
151 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
152 }
153
CheckForHostNameAlias(Section aSection,Name & aHostName) const154 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
155 {
156 // If the response includes a CNAME record mapping the query host
157 // name to a canonical name, we update `aHostName` to the new alias
158 // name. Otherwise `aHostName` remains as before. This method handles
159 // when there are multiple CNAME records mapping the host name multiple
160 // times. We limit number of changes to `kMaxCnameAliasNameChanges`
161 // to detect and handle if the response contains CNAME record loops.
162
163 Error error;
164 uint16_t offset;
165 uint16_t numRecords;
166 CnameRecord cnameRecord;
167
168 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
169
170 for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
171 {
172 SelectSection(aSection, offset, numRecords);
173 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
174
175 if (error == kErrorNotFound)
176 {
177 error = kErrorNone;
178 ExitNow();
179 }
180
181 SuccessOrExit(error);
182
183 // A CNAME record was found. `offset` now points to after the
184 // last read byte within the `mMessage` into the `cnameRecord`
185 // (which is the start of the new canonical name).
186
187 aHostName.SetFromMessage(*mMessage, offset);
188 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
189
190 // Loop back to check if there may be a CNAME record for the
191 // new `aHostName`.
192 }
193
194 error = kErrorParse;
195
196 exit:
197 return error;
198 }
199
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const200 Error Client::Response::FindHostAddress(Section aSection,
201 const Name &aHostName,
202 uint16_t aIndex,
203 Ip6::Address &aAddress,
204 uint32_t &aTtl) const
205 {
206 Error error;
207 uint16_t offset;
208 uint16_t numRecords;
209 Name name = aHostName;
210 AaaaRecord aaaaRecord;
211
212 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
213
214 SelectSection(aSection, offset, numRecords);
215 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
216 aAddress = aaaaRecord.GetAddress();
217 aTtl = aaaaRecord.GetTtl();
218
219 exit:
220 return error;
221 }
222
223 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
224
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const225 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
226 {
227 Error error;
228 uint16_t offset;
229 uint16_t numRecords;
230 Name name = aHostName;
231
232 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
233
234 SelectSection(aSection, offset, numRecords);
235 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
236
237 exit:
238 return error;
239 }
240
241 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
242
243 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
244
InitServiceInfo(ServiceInfo & aServiceInfo) const245 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
246 {
247 // This method initializes `aServiceInfo` setting all
248 // TTLs to zero and host name to empty string.
249
250 aServiceInfo.mTtl = 0;
251 aServiceInfo.mHostAddressTtl = 0;
252 aServiceInfo.mTxtDataTtl = 0;
253 aServiceInfo.mTxtDataTruncated = false;
254
255 AsCoreType(&aServiceInfo.mHostAddress).Clear();
256
257 if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
258 {
259 aServiceInfo.mHostNameBuffer[0] = '\0';
260 }
261 }
262
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const263 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
264 {
265 // This method searches for SRV record in the given `aSection`
266 // matching the record name against `aName`, and updates the
267 // `aServiceInfo` accordingly. It also searches for AAAA record
268 // for host name associated with the service (from SRV record).
269 // The search for AAAA record is always performed in Additional
270 // Data section (independent of the value given in `aSection`).
271
272 Error error = kErrorNone;
273 uint16_t offset;
274 uint16_t numRecords;
275 Name hostName;
276 SrvRecord srvRecord;
277
278 // A non-zero `mTtl` indicates that SRV record is already found
279 // and parsed from a previous response.
280 VerifyOrExit(aServiceInfo.mTtl == 0);
281
282 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
283
284 // Search for a matching SRV record
285 SelectSection(aSection, offset, numRecords);
286 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
287
288 aServiceInfo.mTtl = srvRecord.GetTtl();
289 aServiceInfo.mPort = srvRecord.GetPort();
290 aServiceInfo.mPriority = srvRecord.GetPriority();
291 aServiceInfo.mWeight = srvRecord.GetWeight();
292
293 hostName.SetFromMessage(*mMessage, offset);
294
295 if (aServiceInfo.mHostNameBuffer != nullptr)
296 {
297 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
298 aServiceInfo.mHostNameBufferSize));
299 }
300 else
301 {
302 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
303 }
304
305 // Search in additional section for AAAA record for the host name.
306
307 VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
308
309 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
310 aServiceInfo.mHostAddressTtl);
311
312 if (error == kErrorNotFound)
313 {
314 error = kErrorNone;
315 }
316
317 exit:
318 return error;
319 }
320
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const321 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
322 {
323 // This method searches a TXT record in the given `aSection`
324 // matching the record name against `aName` and updates the TXT
325 // related properties in `aServicesInfo`.
326 //
327 // If no match is found `mTxtDataTtl` (which is initialized to zero)
328 // remains unchanged to indicate this. In this case this method still
329 // returns `kErrorNone`.
330
331 Error error = kErrorNone;
332 uint16_t offset;
333 uint16_t numRecords;
334 TxtRecord txtRecord;
335
336 // A non-zero `mTxtDataTtl` indicates that TXT record is already
337 // found and parsed from a previous response.
338 VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
339
340 // A null `mTxtData` indicates that caller does not want to retrieve
341 // TXT data.
342 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
343
344 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
345
346 SelectSection(aSection, offset, numRecords);
347
348 aServiceInfo.mTxtDataTruncated = false;
349
350 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
351
352 error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
353
354 if (error == kErrorNoBufs)
355 {
356 error = kErrorNone;
357
358 // Mark `mTxtDataTruncated` to indicate that we could not read
359 // the full TXT record into the given `mTxtData` buffer.
360 aServiceInfo.mTxtDataTruncated = true;
361 }
362
363 SuccessOrExit(error);
364 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
365
366 exit:
367 if (error == kErrorNotFound)
368 {
369 error = kErrorNone;
370 }
371
372 return error;
373 }
374
375 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
376
PopulateFrom(const Message & aMessage)377 void Client::Response::PopulateFrom(const Message &aMessage)
378 {
379 // Populate `Response` with info from `aMessage`.
380
381 uint16_t offset = aMessage.GetOffset();
382 Header header;
383
384 mMessage = &aMessage;
385
386 IgnoreError(aMessage.Read(offset, header));
387 offset += sizeof(Header);
388
389 for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
390 {
391 IgnoreError(Name::ParseName(aMessage, offset));
392 offset += sizeof(Question);
393 }
394
395 mAnswerOffset = offset;
396 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
397 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
398 mAdditionalOffset = offset;
399 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
400
401 mAnswerRecordCount = header.GetAnswerCount();
402 mAdditionalRecordCount = header.GetAdditionalRecordCount();
403 }
404
405 //---------------------------------------------------------------------------------------------------------------------
406 // Client::AddressResponse
407
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const408 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
409 {
410 Error error = kErrorNone;
411 Name name(*mQuery, kNameOffsetInQuery);
412
413 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
414
415 // If the response is for an IPv4 address query or if it is an
416 // IPv6 address query response with no IPv6 address but with
417 // an IPv4 in its additional section, we read the IPv4 address
418 // and translate it to an IPv6 address.
419
420 QueryInfo info;
421
422 info.ReadFrom(*mQuery);
423
424 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
425 {
426 Section section;
427 ARecord aRecord;
428 NetworkData::ExternalRouteConfig nat64Prefix;
429
430 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
431 error = kErrorInvalidState);
432
433 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
434 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
435
436 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
437 aTtl = aRecord.GetTtl();
438
439 ExitNow();
440 }
441
442 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
443
444 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
445
446 exit:
447 return error;
448 }
449
450 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
451
452 //---------------------------------------------------------------------------------------------------------------------
453 // Client::BrowseResponse
454
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const455 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
456 {
457 Error error;
458 uint16_t offset;
459 uint16_t numRecords;
460 Name serviceName(*mQuery, kNameOffsetInQuery);
461 PtrRecord ptrRecord;
462
463 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
464
465 SelectSection(kAnswerSection, offset, numRecords);
466 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
467 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
468
469 exit:
470 return error;
471 }
472
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const473 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
474 {
475 Error error;
476 Name instanceName;
477
478 // Find a matching PTR record for the service instance label. Then
479 // search and read SRV, TXT and AAAA records in Additional Data
480 // section matching the same name to populate `aServiceInfo`.
481
482 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
483
484 InitServiceInfo(aServiceInfo);
485 SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
486 SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
487
488 if (aServiceInfo.mTxtDataTtl == 0)
489 {
490 aServiceInfo.mTxtDataSize = 0;
491 }
492
493 exit:
494 return error;
495 }
496
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const497 Error Client::BrowseResponse::GetHostAddress(const char *aHostName,
498 uint16_t aIndex,
499 Ip6::Address &aAddress,
500 uint32_t &aTtl) const
501 {
502 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
503 }
504
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const505 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
506 {
507 // This method searches within the Answer Section for a PTR record
508 // matching a given instance label @aInstanceLabel. If found, the
509 // `aName` is updated to return the name in the message.
510
511 Error error;
512 uint16_t offset;
513 Name serviceName(*mQuery, kNameOffsetInQuery);
514 uint16_t numRecords;
515 uint16_t labelOffset;
516 PtrRecord ptrRecord;
517
518 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
519
520 SelectSection(kAnswerSection, offset, numRecords);
521
522 for (; numRecords > 0; numRecords--)
523 {
524 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
525
526 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
527
528 if (error == kErrorNotFound)
529 {
530 // `ReadRecord()` updates `offset` to skip over a
531 // non-matching record.
532 continue;
533 }
534
535 SuccessOrExit(error);
536
537 // It is a PTR record. Check the first label to match the
538 // instance label.
539
540 labelOffset = offset;
541 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
542
543 if (error == kErrorNone)
544 {
545 aInstanceName.SetFromMessage(*mMessage, offset);
546 ExitNow();
547 }
548
549 VerifyOrExit(error == kErrorNotFound);
550
551 // Update offset to skip over the PTR record.
552 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
553 }
554
555 error = kErrorNotFound;
556
557 exit:
558 return error;
559 }
560
561 //---------------------------------------------------------------------------------------------------------------------
562 // Client::ServiceResponse
563
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const564 Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer,
565 uint8_t aLabelBufferSize,
566 char *aNameBuffer,
567 uint16_t aNameBufferSize) const
568 {
569 Error error;
570 uint16_t offset = kNameOffsetInQuery;
571
572 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
573
574 VerifyOrExit(aNameBuffer != nullptr);
575 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
576
577 exit:
578 return error;
579 }
580
GetServiceInfo(ServiceInfo & aServiceInfo) const581 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
582 {
583 // Search and read SRV, TXT records matching name from query.
584
585 Error error = kErrorNotFound;
586
587 InitServiceInfo(aServiceInfo);
588
589 for (const Response *response = this; response != nullptr; response = response->mNext)
590 {
591 Name name(*response->mQuery, kNameOffsetInQuery);
592 QueryInfo info;
593 Section srvSection;
594 Section txtSection;
595
596 info.ReadFrom(*response->mQuery);
597
598 switch (info.mQueryType)
599 {
600 case kIp6AddressQuery:
601 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
602 case kIp4AddressQuery:
603 #endif
604 IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
605 AsCoreType(&aServiceInfo.mHostAddress),
606 aServiceInfo.mHostAddressTtl));
607
608 continue; // to `for()` loop
609
610 case kServiceQuerySrvTxt:
611 case kServiceQuerySrv:
612 case kServiceQueryTxt:
613 break;
614
615 default:
616 continue;
617 }
618
619 // Determine from which section we should try to read the SRV and
620 // TXT records based on the query type.
621 //
622 // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
623 // only one record (SRV or TXT) in the answer section, but we
624 // still try to read the other records from additional data
625 // section in case server provided them.
626
627 srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
628 txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
629
630 error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
631
632 if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
633 {
634 error = kErrorNone;
635 }
636
637 SuccessOrExit(error);
638
639 SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
640 }
641
642 if (aServiceInfo.mTxtDataTtl == 0)
643 {
644 aServiceInfo.mTxtDataSize = 0;
645 }
646
647 exit:
648 return error;
649 }
650
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const651 Error Client::ServiceResponse::GetHostAddress(const char *aHostName,
652 uint16_t aIndex,
653 Ip6::Address &aAddress,
654 uint32_t &aTtl) const
655 {
656 Error error = kErrorNotFound;
657
658 for (const Response *response = this; response != nullptr; response = response->mNext)
659 {
660 Section section = kAdditionalDataSection;
661 QueryInfo info;
662
663 info.ReadFrom(*response->mQuery);
664
665 switch (info.mQueryType)
666 {
667 case kIp6AddressQuery:
668 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
669 case kIp4AddressQuery:
670 #endif
671 section = kAnswerSection;
672 break;
673
674 default:
675 break;
676 }
677
678 error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
679
680 if (error == kErrorNone)
681 {
682 break;
683 }
684 }
685
686 return error;
687 }
688
689 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
690
691 //---------------------------------------------------------------------------------------------------------------------
692 // Client
693
694 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
695 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
696 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
697 #endif
698 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
699 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
700 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
701 #endif
702
703 const uint8_t Client::kQuestionCount[] = {
704 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
705 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
706 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
707 #endif
708 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
709 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record
710 /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
711 /* kServiceQuerySrv -> */ 1, // SRV record only
712 /* kServiceQueryTxt -> */ 1, // TXT record only
713 #endif
714 };
715
716 const uint16_t *const Client::kQuestionRecordTypes[] = {
717 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
718 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
719 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
720 #endif
721 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
722 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
723 /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
724 /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0],
725 /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1],
726
727 #endif
728 };
729
Client(Instance & aInstance)730 Client::Client(Instance &aInstance)
731 : InstanceLocator(aInstance)
732 , mSocket(aInstance, *this)
733 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
734 , mTcpState(kTcpUninitialized)
735 #endif
736 , mTimer(aInstance)
737 , mDefaultConfig(QueryConfig::kInitFromDefaults)
738 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
739 , mUserDidSetDefaultAddress(false)
740 #endif
741 {
742 struct QueryTypeChecker
743 {
744 InitEnumValidatorCounter();
745
746 ValidateNextEnum(kIp6AddressQuery);
747 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
748 ValidateNextEnum(kIp4AddressQuery);
749 #endif
750 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
751 ValidateNextEnum(kBrowseQuery);
752 ValidateNextEnum(kServiceQuerySrvTxt);
753 ValidateNextEnum(kServiceQuerySrv);
754 ValidateNextEnum(kServiceQueryTxt);
755 #endif
756 };
757
758 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
759 ClearAllBytes(mSendLink);
760 #endif
761 }
762
Start(void)763 Error Client::Start(void)
764 {
765 Error error;
766
767 SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified));
768 SuccessOrExit(error = mSocket.Bind(0));
769
770 exit:
771 return error;
772 }
773
Stop(void)774 void Client::Stop(void)
775 {
776 Query *query;
777
778 while ((query = mMainQueries.GetHead()) != nullptr)
779 {
780 FinalizeQuery(*query, kErrorAbort);
781 }
782
783 IgnoreError(mSocket.Close());
784 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
785 if (mTcpState != kTcpUninitialized)
786 {
787 IgnoreError(mEndpoint.Deinitialize());
788 }
789 #endif
790
791 mLimitedQueryServers.Clear();
792 }
793
794 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)795 Error Client::InitTcpSocket(void)
796 {
797 Error error;
798 otTcpEndpointInitializeArgs endpointArgs;
799
800 ClearAllBytes(endpointArgs);
801 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
802 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
803 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
804 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
805 endpointArgs.mContext = this;
806 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
807 endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes);
808
809 mSendLink.mNext = nullptr;
810 mSendLink.mData = mSendBufferBytes;
811 mSendLink.mLength = 0;
812
813 SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
814 exit:
815 return error;
816 }
817 #endif
818
SetDefaultConfig(const QueryConfig & aQueryConfig)819 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
820 {
821 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
822
823 mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
824
825 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
826 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
827 UpdateDefaultConfigAddress();
828 #endif
829 }
830
ResetDefaultConfig(void)831 void Client::ResetDefaultConfig(void)
832 {
833 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
834
835 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
836 mUserDidSetDefaultAddress = false;
837 UpdateDefaultConfigAddress();
838 #endif
839 }
840
841 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)842 void Client::UpdateDefaultConfigAddress(void)
843 {
844 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
845
846 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
847 !srpServerAddr.IsUnspecified())
848 {
849 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
850 }
851 }
852 #endif
853
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)854 Error Client::ResolveAddress(const char *aHostName,
855 AddressCallback aCallback,
856 void *aContext,
857 const QueryConfig *aConfig)
858 {
859 QueryInfo info;
860
861 info.Clear();
862 info.mQueryType = kIp6AddressQuery;
863 info.mConfig.SetFrom(aConfig, mDefaultConfig);
864 info.mCallback.mAddressCallback = aCallback;
865 info.mCallbackContext = aContext;
866
867 return StartQuery(info, nullptr, aHostName);
868 }
869
870 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)871 Error Client::ResolveIp4Address(const char *aHostName,
872 AddressCallback aCallback,
873 void *aContext,
874 const QueryConfig *aConfig)
875 {
876 QueryInfo info;
877
878 info.Clear();
879 info.mQueryType = kIp4AddressQuery;
880 info.mConfig.SetFrom(aConfig, mDefaultConfig);
881 info.mCallback.mAddressCallback = aCallback;
882 info.mCallbackContext = aContext;
883
884 return StartQuery(info, nullptr, aHostName);
885 }
886 #endif
887
888 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
889
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)890 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
891 {
892 QueryInfo info;
893
894 info.Clear();
895 info.mQueryType = kBrowseQuery;
896 info.mConfig.SetFrom(aConfig, mDefaultConfig);
897 info.mCallback.mBrowseCallback = aCallback;
898 info.mCallbackContext = aContext;
899
900 return StartQuery(info, nullptr, aServiceName);
901 }
902
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)903 Error Client::ResolveService(const char *aInstanceLabel,
904 const char *aServiceName,
905 ServiceCallback aCallback,
906 void *aContext,
907 const QueryConfig *aConfig)
908 {
909 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
910 }
911
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)912 Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel,
913 const char *aServiceName,
914 ServiceCallback aCallback,
915 void *aContext,
916 const QueryConfig *aConfig)
917 {
918 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
919 }
920
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)921 Error Client::Resolve(const char *aInstanceLabel,
922 const char *aServiceName,
923 ServiceCallback aCallback,
924 void *aContext,
925 const QueryConfig *aConfig,
926 bool aShouldResolveHostAddr)
927 {
928 QueryInfo info;
929 Error error;
930 QueryType secondQueryType = kNoQuery;
931
932 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
933
934 info.Clear();
935
936 info.mConfig.SetFrom(aConfig, mDefaultConfig);
937 info.mShouldResolveHostAddr = aShouldResolveHostAddr;
938
939 CheckAndUpdateServiceMode(info.mConfig, aConfig);
940
941 switch (info.mConfig.GetServiceMode())
942 {
943 case QueryConfig::kServiceModeSrvTxtSeparate:
944 secondQueryType = kServiceQueryTxt;
945
946 OT_FALL_THROUGH;
947
948 case QueryConfig::kServiceModeSrv:
949 info.mQueryType = kServiceQuerySrv;
950 break;
951
952 case QueryConfig::kServiceModeTxt:
953 info.mQueryType = kServiceQueryTxt;
954 VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
955 break;
956
957 case QueryConfig::kServiceModeSrvTxt:
958 case QueryConfig::kServiceModeSrvTxtOptimize:
959 default:
960 info.mQueryType = kServiceQuerySrvTxt;
961 break;
962 }
963
964 info.mCallback.mServiceCallback = aCallback;
965 info.mCallbackContext = aContext;
966
967 error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
968
969 exit:
970 return error;
971 }
972
973 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
974
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)975 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
976 {
977 // The `aLabel` can be `nullptr` and then `aName` provides the
978 // full name, otherwise the name is appended as `{aLabel}.
979 // {aName}`.
980
981 Error error;
982 Query *query;
983
984 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
985
986 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
987 if (aInfo.mQueryType == kIp4AddressQuery)
988 {
989 NetworkData::ExternalRouteConfig nat64Prefix;
990
991 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
992 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
993 error = kErrorInvalidState);
994 }
995 #endif
996
997 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
998
999 mMainQueries.Enqueue(*query);
1000
1001 error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1002 VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1003
1004 if (aSecondType != kNoQuery)
1005 {
1006 Query *secondQuery;
1007
1008 aInfo.mQueryType = aSecondType;
1009 aInfo.mMessageId = 0;
1010 aInfo.mTransmissionCount = 0;
1011 aInfo.mMainQuery = query;
1012
1013 // We intentionally do not use `error` here so in the unlikely
1014 // case where we cannot allocate the second query we can proceed
1015 // with the first one.
1016 SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1017
1018 IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1019
1020 // Update first query to link to second one by updating
1021 // its `mNextQuery`.
1022 aInfo.ReadFrom(*query);
1023 aInfo.mNextQuery = secondQuery;
1024 UpdateQuery(*query, aInfo);
1025 }
1026
1027 exit:
1028 return error;
1029 }
1030
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1031 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1032 {
1033 Error error = kErrorNone;
1034
1035 aQuery = nullptr;
1036
1037 VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1038
1039 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1040 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1041
1042 SuccessOrExit(error = aQuery->Append(aInfo));
1043
1044 if (aLabel != nullptr)
1045 {
1046 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1047 }
1048
1049 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1050
1051 exit:
1052 FreeAndNullMessageOnError(aQuery, error);
1053 return error;
1054 }
1055
FindMainQuery(Query & aQuery)1056 Client::Query &Client::FindMainQuery(Query &aQuery)
1057 {
1058 QueryInfo info;
1059
1060 info.ReadFrom(aQuery);
1061
1062 return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1063 }
1064
FreeQuery(Query & aQuery)1065 void Client::FreeQuery(Query &aQuery)
1066 {
1067 Query &mainQuery = FindMainQuery(aQuery);
1068 QueryInfo info;
1069
1070 mMainQueries.Dequeue(mainQuery);
1071
1072 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1073 {
1074 info.ReadFrom(*query);
1075 FreeMessage(info.mSavedResponse);
1076 query->Free();
1077 }
1078 }
1079
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1080 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1081 {
1082 // This method prepares and sends a query message represented by
1083 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1084 // the new `mRetransmissionTime`) and updates it in `aQuery` as
1085 // well. `aUpdateTimer` indicates whether the timer should be
1086 // updated when query is sent or not (used in the case where timer
1087 // is handled by caller).
1088
1089 Error error = kErrorNone;
1090 Message *message = nullptr;
1091 Header header;
1092 Ip6::MessageInfo messageInfo;
1093 uint16_t length = 0;
1094
1095 aInfo.mTransmissionCount++;
1096 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1097
1098 if (aInfo.mMessageId == 0)
1099 {
1100 do
1101 {
1102 SuccessOrExit(error = header.SetRandomMessageId());
1103 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1104
1105 aInfo.mMessageId = header.GetMessageId();
1106 }
1107 else
1108 {
1109 header.SetMessageId(aInfo.mMessageId);
1110 }
1111
1112 header.SetType(Header::kTypeQuery);
1113 header.SetQueryType(Header::kQueryTypeStandard);
1114
1115 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1116 {
1117 header.SetRecursionDesiredFlag();
1118 }
1119
1120 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1121
1122 message = mSocket.NewMessage();
1123 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1124
1125 SuccessOrExit(error = message->Append(header));
1126
1127 // Prepare the question section.
1128
1129 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1130 {
1131 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1132 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1133 }
1134
1135 length = message->GetLength() - message->GetOffset();
1136
1137 if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1138 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1139 {
1140 // Check if query will fit into tcp buffer if not return error.
1141 VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1142 OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1143 error = kErrorNoBufs);
1144
1145 // In case of initialized connection check if connected peer and new query have the same address.
1146 if (mTcpState != kTcpUninitialized)
1147 {
1148 VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1149 error = kErrorFailed);
1150 }
1151
1152 switch (mTcpState)
1153 {
1154 case kTcpUninitialized:
1155 SuccessOrExit(error = InitTcpSocket());
1156 SuccessOrExit(
1157 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1158 mTcpState = kTcpConnecting;
1159 PrepareTcpMessage(*message);
1160 break;
1161 case kTcpConnectedIdle:
1162 PrepareTcpMessage(*message);
1163 SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1164 mTcpState = kTcpConnectedSending;
1165 break;
1166 case kTcpConnecting:
1167 PrepareTcpMessage(*message);
1168 break;
1169 case kTcpConnectedSending:
1170 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1171 SuccessOrAssert(error = message->Read(message->GetOffset(),
1172 (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1173 IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1174 break;
1175 }
1176 message->Free();
1177 message = nullptr;
1178 }
1179 #else
1180 {
1181 error = kErrorInvalidArgs;
1182 LogWarn("DNS query over TCP not supported.");
1183 ExitNow();
1184 }
1185 #endif
1186 else
1187 {
1188 VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1189 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1190 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1191 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1192 }
1193
1194 exit:
1195
1196 FreeMessageOnError(message, error);
1197 if (aUpdateTimer)
1198 {
1199 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1200 }
1201
1202 UpdateQuery(aQuery, aInfo);
1203
1204 return error;
1205 }
1206
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1207 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1208 {
1209 // The name is encoded and included after the `Info` in `aQuery`
1210 // starting at `kNameOffsetInQuery`.
1211
1212 return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1213 }
1214
FinalizeQuery(Query & aQuery,Error aError)1215 void Client::FinalizeQuery(Query &aQuery, Error aError)
1216 {
1217 Response response;
1218 Query &mainQuery = FindMainQuery(aQuery);
1219
1220 response.mInstance = &Get<Instance>();
1221 response.mQuery = &mainQuery;
1222
1223 FinalizeQuery(response, aError);
1224 }
1225
FinalizeQuery(Response & aResponse,Error aError)1226 void Client::FinalizeQuery(Response &aResponse, Error aError)
1227 {
1228 QueryType type;
1229 Callback callback;
1230 void *context;
1231
1232 GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1233
1234 switch (type)
1235 {
1236 case kIp6AddressQuery:
1237 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1238 case kIp4AddressQuery:
1239 #endif
1240 if (callback.mAddressCallback != nullptr)
1241 {
1242 callback.mAddressCallback(aError, &aResponse, context);
1243 }
1244 break;
1245
1246 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1247 case kBrowseQuery:
1248 if (callback.mBrowseCallback != nullptr)
1249 {
1250 callback.mBrowseCallback(aError, &aResponse, context);
1251 }
1252 break;
1253
1254 case kServiceQuerySrvTxt:
1255 case kServiceQuerySrv:
1256 case kServiceQueryTxt:
1257 if (callback.mServiceCallback != nullptr)
1258 {
1259 callback.mServiceCallback(aError, &aResponse, context);
1260 }
1261 break;
1262 #endif
1263 case kNoQuery:
1264 break;
1265 }
1266
1267 FreeQuery(*aResponse.mQuery);
1268 }
1269
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1270 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1271 {
1272 QueryInfo info;
1273
1274 info.ReadFrom(aQuery);
1275
1276 aType = info.mQueryType;
1277 aCallback = info.mCallback;
1278 aContext = info.mCallbackContext;
1279 }
1280
FindQueryById(uint16_t aMessageId)1281 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1282 {
1283 Query *matchedQuery = nullptr;
1284 QueryInfo info;
1285
1286 for (Query &mainQuery : mMainQueries)
1287 {
1288 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1289 {
1290 info.ReadFrom(*query);
1291
1292 if (info.mMessageId == aMessageId)
1293 {
1294 matchedQuery = query;
1295 ExitNow();
1296 }
1297 }
1298 }
1299
1300 exit:
1301 return matchedQuery;
1302 }
1303
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMsgInfo)1304 void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMsgInfo)
1305 {
1306 OT_UNUSED_VARIABLE(aMsgInfo);
1307 ProcessResponse(aMessage);
1308 }
1309
ProcessResponse(const Message & aResponseMessage)1310 void Client::ProcessResponse(const Message &aResponseMessage)
1311 {
1312 Error responseError;
1313 Query *query;
1314
1315 SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1316
1317 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1318 if (ReplaceWithIp4Query(*query, aResponseMessage) == kErrorNone)
1319 {
1320 ExitNow();
1321 }
1322 #endif
1323
1324 if (responseError != kErrorNone)
1325 {
1326 // Received an error from server, check if we can replace
1327 // the query.
1328
1329 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1330 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1331 {
1332 ExitNow();
1333 }
1334 #endif
1335
1336 FinalizeQuery(*query, responseError);
1337 ExitNow();
1338 }
1339
1340 // Received successful response from server.
1341
1342 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1343 ResolveHostAddressIfNeeded(*query, aResponseMessage);
1344 #endif
1345
1346 if (!CanFinalizeQuery(*query))
1347 {
1348 SaveQueryResponse(*query, aResponseMessage);
1349 ExitNow();
1350 }
1351
1352 PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1353
1354 exit:
1355 return;
1356 }
1357
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1358 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1359 {
1360 Error error = kErrorNone;
1361 uint16_t offset = aResponseMessage.GetOffset();
1362 Header header;
1363 QueryInfo info;
1364 Name queryName;
1365
1366 SuccessOrExit(error = aResponseMessage.Read(offset, header));
1367 offset += sizeof(Header);
1368
1369 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1370 !header.IsTruncationFlagSet(),
1371 error = kErrorDrop);
1372
1373 aQuery = FindQueryById(header.GetMessageId());
1374 VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1375
1376 info.ReadFrom(*aQuery);
1377
1378 queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1379
1380 // Check the Question Section
1381
1382 if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1383 {
1384 for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1385 {
1386 SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1387 offset += sizeof(Question);
1388 }
1389 }
1390 else
1391 {
1392 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1393 error = kErrorParse);
1394 }
1395
1396 // Check the answer, authority and additional record sections
1397
1398 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1399 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1400 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1401
1402 // Read the response code
1403
1404 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1405
1406 if ((aResponseError == kErrorNone) && (info.mQueryType == kServiceQuerySrvTxt))
1407 {
1408 RecordServerAsCapableOfMultiQuestions(info.mConfig.GetServerSockAddr().GetAddress());
1409 }
1410
1411 exit:
1412 return error;
1413 }
1414
CanFinalizeQuery(Query & aQuery)1415 bool Client::CanFinalizeQuery(Query &aQuery)
1416 {
1417 // Determines whether we can finalize a main query by checking if
1418 // we have received and saved responses for all other related
1419 // queries associated with `aQuery`. Note that this method is
1420 // called when we receive a response for `aQuery`, so no need to
1421 // check for a saved response for `aQuery` itself.
1422
1423 bool canFinalize = true;
1424 QueryInfo info;
1425
1426 for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1427 {
1428 info.ReadFrom(*query);
1429
1430 if (query == &aQuery)
1431 {
1432 continue;
1433 }
1434
1435 if (info.mSavedResponse == nullptr)
1436 {
1437 canFinalize = false;
1438 ExitNow();
1439 }
1440 }
1441
1442 exit:
1443 return canFinalize;
1444 }
1445
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1446 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1447 {
1448 QueryInfo info;
1449
1450 info.ReadFrom(aQuery);
1451 VerifyOrExit(info.mSavedResponse == nullptr);
1452
1453 // If `Clone()` fails we let retry or timeout handle the error.
1454 info.mSavedResponse = aResponseMessage.Clone();
1455
1456 UpdateQuery(aQuery, info);
1457
1458 exit:
1459 return;
1460 }
1461
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1462 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1463 {
1464 // Populate `aResponse` for `aQuery`. If there is a saved response
1465 // message for `aQuery` we use it, otherwise, we use
1466 // `aResponseMessage`.
1467
1468 QueryInfo info;
1469
1470 info.ReadFrom(aQuery);
1471
1472 aResponse.mInstance = &Get<Instance>();
1473 aResponse.mQuery = &aQuery;
1474 aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1475
1476 return info.mNextQuery;
1477 }
1478
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1479 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1480 {
1481 // This method prepares a list of chained `Response` instances
1482 // corresponding to all related (chained) queries. It uses
1483 // recursion to go through the queries and construct the
1484 // `Response` chain.
1485
1486 Response response;
1487 Query *nextQuery;
1488
1489 nextQuery = PopulateResponse(response, aQuery, aResponseMessage);
1490 response.mNext = aPrevResponse;
1491
1492 if (nextQuery != nullptr)
1493 {
1494 PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1495 }
1496 else
1497 {
1498 FinalizeQuery(response, kErrorNone);
1499 }
1500 }
1501
HandleTimer(void)1502 void Client::HandleTimer(void)
1503 {
1504 NextFireTime nextTime;
1505 QueryInfo info;
1506 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1507 bool hasTcpQuery = false;
1508 #endif
1509
1510 for (Query &mainQuery : mMainQueries)
1511 {
1512 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1513 {
1514 info.ReadFrom(*query);
1515
1516 if (info.mSavedResponse != nullptr)
1517 {
1518 continue;
1519 }
1520
1521 if (nextTime.GetNow() >= info.mRetransmissionTime)
1522 {
1523 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1524 {
1525 FinalizeQuery(*query, kErrorResponseTimeout);
1526 break;
1527 }
1528
1529 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1530 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1531 {
1532 LogInfo("Switching to separate SRV/TXT on response timeout");
1533 info.ReadFrom(*query);
1534 }
1535 else
1536 #endif
1537 {
1538 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1539 }
1540 }
1541
1542 nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1543
1544 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1545 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1546 {
1547 hasTcpQuery = true;
1548 }
1549 #endif
1550 }
1551 }
1552
1553 mTimer.FireAtIfEarlier(nextTime);
1554
1555 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1556 if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1557 {
1558 IgnoreError(mEndpoint.SendEndOfStream());
1559 }
1560 #endif
1561 }
1562
1563 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1564
ReplaceWithIp4Query(Query & aQuery,const Message & aResponseMessage)1565 Error Client::ReplaceWithIp4Query(Query &aQuery, const Message &aResponseMessage)
1566 {
1567 Error error = kErrorFailed;
1568 QueryInfo info;
1569 Header header;
1570
1571 info.ReadFrom(aQuery);
1572
1573 VerifyOrExit(info.mQueryType == kIp6AddressQuery);
1574 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1575
1576 // Check the response to the IPv6 query from the server. If the
1577 // response code is success but the answer section is empty
1578 // (indicating the name exists but has no IPv6 address), or the
1579 // response code indicates an error other than `NameError`, we
1580 // replace the query with an IPv4 address resolution query for
1581 // the same name. If the server responded with `NameError`
1582 // (RCode=3), it indicates that the name doesn't exist, so there
1583 // is no need to try an IPv4 query.
1584
1585 SuccessOrExit(aResponseMessage.Read(aResponseMessage.GetOffset(), header));
1586
1587 switch (header.GetResponseCode())
1588 {
1589 case Header::kResponseSuccess:
1590 VerifyOrExit(header.GetAnswerCount() == 0);
1591 OT_FALL_THROUGH;
1592
1593 default:
1594 break;
1595
1596 case Header::kResponseNameError:
1597 ExitNow();
1598 }
1599
1600 // We send a new query for IPv4 address resolution
1601 // for the same host name. We reuse the existing `aQuery`
1602 // instance and keep all the info but clear `mTransmissionCount`
1603 // and `mMessageId` (so that a new random message ID is
1604 // selected). The new `info` will be saved in the query in
1605 // `SendQuery()`. Note that the current query is still in the
1606 // `mMainQueries` list when `SendQuery()` selects a new random
1607 // message ID, so the existing message ID for this query will
1608 // not be reused.
1609
1610 info.mQueryType = kIp4AddressQuery;
1611 info.mMessageId = 0;
1612 info.mTransmissionCount = 0;
1613
1614 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1615 error = kErrorNone;
1616
1617 exit:
1618 return error;
1619 }
1620
1621 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1622
1623 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1624
CheckAndUpdateServiceMode(QueryConfig & aConfig,const QueryConfig * aRequestConfig) const1625 void Client::CheckAndUpdateServiceMode(QueryConfig &aConfig, const QueryConfig *aRequestConfig) const
1626 {
1627 // If the user explicitly requested "optimize" mode, we honor that
1628 // request. Otherwise, if "optimize" is chosen from the default
1629 // config, we check if the DNS server is known to have trouble
1630 // with multiple-question queries. If so, we switch to "separate"
1631 // mode.
1632
1633 if ((aRequestConfig != nullptr) && (aRequestConfig->GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize))
1634 {
1635 ExitNow();
1636 }
1637
1638 VerifyOrExit(aConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1639
1640 if (mLimitedQueryServers.Contains(aConfig.GetServerSockAddr().GetAddress()))
1641 {
1642 aConfig.SetServiceMode(QueryConfig::kServiceModeSrvTxtSeparate);
1643 }
1644
1645 exit:
1646 return;
1647 }
1648
RecordServerAsLimitedToSingleQuestion(const Ip6::Address & aServerAddress)1649 void Client::RecordServerAsLimitedToSingleQuestion(const Ip6::Address &aServerAddress)
1650 {
1651 VerifyOrExit(!aServerAddress.IsUnspecified());
1652
1653 VerifyOrExit(!mLimitedQueryServers.Contains(aServerAddress));
1654
1655 if (mLimitedQueryServers.IsFull())
1656 {
1657 uint8_t randomIndex = Random::NonCrypto::GetUint8InRange(0, mLimitedQueryServers.GetMaxSize());
1658
1659 mLimitedQueryServers.Remove(mLimitedQueryServers[randomIndex]);
1660 }
1661
1662 IgnoreError(mLimitedQueryServers.PushBack(aServerAddress));
1663
1664 exit:
1665 return;
1666 }
1667
RecordServerAsCapableOfMultiQuestions(const Ip6::Address & aServerAddress)1668 void Client::RecordServerAsCapableOfMultiQuestions(const Ip6::Address &aServerAddress)
1669 {
1670 Ip6::Address *entry = mLimitedQueryServers.Find(aServerAddress);
1671
1672 VerifyOrExit(entry != nullptr);
1673 mLimitedQueryServers.Remove(*entry);
1674
1675 exit:
1676 return;
1677 }
1678
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1679 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1680 {
1681 Error error = kErrorFailed;
1682 QueryInfo info;
1683 Query *secondQuery;
1684
1685 info.ReadFrom(aQuery);
1686
1687 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1688 VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1689
1690 RecordServerAsLimitedToSingleQuestion(info.mConfig.GetServerSockAddr().GetAddress());
1691
1692 secondQuery = aQuery.Clone();
1693 VerifyOrExit(secondQuery != nullptr);
1694
1695 info.mQueryType = kServiceQueryTxt;
1696 info.mMessageId = 0;
1697 info.mTransmissionCount = 0;
1698 info.mMainQuery = &aQuery;
1699 IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1700
1701 info.mQueryType = kServiceQuerySrv;
1702 info.mMessageId = 0;
1703 info.mTransmissionCount = 0;
1704 info.mNextQuery = secondQuery;
1705 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1706 error = kErrorNone;
1707
1708 exit:
1709 return error;
1710 }
1711
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1712 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1713 {
1714 QueryInfo info;
1715 Response response;
1716 ServiceInfo serviceInfo;
1717 char hostName[Name::kMaxNameSize];
1718
1719 info.ReadFrom(aQuery);
1720
1721 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1722 VerifyOrExit(info.mShouldResolveHostAddr);
1723
1724 PopulateResponse(response, aQuery, aResponseMessage);
1725
1726 ClearAllBytes(serviceInfo);
1727 serviceInfo.mHostNameBuffer = hostName;
1728 serviceInfo.mHostNameBufferSize = sizeof(hostName);
1729 SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1730
1731 // Check whether AAAA record for host address is provided in the SRV query response
1732
1733 if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1734 {
1735 Query *newQuery;
1736
1737 info.mQueryType = kIp6AddressQuery;
1738 info.mMessageId = 0;
1739 info.mTransmissionCount = 0;
1740 info.mMainQuery = &FindMainQuery(aQuery);
1741
1742 SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1743 IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1744
1745 // Update `aQuery` to be linked with new query (inserting
1746 // the `newQuery` into the linked-list after `aQuery`).
1747
1748 info.ReadFrom(aQuery);
1749 info.mNextQuery = newQuery;
1750 UpdateQuery(aQuery, info);
1751 }
1752
1753 exit:
1754 return;
1755 }
1756
1757 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1758
1759 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1760 void Client::PrepareTcpMessage(Message &aMessage)
1761 {
1762 uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1763
1764 // Prepending the DNS query with length of the packet according to RFC1035.
1765 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1766 SuccessOrAssert(
1767 aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1768 mSendLink.mLength += length + sizeof(uint16_t);
1769 }
1770
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1771 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1772 {
1773 OT_UNUSED_VARIABLE(aEndpoint);
1774 OT_UNUSED_VARIABLE(aData);
1775 OT_ASSERT(mTcpState == kTcpConnectedSending);
1776
1777 mSendLink.mLength = 0;
1778 mTcpState = kTcpConnectedIdle;
1779 }
1780
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1781 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1782 {
1783 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1784 }
1785
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1786 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1787 {
1788 OT_UNUSED_VARIABLE(aEndpoint);
1789 IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1790 mTcpState = kTcpConnectedSending;
1791 }
1792
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1793 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1794 {
1795 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1796 }
1797
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1798 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1799 size_t &aOffset,
1800 Message &aMessage,
1801 uint16_t aLength)
1802 {
1803 // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1804 // and copy the content into `aMessage`. As we read we can move
1805 // to the next `aLinkedBuffer` and update `aOffset`.
1806 // Returns:
1807 // - `kErrorNone` if `aLength` bytes are successfully read and
1808 // `aOffset` and `aLinkedBuffer` are updated.
1809 // - `kErrorNotFound` is not enough bytes available to read
1810 // from `aLinkedBuffer`.
1811 // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1812
1813 Error error = kErrorNone;
1814
1815 while (aLength > 0)
1816 {
1817 uint16_t bytesToRead = aLength;
1818
1819 VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1820
1821 if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1822 {
1823 bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1824 }
1825
1826 SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1827
1828 aLength -= bytesToRead;
1829 aOffset += bytesToRead;
1830
1831 if (aOffset == aLinkedBuffer->mLength)
1832 {
1833 aLinkedBuffer = aLinkedBuffer->mNext;
1834 aOffset = 0;
1835 }
1836 }
1837
1838 exit:
1839 return error;
1840 }
1841
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1842 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1843 size_t aBytesAvailable,
1844 bool aEndOfStream,
1845 size_t aBytesRemaining)
1846 {
1847 OT_UNUSED_VARIABLE(aEndpoint);
1848 OT_UNUSED_VARIABLE(aBytesRemaining);
1849
1850 Message *message = nullptr;
1851 size_t totalRead = 0;
1852 size_t offset = 0;
1853 const otLinkedBuffer *data;
1854
1855 if (aEndOfStream)
1856 {
1857 // Cleanup is done in disconnected callback.
1858 IgnoreError(mEndpoint.SendEndOfStream());
1859 }
1860
1861 SuccessOrExit(mEndpoint.ReceiveByReference(data));
1862 VerifyOrExit(data != nullptr);
1863
1864 message = mSocket.NewMessage();
1865 VerifyOrExit(message != nullptr);
1866
1867 while (aBytesAvailable > totalRead)
1868 {
1869 uint16_t length;
1870
1871 // Read the `length` field.
1872 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1873
1874 IgnoreError(message->Read(/* aOffset */ 0, length));
1875 length = BigEndian::HostSwap16(length);
1876
1877 // Try to read `length` bytes.
1878 IgnoreError(message->SetLength(0));
1879 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1880
1881 totalRead += length + sizeof(uint16_t);
1882
1883 // Now process the read message as query response.
1884 ProcessResponse(*message);
1885
1886 IgnoreError(message->SetLength(0));
1887
1888 // Loop again to see if we can read another response.
1889 }
1890
1891 exit:
1892 // Inform `mEndPoint` about the total read and processed bytes
1893 IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1894 FreeMessage(message);
1895 }
1896
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1897 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1898 size_t aBytesAvailable,
1899 bool aEndOfStream,
1900 size_t aBytesRemaining)
1901 {
1902 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1903 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1904 }
1905
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1906 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1907 {
1908 OT_UNUSED_VARIABLE(aEndpoint);
1909 OT_UNUSED_VARIABLE(aReason);
1910 QueryInfo info;
1911
1912 IgnoreError(mEndpoint.Deinitialize());
1913 mTcpState = kTcpUninitialized;
1914
1915 // Abort queries in case of connection failures
1916 for (Query &mainQuery : mMainQueries)
1917 {
1918 info.ReadFrom(mainQuery);
1919
1920 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1921 {
1922 FinalizeQuery(mainQuery, kErrorAbort);
1923 }
1924 }
1925 }
1926
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1927 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1928 {
1929 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1930 }
1931
1932 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1933
1934 } // namespace Dns
1935 } // namespace ot
1936
1937 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1938