1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43
44 /**
45 * @file
46 * This file implements the DNS client.
47 */
48
49 namespace ot {
50 namespace Dns {
51
52 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
53 using ot::Encoding::BigEndian::ReadUint16;
54 using ot::Encoding::BigEndian::WriteUint16;
55 #endif
56
57 RegisterLogModule("DnsClient");
58
59 //---------------------------------------------------------------------------------------------------------------------
60 // Client::QueryConfig
61
62 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
63
QueryConfig(InitMode aMode)64 Client::QueryConfig::QueryConfig(InitMode aMode)
65 {
66 OT_UNUSED_VARIABLE(aMode);
67
68 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
69 GetServerSockAddr().SetPort(kDefaultServerPort);
70 SetResponseTimeout(kDefaultResponseTimeout);
71 SetMaxTxAttempts(kDefaultMaxTxAttempts);
72 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
73 SetServiceMode(kDefaultServiceMode);
74 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
75 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
76 #endif
77 SetTransportProto(kDnsTransportUdp);
78 }
79
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)80 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
81 {
82 // This method sets the config from `aConfig` replacing any
83 // unspecified fields (value zero) with the fields from
84 // `aDefaultConfig`. If `aConfig` is `nullptr` then
85 // `aDefaultConfig` is used.
86
87 if (aConfig == nullptr)
88 {
89 *this = aDefaultConfig;
90 ExitNow();
91 }
92
93 *this = *aConfig;
94
95 if (GetServerSockAddr().GetAddress().IsUnspecified())
96 {
97 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
98 }
99
100 if (GetServerSockAddr().GetPort() == 0)
101 {
102 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
103 }
104
105 if (GetResponseTimeout() == 0)
106 {
107 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
108 }
109
110 if (GetMaxTxAttempts() == 0)
111 {
112 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
113 }
114
115 if (GetRecursionFlag() == kFlagUnspecified)
116 {
117 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
118 }
119
120 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
121 if (GetNat64Mode() == kNat64Unspecified)
122 {
123 SetNat64Mode(aDefaultConfig.GetNat64Mode());
124 }
125 #endif
126
127 if (GetServiceMode() == kServiceModeUnspecified)
128 {
129 SetServiceMode(aDefaultConfig.GetServiceMode());
130 }
131
132 if (GetTransportProto() == kDnsTransportUnspecified)
133 {
134 SetTransportProto(aDefaultConfig.GetTransportProto());
135 }
136
137 exit:
138 return;
139 }
140
141 //---------------------------------------------------------------------------------------------------------------------
142 // Client::Response
143
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const144 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
145 {
146 switch (aSection)
147 {
148 case kAnswerSection:
149 aOffset = mAnswerOffset;
150 aNumRecord = mAnswerRecordCount;
151 break;
152 case kAdditionalDataSection:
153 default:
154 aOffset = mAdditionalOffset;
155 aNumRecord = mAdditionalRecordCount;
156 break;
157 }
158 }
159
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const160 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
161 {
162 uint16_t offset = kNameOffsetInQuery;
163
164 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
165 }
166
CheckForHostNameAlias(Section aSection,Name & aHostName) const167 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
168 {
169 // If the response includes a CNAME record mapping the query host
170 // name to a canonical name, we update `aHostName` to the new alias
171 // name. Otherwise `aHostName` remains as before.
172
173 Error error;
174 uint16_t offset;
175 uint16_t numRecords;
176 CnameRecord cnameRecord;
177
178 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
179
180 SelectSection(aSection, offset, numRecords);
181 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182
183 switch (error)
184 {
185 case kErrorNone:
186 // A CNAME record was found. `offset` now points to after the
187 // last read byte within the `mMessage` into the `cnameRecord`
188 // (which is the start of the new canonical name).
189 aHostName.SetFromMessage(*mMessage, offset);
190 error = Name::ParseName(*mMessage, offset);
191 break;
192
193 case kErrorNotFound:
194 error = kErrorNone;
195 break;
196
197 default:
198 break;
199 }
200
201 exit:
202 return error;
203 }
204
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const205 Error Client::Response::FindHostAddress(Section aSection,
206 const Name &aHostName,
207 uint16_t aIndex,
208 Ip6::Address &aAddress,
209 uint32_t &aTtl) const
210 {
211 Error error;
212 uint16_t offset;
213 uint16_t numRecords;
214 Name name = aHostName;
215 AaaaRecord aaaaRecord;
216
217 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
218
219 SelectSection(aSection, offset, numRecords);
220 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
221 aAddress = aaaaRecord.GetAddress();
222 aTtl = aaaaRecord.GetTtl();
223
224 exit:
225 return error;
226 }
227
228 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
229
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const230 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
231 {
232 Error error;
233 uint16_t offset;
234 uint16_t numRecords;
235 Name name = aHostName;
236
237 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
238
239 SelectSection(aSection, offset, numRecords);
240 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
241
242 exit:
243 return error;
244 }
245
246 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
247
248 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
249
InitServiceInfo(ServiceInfo & aServiceInfo) const250 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
251 {
252 // This method initializes `aServiceInfo` setting all
253 // TTLs to zero and host name to empty string.
254
255 aServiceInfo.mTtl = 0;
256 aServiceInfo.mHostAddressTtl = 0;
257 aServiceInfo.mTxtDataTtl = 0;
258 aServiceInfo.mTxtDataTruncated = false;
259
260 AsCoreType(&aServiceInfo.mHostAddress).Clear();
261
262 if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
263 {
264 aServiceInfo.mHostNameBuffer[0] = '\0';
265 }
266 }
267
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const268 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
269 {
270 // This method searches for SRV record in the given `aSection`
271 // matching the record name against `aName`, and updates the
272 // `aServiceInfo` accordingly. It also searches for AAAA record
273 // for host name associated with the service (from SRV record).
274 // The search for AAAA record is always performed in Additional
275 // Data section (independent of the value given in `aSection`).
276
277 Error error = kErrorNone;
278 uint16_t offset;
279 uint16_t numRecords;
280 Name hostName;
281 SrvRecord srvRecord;
282
283 // A non-zero `mTtl` indicates that SRV record is already found
284 // and parsed from a previous response.
285 VerifyOrExit(aServiceInfo.mTtl == 0);
286
287 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
288
289 // Search for a matching SRV record
290 SelectSection(aSection, offset, numRecords);
291 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
292
293 aServiceInfo.mTtl = srvRecord.GetTtl();
294 aServiceInfo.mPort = srvRecord.GetPort();
295 aServiceInfo.mPriority = srvRecord.GetPriority();
296 aServiceInfo.mWeight = srvRecord.GetWeight();
297
298 hostName.SetFromMessage(*mMessage, offset);
299
300 if (aServiceInfo.mHostNameBuffer != nullptr)
301 {
302 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
303 aServiceInfo.mHostNameBufferSize));
304 }
305 else
306 {
307 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
308 }
309
310 // Search in additional section for AAAA record for the host name.
311
312 VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
313
314 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
315 aServiceInfo.mHostAddressTtl);
316
317 if (error == kErrorNotFound)
318 {
319 error = kErrorNone;
320 }
321
322 exit:
323 return error;
324 }
325
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const326 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
327 {
328 // This method searches a TXT record in the given `aSection`
329 // matching the record name against `aName` and updates the TXT
330 // related properties in `aServicesInfo`.
331 //
332 // If no match is found `mTxtDataTtl` (which is initialized to zero)
333 // remains unchanged to indicate this. In this case this method still
334 // returns `kErrorNone`.
335
336 Error error = kErrorNone;
337 uint16_t offset;
338 uint16_t numRecords;
339 TxtRecord txtRecord;
340
341 // A non-zero `mTxtDataTtl` indicates that TXT record is already
342 // found and parsed from a previous response.
343 VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
344
345 // A null `mTxtData` indicates that caller does not want to retrieve
346 // TXT data.
347 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
348
349 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
350
351 SelectSection(aSection, offset, numRecords);
352
353 aServiceInfo.mTxtDataTruncated = false;
354
355 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
356
357 error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
358
359 if (error == kErrorNoBufs)
360 {
361 error = kErrorNone;
362
363 // Mark `mTxtDataTruncated` to indicate that we could not read
364 // the full TXT record into the given `mTxtData` buffer.
365 aServiceInfo.mTxtDataTruncated = true;
366 }
367
368 SuccessOrExit(error);
369 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
370
371 exit:
372 if (error == kErrorNotFound)
373 {
374 error = kErrorNone;
375 }
376
377 return error;
378 }
379
380 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
381
PopulateFrom(const Message & aMessage)382 void Client::Response::PopulateFrom(const Message &aMessage)
383 {
384 // Populate `Response` with info from `aMessage`.
385
386 uint16_t offset = aMessage.GetOffset();
387 Header header;
388
389 mMessage = &aMessage;
390
391 IgnoreError(aMessage.Read(offset, header));
392 offset += sizeof(Header);
393
394 for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
395 {
396 IgnoreError(Name::ParseName(aMessage, offset));
397 offset += sizeof(Question);
398 }
399
400 mAnswerOffset = offset;
401 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
402 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
403 mAdditionalOffset = offset;
404 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
405
406 mAnswerRecordCount = header.GetAnswerCount();
407 mAdditionalRecordCount = header.GetAdditionalRecordCount();
408 }
409
410 //---------------------------------------------------------------------------------------------------------------------
411 // Client::AddressResponse
412
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const413 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
414 {
415 Error error = kErrorNone;
416 Name name(*mQuery, kNameOffsetInQuery);
417
418 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
419
420 // If the response is for an IPv4 address query or if it is an
421 // IPv6 address query response with no IPv6 address but with
422 // an IPv4 in its additional section, we read the IPv4 address
423 // and translate it to an IPv6 address.
424
425 QueryInfo info;
426
427 info.ReadFrom(*mQuery);
428
429 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
430 {
431 Section section;
432 ARecord aRecord;
433 NetworkData::ExternalRouteConfig nat64Prefix;
434
435 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
436 error = kErrorInvalidState);
437
438 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
439 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
440
441 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
442 aTtl = aRecord.GetTtl();
443
444 ExitNow();
445 }
446
447 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
448
449 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
450
451 exit:
452 return error;
453 }
454
455 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
456
457 //---------------------------------------------------------------------------------------------------------------------
458 // Client::BrowseResponse
459
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const460 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
461 {
462 Error error;
463 uint16_t offset;
464 uint16_t numRecords;
465 Name serviceName(*mQuery, kNameOffsetInQuery);
466 PtrRecord ptrRecord;
467
468 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
469
470 SelectSection(kAnswerSection, offset, numRecords);
471 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
472 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
473
474 exit:
475 return error;
476 }
477
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const478 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
479 {
480 Error error;
481 Name instanceName;
482
483 // Find a matching PTR record for the service instance label. Then
484 // search and read SRV, TXT and AAAA records in Additional Data
485 // section matching the same name to populate `aServiceInfo`.
486
487 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
488
489 InitServiceInfo(aServiceInfo);
490 SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
491 SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
492
493 if (aServiceInfo.mTxtDataTtl == 0)
494 {
495 aServiceInfo.mTxtDataSize = 0;
496 }
497
498 exit:
499 return error;
500 }
501
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const502 Error Client::BrowseResponse::GetHostAddress(const char *aHostName,
503 uint16_t aIndex,
504 Ip6::Address &aAddress,
505 uint32_t &aTtl) const
506 {
507 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
508 }
509
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const510 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
511 {
512 // This method searches within the Answer Section for a PTR record
513 // matching a given instance label @aInstanceLabel. If found, the
514 // `aName` is updated to return the name in the message.
515
516 Error error;
517 uint16_t offset;
518 Name serviceName(*mQuery, kNameOffsetInQuery);
519 uint16_t numRecords;
520 uint16_t labelOffset;
521 PtrRecord ptrRecord;
522
523 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
524
525 SelectSection(kAnswerSection, offset, numRecords);
526
527 for (; numRecords > 0; numRecords--)
528 {
529 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
530
531 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
532
533 if (error == kErrorNotFound)
534 {
535 // `ReadRecord()` updates `offset` to skip over a
536 // non-matching record.
537 continue;
538 }
539
540 SuccessOrExit(error);
541
542 // It is a PTR record. Check the first label to match the
543 // instance label.
544
545 labelOffset = offset;
546 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
547
548 if (error == kErrorNone)
549 {
550 aInstanceName.SetFromMessage(*mMessage, offset);
551 ExitNow();
552 }
553
554 VerifyOrExit(error == kErrorNotFound);
555
556 // Update offset to skip over the PTR record.
557 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
558 }
559
560 error = kErrorNotFound;
561
562 exit:
563 return error;
564 }
565
566 //---------------------------------------------------------------------------------------------------------------------
567 // Client::ServiceResponse
568
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const569 Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer,
570 uint8_t aLabelBufferSize,
571 char *aNameBuffer,
572 uint16_t aNameBufferSize) const
573 {
574 Error error;
575 uint16_t offset = kNameOffsetInQuery;
576
577 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
578
579 VerifyOrExit(aNameBuffer != nullptr);
580 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
581
582 exit:
583 return error;
584 }
585
GetServiceInfo(ServiceInfo & aServiceInfo) const586 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
587 {
588 // Search and read SRV, TXT records matching name from query.
589
590 Error error = kErrorNotFound;
591
592 InitServiceInfo(aServiceInfo);
593
594 for (const Response *response = this; response != nullptr; response = response->mNext)
595 {
596 Name name(*response->mQuery, kNameOffsetInQuery);
597 QueryInfo info;
598 Section srvSection;
599 Section txtSection;
600
601 info.ReadFrom(*response->mQuery);
602
603 switch (info.mQueryType)
604 {
605 case kIp6AddressQuery:
606 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
607 case kIp4AddressQuery:
608 #endif
609 IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
610 AsCoreType(&aServiceInfo.mHostAddress),
611 aServiceInfo.mHostAddressTtl));
612
613 continue; // to `for()` loop
614
615 case kServiceQuerySrvTxt:
616 case kServiceQuerySrv:
617 case kServiceQueryTxt:
618 break;
619
620 default:
621 continue;
622 }
623
624 // Determine from which section we should try to read the SRV and
625 // TXT records based on the query type.
626 //
627 // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
628 // only one record (SRV or TXT) in the answer section, but we
629 // still try to read the other records from additional data
630 // section in case server provided them.
631
632 srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
633 txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
634
635 error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
636
637 if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
638 {
639 error = kErrorNone;
640 }
641
642 SuccessOrExit(error);
643
644 SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
645 }
646
647 if (aServiceInfo.mTxtDataTtl == 0)
648 {
649 aServiceInfo.mTxtDataSize = 0;
650 }
651
652 exit:
653 return error;
654 }
655
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const656 Error Client::ServiceResponse::GetHostAddress(const char *aHostName,
657 uint16_t aIndex,
658 Ip6::Address &aAddress,
659 uint32_t &aTtl) const
660 {
661 Error error = kErrorNotFound;
662
663 for (const Response *response = this; response != nullptr; response = response->mNext)
664 {
665 Section section = kAdditionalDataSection;
666 QueryInfo info;
667
668 info.ReadFrom(*response->mQuery);
669
670 switch (info.mQueryType)
671 {
672 case kIp6AddressQuery:
673 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
674 case kIp4AddressQuery:
675 #endif
676 section = kAnswerSection;
677 break;
678
679 default:
680 break;
681 }
682
683 error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
684
685 if (error == kErrorNone)
686 {
687 break;
688 }
689 }
690
691 return error;
692 }
693
694 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
695
696 //---------------------------------------------------------------------------------------------------------------------
697 // Client
698
699 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
700 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
701 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
702 #endif
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
704 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
705 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
706 #endif
707
708 const uint8_t Client::kQuestionCount[] = {
709 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
710 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
711 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
712 #endif
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
714 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record
715 /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
716 /* kServiceQuerySrv -> */ 1, // SRV record only
717 /* kServiceQueryTxt -> */ 1, // TXT record only
718 #endif
719 };
720
721 const uint16_t *const Client::kQuestionRecordTypes[] = {
722 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
723 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
724 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
725 #endif
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
727 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
728 /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
729 /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0],
730 /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1],
731
732 #endif
733 };
734
Client(Instance & aInstance)735 Client::Client(Instance &aInstance)
736 : InstanceLocator(aInstance)
737 , mSocket(aInstance)
738 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
739 , mTcpState(kTcpUninitialized)
740 #endif
741 , mTimer(aInstance)
742 , mDefaultConfig(QueryConfig::kInitFromDefaults)
743 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
744 , mUserDidSetDefaultAddress(false)
745 #endif
746 {
747 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
748 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
749 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
750 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
751 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
752 static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
753 static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
754 static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
755 #endif
756 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
757 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
758 static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
759 static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
760 static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
761 #endif
762 }
763
Start(void)764 Error Client::Start(void)
765 {
766 Error error;
767
768 SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
769 SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
770
771 exit:
772 return error;
773 }
774
Stop(void)775 void Client::Stop(void)
776 {
777 Query *query;
778
779 while ((query = mMainQueries.GetHead()) != nullptr)
780 {
781 FinalizeQuery(*query, kErrorAbort);
782 }
783
784 IgnoreError(mSocket.Close());
785 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
786 if (mTcpState != kTcpUninitialized)
787 {
788 IgnoreError(mEndpoint.Deinitialize());
789 }
790 #endif
791 }
792
793 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)794 Error Client::InitTcpSocket(void)
795 {
796 Error error;
797 otTcpEndpointInitializeArgs endpointArgs;
798
799 memset(&endpointArgs, 0x00, sizeof(endpointArgs));
800 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
801 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
802 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
803 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
804 endpointArgs.mContext = this;
805 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
806 endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes);
807
808 mSendLink.mNext = nullptr;
809 mSendLink.mData = mSendBufferBytes;
810 mSendLink.mLength = 0;
811
812 SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
813 exit:
814 return error;
815 }
816 #endif
817
SetDefaultConfig(const QueryConfig & aQueryConfig)818 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
819 {
820 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
821
822 mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
823
824 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
825 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
826 UpdateDefaultConfigAddress();
827 #endif
828 }
829
ResetDefaultConfig(void)830 void Client::ResetDefaultConfig(void)
831 {
832 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
833
834 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
835 mUserDidSetDefaultAddress = false;
836 UpdateDefaultConfigAddress();
837 #endif
838 }
839
840 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)841 void Client::UpdateDefaultConfigAddress(void)
842 {
843 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
844
845 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
846 !srpServerAddr.IsUnspecified())
847 {
848 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
849 }
850 }
851 #endif
852
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)853 Error Client::ResolveAddress(const char *aHostName,
854 AddressCallback aCallback,
855 void *aContext,
856 const QueryConfig *aConfig)
857 {
858 QueryInfo info;
859
860 info.Clear();
861 info.mQueryType = kIp6AddressQuery;
862 info.mConfig.SetFrom(aConfig, mDefaultConfig);
863 info.mCallback.mAddressCallback = aCallback;
864 info.mCallbackContext = aContext;
865
866 return StartQuery(info, nullptr, aHostName);
867 }
868
869 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)870 Error Client::ResolveIp4Address(const char *aHostName,
871 AddressCallback aCallback,
872 void *aContext,
873 const QueryConfig *aConfig)
874 {
875 QueryInfo info;
876
877 info.Clear();
878 info.mQueryType = kIp4AddressQuery;
879 info.mConfig.SetFrom(aConfig, mDefaultConfig);
880 info.mCallback.mAddressCallback = aCallback;
881 info.mCallbackContext = aContext;
882
883 return StartQuery(info, nullptr, aHostName);
884 }
885 #endif
886
887 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
888
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)889 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
890 {
891 QueryInfo info;
892
893 info.Clear();
894 info.mQueryType = kBrowseQuery;
895 info.mConfig.SetFrom(aConfig, mDefaultConfig);
896 info.mCallback.mBrowseCallback = aCallback;
897 info.mCallbackContext = aContext;
898
899 return StartQuery(info, nullptr, aServiceName);
900 }
901
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)902 Error Client::ResolveService(const char *aInstanceLabel,
903 const char *aServiceName,
904 ServiceCallback aCallback,
905 void *aContext,
906 const QueryConfig *aConfig)
907 {
908 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
909 }
910
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)911 Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel,
912 const char *aServiceName,
913 ServiceCallback aCallback,
914 void *aContext,
915 const QueryConfig *aConfig)
916 {
917 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
918 }
919
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)920 Error Client::Resolve(const char *aInstanceLabel,
921 const char *aServiceName,
922 ServiceCallback aCallback,
923 void *aContext,
924 const QueryConfig *aConfig,
925 bool aShouldResolveHostAddr)
926 {
927 QueryInfo info;
928 Error error;
929 QueryType secondQueryType = kNoQuery;
930
931 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
932
933 info.Clear();
934
935 info.mConfig.SetFrom(aConfig, mDefaultConfig);
936 info.mShouldResolveHostAddr = aShouldResolveHostAddr;
937
938 switch (info.mConfig.GetServiceMode())
939 {
940 case QueryConfig::kServiceModeSrvTxtSeparate:
941 secondQueryType = kServiceQueryTxt;
942
943 OT_FALL_THROUGH;
944
945 case QueryConfig::kServiceModeSrv:
946 info.mQueryType = kServiceQuerySrv;
947 break;
948
949 case QueryConfig::kServiceModeTxt:
950 info.mQueryType = kServiceQueryTxt;
951 VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
952 break;
953
954 case QueryConfig::kServiceModeSrvTxt:
955 case QueryConfig::kServiceModeSrvTxtOptimize:
956 default:
957 info.mQueryType = kServiceQuerySrvTxt;
958 break;
959 }
960
961 info.mCallback.mServiceCallback = aCallback;
962 info.mCallbackContext = aContext;
963
964 error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
965
966 exit:
967 return error;
968 }
969
970 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
971
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)972 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
973 {
974 // The `aLabel` can be `nullptr` and then `aName` provides the
975 // full name, otherwise the name is appended as `{aLabel}.
976 // {aName}`.
977
978 Error error;
979 Query *query;
980
981 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
982
983 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
984 if (aInfo.mQueryType == kIp4AddressQuery)
985 {
986 NetworkData::ExternalRouteConfig nat64Prefix;
987
988 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
989 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
990 error = kErrorInvalidState);
991 }
992 #endif
993
994 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
995
996 mMainQueries.Enqueue(*query);
997
998 error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
999 VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1000
1001 if (aSecondType != kNoQuery)
1002 {
1003 Query *secondQuery;
1004
1005 aInfo.mQueryType = aSecondType;
1006 aInfo.mMessageId = 0;
1007 aInfo.mTransmissionCount = 0;
1008 aInfo.mMainQuery = query;
1009
1010 // We intentionally do not use `error` here so in the unlikely
1011 // case where we cannot allocate the second query we can proceed
1012 // with the first one.
1013 SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1014
1015 IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1016
1017 // Update first query to link to second one by updating
1018 // its `mNextQuery`.
1019 aInfo.ReadFrom(*query);
1020 aInfo.mNextQuery = secondQuery;
1021 UpdateQuery(*query, aInfo);
1022 }
1023
1024 exit:
1025 return error;
1026 }
1027
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1028 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1029 {
1030 Error error = kErrorNone;
1031
1032 aQuery = nullptr;
1033
1034 VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1035
1036 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1037 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1038
1039 SuccessOrExit(error = aQuery->Append(aInfo));
1040
1041 if (aLabel != nullptr)
1042 {
1043 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1044 }
1045
1046 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1047
1048 exit:
1049 FreeAndNullMessageOnError(aQuery, error);
1050 return error;
1051 }
1052
FindMainQuery(Query & aQuery)1053 Client::Query &Client::FindMainQuery(Query &aQuery)
1054 {
1055 QueryInfo info;
1056
1057 info.ReadFrom(aQuery);
1058
1059 return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1060 }
1061
FreeQuery(Query & aQuery)1062 void Client::FreeQuery(Query &aQuery)
1063 {
1064 Query &mainQuery = FindMainQuery(aQuery);
1065 QueryInfo info;
1066
1067 mMainQueries.Dequeue(mainQuery);
1068
1069 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1070 {
1071 info.ReadFrom(*query);
1072 FreeMessage(info.mSavedResponse);
1073 query->Free();
1074 }
1075 }
1076
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1077 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1078 {
1079 // This method prepares and sends a query message represented by
1080 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1081 // the new `mRetransmissionTime`) and updates it in `aQuery` as
1082 // well. `aUpdateTimer` indicates whether the timer should be
1083 // updated when query is sent or not (used in the case where timer
1084 // is handled by caller).
1085
1086 Error error = kErrorNone;
1087 Message *message = nullptr;
1088 Header header;
1089 Ip6::MessageInfo messageInfo;
1090 uint16_t length = 0;
1091
1092 aInfo.mTransmissionCount++;
1093 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1094
1095 if (aInfo.mMessageId == 0)
1096 {
1097 do
1098 {
1099 SuccessOrExit(error = header.SetRandomMessageId());
1100 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1101
1102 aInfo.mMessageId = header.GetMessageId();
1103 }
1104 else
1105 {
1106 header.SetMessageId(aInfo.mMessageId);
1107 }
1108
1109 header.SetType(Header::kTypeQuery);
1110 header.SetQueryType(Header::kQueryTypeStandard);
1111
1112 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1113 {
1114 header.SetRecursionDesiredFlag();
1115 }
1116
1117 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1118
1119 message = mSocket.NewMessage();
1120 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1121
1122 SuccessOrExit(error = message->Append(header));
1123
1124 // Prepare the question section.
1125
1126 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1127 {
1128 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1129 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1130 }
1131
1132 length = message->GetLength() - message->GetOffset();
1133
1134 if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1135 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1136 {
1137 // Check if query will fit into tcp buffer if not return error.
1138 VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1139 OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1140 error = kErrorNoBufs);
1141
1142 // In case of initialized connection check if connected peer and new query have the same address.
1143 if (mTcpState != kTcpUninitialized)
1144 {
1145 VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1146 error = kErrorFailed);
1147 }
1148
1149 switch (mTcpState)
1150 {
1151 case kTcpUninitialized:
1152 SuccessOrExit(error = InitTcpSocket());
1153 SuccessOrExit(
1154 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1155 mTcpState = kTcpConnecting;
1156 PrepareTcpMessage(*message);
1157 break;
1158 case kTcpConnectedIdle:
1159 PrepareTcpMessage(*message);
1160 SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1161 mTcpState = kTcpConnectedSending;
1162 break;
1163 case kTcpConnecting:
1164 PrepareTcpMessage(*message);
1165 break;
1166 case kTcpConnectedSending:
1167 WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1168 SuccessOrAssert(error = message->Read(message->GetOffset(),
1169 (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1170 IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1171 break;
1172 }
1173 message->Free();
1174 message = nullptr;
1175 }
1176 #else
1177 {
1178 error = kErrorInvalidArgs;
1179 LogWarn("DNS query over TCP not supported.");
1180 ExitNow();
1181 }
1182 #endif
1183 else
1184 {
1185 VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1186 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1187 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1188 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1189 }
1190
1191 exit:
1192
1193 FreeMessageOnError(message, error);
1194 if (aUpdateTimer)
1195 {
1196 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1197 }
1198
1199 UpdateQuery(aQuery, aInfo);
1200
1201 return error;
1202 }
1203
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1204 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1205 {
1206 // The name is encoded and included after the `Info` in `aQuery`
1207 // starting at `kNameOffsetInQuery`.
1208
1209 return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1210 }
1211
FinalizeQuery(Query & aQuery,Error aError)1212 void Client::FinalizeQuery(Query &aQuery, Error aError)
1213 {
1214 Response response;
1215 Query &mainQuery = FindMainQuery(aQuery);
1216
1217 response.mInstance = &Get<Instance>();
1218 response.mQuery = &mainQuery;
1219
1220 FinalizeQuery(response, aError);
1221 }
1222
FinalizeQuery(Response & aResponse,Error aError)1223 void Client::FinalizeQuery(Response &aResponse, Error aError)
1224 {
1225 QueryType type;
1226 Callback callback;
1227 void *context;
1228
1229 GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1230
1231 switch (type)
1232 {
1233 case kIp6AddressQuery:
1234 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1235 case kIp4AddressQuery:
1236 #endif
1237 if (callback.mAddressCallback != nullptr)
1238 {
1239 callback.mAddressCallback(aError, &aResponse, context);
1240 }
1241 break;
1242
1243 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1244 case kBrowseQuery:
1245 if (callback.mBrowseCallback != nullptr)
1246 {
1247 callback.mBrowseCallback(aError, &aResponse, context);
1248 }
1249 break;
1250
1251 case kServiceQuerySrvTxt:
1252 case kServiceQuerySrv:
1253 case kServiceQueryTxt:
1254 if (callback.mServiceCallback != nullptr)
1255 {
1256 callback.mServiceCallback(aError, &aResponse, context);
1257 }
1258 break;
1259 #endif
1260 case kNoQuery:
1261 break;
1262 }
1263
1264 FreeQuery(*aResponse.mQuery);
1265 }
1266
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1267 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1268 {
1269 QueryInfo info;
1270
1271 info.ReadFrom(aQuery);
1272
1273 aType = info.mQueryType;
1274 aCallback = info.mCallback;
1275 aContext = info.mCallbackContext;
1276 }
1277
FindQueryById(uint16_t aMessageId)1278 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1279 {
1280 Query *matchedQuery = nullptr;
1281 QueryInfo info;
1282
1283 for (Query &mainQuery : mMainQueries)
1284 {
1285 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1286 {
1287 info.ReadFrom(*query);
1288
1289 if (info.mMessageId == aMessageId)
1290 {
1291 matchedQuery = query;
1292 ExitNow();
1293 }
1294 }
1295 }
1296
1297 exit:
1298 return matchedQuery;
1299 }
1300
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1301 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1302 {
1303 OT_UNUSED_VARIABLE(aMsgInfo);
1304
1305 static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1306 }
1307
ProcessResponse(const Message & aResponseMessage)1308 void Client::ProcessResponse(const Message &aResponseMessage)
1309 {
1310 Error responseError;
1311 Query *query;
1312
1313 SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1314
1315 if (responseError != kErrorNone)
1316 {
1317 // Received an error from server, check if we can replace
1318 // the query.
1319
1320 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1321 if (ReplaceWithIp4Query(*query) == kErrorNone)
1322 {
1323 ExitNow();
1324 }
1325 #endif
1326 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1327 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1328 {
1329 ExitNow();
1330 }
1331 #endif
1332
1333 FinalizeQuery(*query, responseError);
1334 ExitNow();
1335 }
1336
1337 // Received successful response from server.
1338
1339 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1340 ResolveHostAddressIfNeeded(*query, aResponseMessage);
1341 #endif
1342
1343 if (!CanFinalizeQuery(*query))
1344 {
1345 SaveQueryResponse(*query, aResponseMessage);
1346 ExitNow();
1347 }
1348
1349 PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1350
1351 exit:
1352 return;
1353 }
1354
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1355 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1356 {
1357 Error error = kErrorNone;
1358 uint16_t offset = aResponseMessage.GetOffset();
1359 Header header;
1360 QueryInfo info;
1361 Name queryName;
1362
1363 SuccessOrExit(error = aResponseMessage.Read(offset, header));
1364 offset += sizeof(Header);
1365
1366 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1367 !header.IsTruncationFlagSet(),
1368 error = kErrorDrop);
1369
1370 aQuery = FindQueryById(header.GetMessageId());
1371 VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1372
1373 info.ReadFrom(*aQuery);
1374
1375 queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1376
1377 // Check the Question Section
1378
1379 if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1380 {
1381 for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1382 {
1383 SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1384 offset += sizeof(Question);
1385 }
1386 }
1387 else
1388 {
1389 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1390 error = kErrorParse);
1391 }
1392
1393 // Check the answer, authority and additional record sections
1394
1395 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1396 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1397 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1398
1399 // Read the response code
1400
1401 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1402
1403 exit:
1404 return error;
1405 }
1406
CanFinalizeQuery(Query & aQuery)1407 bool Client::CanFinalizeQuery(Query &aQuery)
1408 {
1409 // Determines whether we can finalize a main query by checking if
1410 // we have received and saved responses for all other related
1411 // queries associated with `aQuery`. Note that this method is
1412 // called when we receive a response for `aQuery`, so no need to
1413 // check for a saved response for `aQuery` itself.
1414
1415 bool canFinalize = true;
1416 QueryInfo info;
1417
1418 for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1419 {
1420 info.ReadFrom(*query);
1421
1422 if (query == &aQuery)
1423 {
1424 continue;
1425 }
1426
1427 if (info.mSavedResponse == nullptr)
1428 {
1429 canFinalize = false;
1430 ExitNow();
1431 }
1432 }
1433
1434 exit:
1435 return canFinalize;
1436 }
1437
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1438 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1439 {
1440 QueryInfo info;
1441
1442 info.ReadFrom(aQuery);
1443 VerifyOrExit(info.mSavedResponse == nullptr);
1444
1445 // If `Clone()` fails we let retry or timeout handle the error.
1446 info.mSavedResponse = aResponseMessage.Clone();
1447
1448 UpdateQuery(aQuery, info);
1449
1450 exit:
1451 return;
1452 }
1453
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1454 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1455 {
1456 // Populate `aResponse` for `aQuery`. If there is a saved response
1457 // message for `aQuery` we use it, otherwise, we use
1458 // `aResponseMessage`.
1459
1460 QueryInfo info;
1461
1462 info.ReadFrom(aQuery);
1463
1464 aResponse.mInstance = &Get<Instance>();
1465 aResponse.mQuery = &aQuery;
1466 aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1467
1468 return info.mNextQuery;
1469 }
1470
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1471 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1472 {
1473 // This method prepares a list of chained `Response` instances
1474 // corresponding to all related (chained) queries. It uses
1475 // recursion to go through the queries and construct the
1476 // `Response` chain.
1477
1478 Response response;
1479 Query *nextQuery;
1480
1481 nextQuery = PopulateResponse(response, aQuery, aResponseMessage);
1482 response.mNext = aPrevResponse;
1483
1484 if (nextQuery != nullptr)
1485 {
1486 PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1487 }
1488 else
1489 {
1490 FinalizeQuery(response, kErrorNone);
1491 }
1492 }
1493
HandleTimer(void)1494 void Client::HandleTimer(void)
1495 {
1496 TimeMilli now = TimerMilli::GetNow();
1497 TimeMilli nextTime = now.GetDistantFuture();
1498 QueryInfo info;
1499 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1500 bool hasTcpQuery = false;
1501 #endif
1502
1503 for (Query &mainQuery : mMainQueries)
1504 {
1505 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1506 {
1507 info.ReadFrom(*query);
1508
1509 if (info.mSavedResponse != nullptr)
1510 {
1511 continue;
1512 }
1513
1514 if (now >= info.mRetransmissionTime)
1515 {
1516 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1517 {
1518 FinalizeQuery(*query, kErrorResponseTimeout);
1519 break;
1520 }
1521
1522 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1523 }
1524
1525 if (nextTime > info.mRetransmissionTime)
1526 {
1527 nextTime = info.mRetransmissionTime;
1528 }
1529
1530 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1531 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1532 {
1533 hasTcpQuery = true;
1534 }
1535 #endif
1536 }
1537 }
1538
1539 if (nextTime < now.GetDistantFuture())
1540 {
1541 mTimer.FireAt(nextTime);
1542 }
1543
1544 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1545 if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1546 {
1547 IgnoreError(mEndpoint.SendEndOfStream());
1548 }
1549 #endif
1550 }
1551
1552 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1553
ReplaceWithIp4Query(Query & aQuery)1554 Error Client::ReplaceWithIp4Query(Query &aQuery)
1555 {
1556 Error error = kErrorFailed;
1557 QueryInfo info;
1558
1559 info.ReadFrom(aQuery);
1560
1561 VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1562 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1563
1564 // We send a new query for IPv4 address resolution
1565 // for the same host name. We reuse the existing `aQuery`
1566 // instance and keep all the info but clear `mTransmissionCount`
1567 // and `mMessageId` (so that a new random message ID is
1568 // selected). The new `info` will be saved in the query in
1569 // `SendQuery()`. Note that the current query is still in the
1570 // `mMainQueries` list when `SendQuery()` selects a new random
1571 // message ID, so the existing message ID for this query will
1572 // not be reused.
1573
1574 info.mQueryType = kIp4AddressQuery;
1575 info.mMessageId = 0;
1576 info.mTransmissionCount = 0;
1577
1578 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1579 error = kErrorNone;
1580
1581 exit:
1582 return error;
1583 }
1584
1585 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1586
1587 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1588
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1589 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1590 {
1591 Error error = kErrorFailed;
1592 QueryInfo info;
1593 Query *secondQuery;
1594
1595 info.ReadFrom(aQuery);
1596
1597 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1598 VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1599
1600 secondQuery = aQuery.Clone();
1601 VerifyOrExit(secondQuery != nullptr);
1602
1603 info.mQueryType = kServiceQueryTxt;
1604 info.mMessageId = 0;
1605 info.mTransmissionCount = 0;
1606 info.mMainQuery = &aQuery;
1607 IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1608
1609 info.mQueryType = kServiceQuerySrv;
1610 info.mMessageId = 0;
1611 info.mTransmissionCount = 0;
1612 info.mNextQuery = secondQuery;
1613 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1614 error = kErrorNone;
1615
1616 exit:
1617 return error;
1618 }
1619
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1620 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1621 {
1622 QueryInfo info;
1623 Response response;
1624 ServiceInfo serviceInfo;
1625 char hostName[Name::kMaxNameSize];
1626
1627 info.ReadFrom(aQuery);
1628
1629 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1630 VerifyOrExit(info.mShouldResolveHostAddr);
1631
1632 PopulateResponse(response, aQuery, aResponseMessage);
1633
1634 memset(&serviceInfo, 0, sizeof(serviceInfo));
1635 serviceInfo.mHostNameBuffer = hostName;
1636 serviceInfo.mHostNameBufferSize = sizeof(hostName);
1637 SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1638
1639 // Check whether AAAA record for host address is provided in the SRV query response
1640
1641 if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1642 {
1643 Query *newQuery;
1644
1645 info.mQueryType = kIp6AddressQuery;
1646 info.mMessageId = 0;
1647 info.mTransmissionCount = 0;
1648 info.mMainQuery = &FindMainQuery(aQuery);
1649
1650 SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1651 IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1652
1653 // Update `aQuery` to be linked with new query (inserting
1654 // the `newQuery` into the linked-list after `aQuery`).
1655
1656 info.ReadFrom(aQuery);
1657 info.mNextQuery = newQuery;
1658 UpdateQuery(aQuery, info);
1659 }
1660
1661 exit:
1662 return;
1663 }
1664
1665 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1666
1667 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1668 void Client::PrepareTcpMessage(Message &aMessage)
1669 {
1670 uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1671
1672 // Prepending the DNS query with length of the packet according to RFC1035.
1673 WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1674 SuccessOrAssert(
1675 aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1676 mSendLink.mLength += length + sizeof(uint16_t);
1677 }
1678
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1679 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1680 {
1681 OT_UNUSED_VARIABLE(aEndpoint);
1682 OT_UNUSED_VARIABLE(aData);
1683 OT_ASSERT(mTcpState == kTcpConnectedSending);
1684
1685 mSendLink.mLength = 0;
1686 mTcpState = kTcpConnectedIdle;
1687 }
1688
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1689 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1690 {
1691 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1692 }
1693
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1694 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1695 {
1696 OT_UNUSED_VARIABLE(aEndpoint);
1697 IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1698 mTcpState = kTcpConnectedSending;
1699 }
1700
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1701 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1702 {
1703 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1704 }
1705
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1706 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1707 size_t &aOffset,
1708 Message &aMessage,
1709 uint16_t aLength)
1710 {
1711 // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1712 // and copy the content into `aMessage`. As we read we can move
1713 // to the next `aLinkedBuffer` and update `aOffset`.
1714 // Returns:
1715 // - `kErrorNone` if `aLength` bytes are successfully read and
1716 // `aOffset` and `aLinkedBuffer` are updated.
1717 // - `kErrorNotFound` is not enough bytes available to read
1718 // from `aLinkedBuffer`.
1719 // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1720
1721 Error error = kErrorNone;
1722
1723 while (aLength > 0)
1724 {
1725 uint16_t bytesToRead = aLength;
1726
1727 VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1728
1729 if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1730 {
1731 bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1732 }
1733
1734 SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1735
1736 aLength -= bytesToRead;
1737 aOffset += bytesToRead;
1738
1739 if (aOffset == aLinkedBuffer->mLength)
1740 {
1741 aLinkedBuffer = aLinkedBuffer->mNext;
1742 aOffset = 0;
1743 }
1744 }
1745
1746 exit:
1747 return error;
1748 }
1749
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1750 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1751 size_t aBytesAvailable,
1752 bool aEndOfStream,
1753 size_t aBytesRemaining)
1754 {
1755 OT_UNUSED_VARIABLE(aEndpoint);
1756 OT_UNUSED_VARIABLE(aBytesRemaining);
1757
1758 Message *message = nullptr;
1759 size_t totalRead = 0;
1760 size_t offset = 0;
1761 const otLinkedBuffer *data;
1762
1763 if (aEndOfStream)
1764 {
1765 // Cleanup is done in disconnected callback.
1766 IgnoreError(mEndpoint.SendEndOfStream());
1767 }
1768
1769 SuccessOrExit(mEndpoint.ReceiveByReference(data));
1770 VerifyOrExit(data != nullptr);
1771
1772 message = mSocket.NewMessage();
1773 VerifyOrExit(message != nullptr);
1774
1775 while (aBytesAvailable > totalRead)
1776 {
1777 uint16_t length;
1778
1779 // Read the `length` field.
1780 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1781
1782 IgnoreError(message->Read(/* aOffset */ 0, length));
1783 length = HostSwap16(length);
1784
1785 // Try to read `length` bytes.
1786 IgnoreError(message->SetLength(0));
1787 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1788
1789 totalRead += length + sizeof(uint16_t);
1790
1791 // Now process the read message as query response.
1792 ProcessResponse(*message);
1793
1794 IgnoreError(message->SetLength(0));
1795
1796 // Loop again to see if we can read another response.
1797 }
1798
1799 exit:
1800 // Inform `mEndPoint` about the total read and processed bytes
1801 IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1802 FreeMessage(message);
1803 }
1804
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1805 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1806 size_t aBytesAvailable,
1807 bool aEndOfStream,
1808 size_t aBytesRemaining)
1809 {
1810 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1811 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1812 }
1813
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1814 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1815 {
1816 OT_UNUSED_VARIABLE(aEndpoint);
1817 OT_UNUSED_VARIABLE(aReason);
1818 QueryInfo info;
1819
1820 IgnoreError(mEndpoint.Deinitialize());
1821 mTcpState = kTcpUninitialized;
1822
1823 // Abort queries in case of connection failures
1824 for (Query &mainQuery : mMainQueries)
1825 {
1826 info.ReadFrom(mainQuery);
1827
1828 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1829 {
1830 FinalizeQuery(mainQuery, kErrorAbort);
1831 }
1832 }
1833 }
1834
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1835 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1836 {
1837 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1838 }
1839
1840 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1841
1842 } // namespace Dns
1843 } // namespace ot
1844
1845 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1846