1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43
44 /**
45 * @file
46 * This file implements the DNS client.
47 */
48
49 namespace ot {
50 namespace Dns {
51
52 RegisterLogModule("DnsClient");
53
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61 OT_UNUSED_VARIABLE(aMode);
62
63 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64 GetServerSockAddr().SetPort(kDefaultServerPort);
65 SetResponseTimeout(kDefaultResponseTimeout);
66 SetMaxTxAttempts(kDefaultMaxTxAttempts);
67 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68 SetServiceMode(kDefaultServiceMode);
69 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
70 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
71 #endif
72 SetTransportProto(kDnsTransportUdp);
73 }
74
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)75 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
76 {
77 // This method sets the config from `aConfig` replacing any
78 // unspecified fields (value zero) with the fields from
79 // `aDefaultConfig`. If `aConfig` is `nullptr` then
80 // `aDefaultConfig` is used.
81
82 if (aConfig == nullptr)
83 {
84 *this = aDefaultConfig;
85 ExitNow();
86 }
87
88 *this = *aConfig;
89
90 if (GetServerSockAddr().GetAddress().IsUnspecified())
91 {
92 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
93 }
94
95 if (GetServerSockAddr().GetPort() == 0)
96 {
97 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
98 }
99
100 if (GetResponseTimeout() == 0)
101 {
102 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
103 }
104
105 if (GetMaxTxAttempts() == 0)
106 {
107 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
108 }
109
110 if (GetRecursionFlag() == kFlagUnspecified)
111 {
112 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
113 }
114
115 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
116 if (GetNat64Mode() == kNat64Unspecified)
117 {
118 SetNat64Mode(aDefaultConfig.GetNat64Mode());
119 }
120 #endif
121
122 if (GetServiceMode() == kServiceModeUnspecified)
123 {
124 SetServiceMode(aDefaultConfig.GetServiceMode());
125 }
126
127 if (GetTransportProto() == kDnsTransportUnspecified)
128 {
129 SetTransportProto(aDefaultConfig.GetTransportProto());
130 }
131
132 exit:
133 return;
134 }
135
136 //---------------------------------------------------------------------------------------------------------------------
137 // Client::Response
138
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const139 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
140 {
141 switch (aSection)
142 {
143 case kAnswerSection:
144 aOffset = mAnswerOffset;
145 aNumRecord = mAnswerRecordCount;
146 break;
147 case kAdditionalDataSection:
148 default:
149 aOffset = mAdditionalOffset;
150 aNumRecord = mAdditionalRecordCount;
151 break;
152 }
153 }
154
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const155 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
156 {
157 uint16_t offset = kNameOffsetInQuery;
158
159 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
160 }
161
CheckForHostNameAlias(Section aSection,Name & aHostName) const162 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
163 {
164 // If the response includes a CNAME record mapping the query host
165 // name to a canonical name, we update `aHostName` to the new alias
166 // name. Otherwise `aHostName` remains as before. This method handles
167 // when there are multiple CNAME records mapping the host name multiple
168 // times. We limit number of changes to `kMaxCnameAliasNameChanges`
169 // to detect and handle if the response contains CNAME record loops.
170
171 Error error;
172 uint16_t offset;
173 uint16_t numRecords;
174 CnameRecord cnameRecord;
175
176 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
177
178 for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
179 {
180 SelectSection(aSection, offset, numRecords);
181 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182
183 if (error == kErrorNotFound)
184 {
185 error = kErrorNone;
186 ExitNow();
187 }
188
189 SuccessOrExit(error);
190
191 // A CNAME record was found. `offset` now points to after the
192 // last read byte within the `mMessage` into the `cnameRecord`
193 // (which is the start of the new canonical name).
194
195 aHostName.SetFromMessage(*mMessage, offset);
196 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
197
198 // Loop back to check if there may be a CNAME record for the
199 // new `aHostName`.
200 }
201
202 error = kErrorParse;
203
204 exit:
205 return error;
206 }
207
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const208 Error Client::Response::FindHostAddress(Section aSection,
209 const Name &aHostName,
210 uint16_t aIndex,
211 Ip6::Address &aAddress,
212 uint32_t &aTtl) const
213 {
214 Error error;
215 uint16_t offset;
216 uint16_t numRecords;
217 Name name = aHostName;
218 AaaaRecord aaaaRecord;
219
220 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
221
222 SelectSection(aSection, offset, numRecords);
223 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
224 aAddress = aaaaRecord.GetAddress();
225 aTtl = aaaaRecord.GetTtl();
226
227 exit:
228 return error;
229 }
230
231 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
232
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const233 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
234 {
235 Error error;
236 uint16_t offset;
237 uint16_t numRecords;
238 Name name = aHostName;
239
240 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
241
242 SelectSection(aSection, offset, numRecords);
243 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
244
245 exit:
246 return error;
247 }
248
249 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
250
251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
252
InitServiceInfo(ServiceInfo & aServiceInfo) const253 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
254 {
255 // This method initializes `aServiceInfo` setting all
256 // TTLs to zero and host name to empty string.
257
258 aServiceInfo.mTtl = 0;
259 aServiceInfo.mHostAddressTtl = 0;
260 aServiceInfo.mTxtDataTtl = 0;
261 aServiceInfo.mTxtDataTruncated = false;
262
263 AsCoreType(&aServiceInfo.mHostAddress).Clear();
264
265 if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
266 {
267 aServiceInfo.mHostNameBuffer[0] = '\0';
268 }
269 }
270
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const271 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
272 {
273 // This method searches for SRV record in the given `aSection`
274 // matching the record name against `aName`, and updates the
275 // `aServiceInfo` accordingly. It also searches for AAAA record
276 // for host name associated with the service (from SRV record).
277 // The search for AAAA record is always performed in Additional
278 // Data section (independent of the value given in `aSection`).
279
280 Error error = kErrorNone;
281 uint16_t offset;
282 uint16_t numRecords;
283 Name hostName;
284 SrvRecord srvRecord;
285
286 // A non-zero `mTtl` indicates that SRV record is already found
287 // and parsed from a previous response.
288 VerifyOrExit(aServiceInfo.mTtl == 0);
289
290 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
291
292 // Search for a matching SRV record
293 SelectSection(aSection, offset, numRecords);
294 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
295
296 aServiceInfo.mTtl = srvRecord.GetTtl();
297 aServiceInfo.mPort = srvRecord.GetPort();
298 aServiceInfo.mPriority = srvRecord.GetPriority();
299 aServiceInfo.mWeight = srvRecord.GetWeight();
300
301 hostName.SetFromMessage(*mMessage, offset);
302
303 if (aServiceInfo.mHostNameBuffer != nullptr)
304 {
305 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
306 aServiceInfo.mHostNameBufferSize));
307 }
308 else
309 {
310 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
311 }
312
313 // Search in additional section for AAAA record for the host name.
314
315 VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
316
317 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
318 aServiceInfo.mHostAddressTtl);
319
320 if (error == kErrorNotFound)
321 {
322 error = kErrorNone;
323 }
324
325 exit:
326 return error;
327 }
328
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const329 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
330 {
331 // This method searches a TXT record in the given `aSection`
332 // matching the record name against `aName` and updates the TXT
333 // related properties in `aServicesInfo`.
334 //
335 // If no match is found `mTxtDataTtl` (which is initialized to zero)
336 // remains unchanged to indicate this. In this case this method still
337 // returns `kErrorNone`.
338
339 Error error = kErrorNone;
340 uint16_t offset;
341 uint16_t numRecords;
342 TxtRecord txtRecord;
343
344 // A non-zero `mTxtDataTtl` indicates that TXT record is already
345 // found and parsed from a previous response.
346 VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
347
348 // A null `mTxtData` indicates that caller does not want to retrieve
349 // TXT data.
350 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
351
352 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
353
354 SelectSection(aSection, offset, numRecords);
355
356 aServiceInfo.mTxtDataTruncated = false;
357
358 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
359
360 error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
361
362 if (error == kErrorNoBufs)
363 {
364 error = kErrorNone;
365
366 // Mark `mTxtDataTruncated` to indicate that we could not read
367 // the full TXT record into the given `mTxtData` buffer.
368 aServiceInfo.mTxtDataTruncated = true;
369 }
370
371 SuccessOrExit(error);
372 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
373
374 exit:
375 if (error == kErrorNotFound)
376 {
377 error = kErrorNone;
378 }
379
380 return error;
381 }
382
383 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
384
PopulateFrom(const Message & aMessage)385 void Client::Response::PopulateFrom(const Message &aMessage)
386 {
387 // Populate `Response` with info from `aMessage`.
388
389 uint16_t offset = aMessage.GetOffset();
390 Header header;
391
392 mMessage = &aMessage;
393
394 IgnoreError(aMessage.Read(offset, header));
395 offset += sizeof(Header);
396
397 for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
398 {
399 IgnoreError(Name::ParseName(aMessage, offset));
400 offset += sizeof(Question);
401 }
402
403 mAnswerOffset = offset;
404 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
405 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
406 mAdditionalOffset = offset;
407 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
408
409 mAnswerRecordCount = header.GetAnswerCount();
410 mAdditionalRecordCount = header.GetAdditionalRecordCount();
411 }
412
413 //---------------------------------------------------------------------------------------------------------------------
414 // Client::AddressResponse
415
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const416 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
417 {
418 Error error = kErrorNone;
419 Name name(*mQuery, kNameOffsetInQuery);
420
421 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
422
423 // If the response is for an IPv4 address query or if it is an
424 // IPv6 address query response with no IPv6 address but with
425 // an IPv4 in its additional section, we read the IPv4 address
426 // and translate it to an IPv6 address.
427
428 QueryInfo info;
429
430 info.ReadFrom(*mQuery);
431
432 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
433 {
434 Section section;
435 ARecord aRecord;
436 NetworkData::ExternalRouteConfig nat64Prefix;
437
438 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
439 error = kErrorInvalidState);
440
441 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
442 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
443
444 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
445 aTtl = aRecord.GetTtl();
446
447 ExitNow();
448 }
449
450 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
451
452 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
453
454 exit:
455 return error;
456 }
457
458 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
459
460 //---------------------------------------------------------------------------------------------------------------------
461 // Client::BrowseResponse
462
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const463 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
464 {
465 Error error;
466 uint16_t offset;
467 uint16_t numRecords;
468 Name serviceName(*mQuery, kNameOffsetInQuery);
469 PtrRecord ptrRecord;
470
471 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
472
473 SelectSection(kAnswerSection, offset, numRecords);
474 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
475 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
476
477 exit:
478 return error;
479 }
480
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const481 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
482 {
483 Error error;
484 Name instanceName;
485
486 // Find a matching PTR record for the service instance label. Then
487 // search and read SRV, TXT and AAAA records in Additional Data
488 // section matching the same name to populate `aServiceInfo`.
489
490 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
491
492 InitServiceInfo(aServiceInfo);
493 SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
494 SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
495
496 if (aServiceInfo.mTxtDataTtl == 0)
497 {
498 aServiceInfo.mTxtDataSize = 0;
499 }
500
501 exit:
502 return error;
503 }
504
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const505 Error Client::BrowseResponse::GetHostAddress(const char *aHostName,
506 uint16_t aIndex,
507 Ip6::Address &aAddress,
508 uint32_t &aTtl) const
509 {
510 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
511 }
512
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const513 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
514 {
515 // This method searches within the Answer Section for a PTR record
516 // matching a given instance label @aInstanceLabel. If found, the
517 // `aName` is updated to return the name in the message.
518
519 Error error;
520 uint16_t offset;
521 Name serviceName(*mQuery, kNameOffsetInQuery);
522 uint16_t numRecords;
523 uint16_t labelOffset;
524 PtrRecord ptrRecord;
525
526 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
527
528 SelectSection(kAnswerSection, offset, numRecords);
529
530 for (; numRecords > 0; numRecords--)
531 {
532 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
533
534 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
535
536 if (error == kErrorNotFound)
537 {
538 // `ReadRecord()` updates `offset` to skip over a
539 // non-matching record.
540 continue;
541 }
542
543 SuccessOrExit(error);
544
545 // It is a PTR record. Check the first label to match the
546 // instance label.
547
548 labelOffset = offset;
549 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
550
551 if (error == kErrorNone)
552 {
553 aInstanceName.SetFromMessage(*mMessage, offset);
554 ExitNow();
555 }
556
557 VerifyOrExit(error == kErrorNotFound);
558
559 // Update offset to skip over the PTR record.
560 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
561 }
562
563 error = kErrorNotFound;
564
565 exit:
566 return error;
567 }
568
569 //---------------------------------------------------------------------------------------------------------------------
570 // Client::ServiceResponse
571
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const572 Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer,
573 uint8_t aLabelBufferSize,
574 char *aNameBuffer,
575 uint16_t aNameBufferSize) const
576 {
577 Error error;
578 uint16_t offset = kNameOffsetInQuery;
579
580 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
581
582 VerifyOrExit(aNameBuffer != nullptr);
583 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
584
585 exit:
586 return error;
587 }
588
GetServiceInfo(ServiceInfo & aServiceInfo) const589 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
590 {
591 // Search and read SRV, TXT records matching name from query.
592
593 Error error = kErrorNotFound;
594
595 InitServiceInfo(aServiceInfo);
596
597 for (const Response *response = this; response != nullptr; response = response->mNext)
598 {
599 Name name(*response->mQuery, kNameOffsetInQuery);
600 QueryInfo info;
601 Section srvSection;
602 Section txtSection;
603
604 info.ReadFrom(*response->mQuery);
605
606 switch (info.mQueryType)
607 {
608 case kIp6AddressQuery:
609 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
610 case kIp4AddressQuery:
611 #endif
612 IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
613 AsCoreType(&aServiceInfo.mHostAddress),
614 aServiceInfo.mHostAddressTtl));
615
616 continue; // to `for()` loop
617
618 case kServiceQuerySrvTxt:
619 case kServiceQuerySrv:
620 case kServiceQueryTxt:
621 break;
622
623 default:
624 continue;
625 }
626
627 // Determine from which section we should try to read the SRV and
628 // TXT records based on the query type.
629 //
630 // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
631 // only one record (SRV or TXT) in the answer section, but we
632 // still try to read the other records from additional data
633 // section in case server provided them.
634
635 srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
636 txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
637
638 error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
639
640 if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
641 {
642 error = kErrorNone;
643 }
644
645 SuccessOrExit(error);
646
647 SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
648 }
649
650 if (aServiceInfo.mTxtDataTtl == 0)
651 {
652 aServiceInfo.mTxtDataSize = 0;
653 }
654
655 exit:
656 return error;
657 }
658
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const659 Error Client::ServiceResponse::GetHostAddress(const char *aHostName,
660 uint16_t aIndex,
661 Ip6::Address &aAddress,
662 uint32_t &aTtl) const
663 {
664 Error error = kErrorNotFound;
665
666 for (const Response *response = this; response != nullptr; response = response->mNext)
667 {
668 Section section = kAdditionalDataSection;
669 QueryInfo info;
670
671 info.ReadFrom(*response->mQuery);
672
673 switch (info.mQueryType)
674 {
675 case kIp6AddressQuery:
676 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
677 case kIp4AddressQuery:
678 #endif
679 section = kAnswerSection;
680 break;
681
682 default:
683 break;
684 }
685
686 error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
687
688 if (error == kErrorNone)
689 {
690 break;
691 }
692 }
693
694 return error;
695 }
696
697 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
698
699 //---------------------------------------------------------------------------------------------------------------------
700 // Client
701
702 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
704 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
705 #endif
706 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
707 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
708 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
709 #endif
710
711 const uint8_t Client::kQuestionCount[] = {
712 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
714 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
715 #endif
716 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
717 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record
718 /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
719 /* kServiceQuerySrv -> */ 1, // SRV record only
720 /* kServiceQueryTxt -> */ 1, // TXT record only
721 #endif
722 };
723
724 const uint16_t *const Client::kQuestionRecordTypes[] = {
725 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
727 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
728 #endif
729 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
730 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
731 /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
732 /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0],
733 /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1],
734
735 #endif
736 };
737
Client(Instance & aInstance)738 Client::Client(Instance &aInstance)
739 : InstanceLocator(aInstance)
740 , mSocket(aInstance)
741 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
742 , mTcpState(kTcpUninitialized)
743 #endif
744 , mTimer(aInstance)
745 , mDefaultConfig(QueryConfig::kInitFromDefaults)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
747 , mUserDidSetDefaultAddress(false)
748 #endif
749 {
750 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
752 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
753 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
754 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
755 static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
756 static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
757 static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
758 #endif
759 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
760 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
761 static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
762 static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
763 static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
764 #endif
765 }
766
Start(void)767 Error Client::Start(void)
768 {
769 Error error;
770
771 SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
772 SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
773
774 exit:
775 return error;
776 }
777
Stop(void)778 void Client::Stop(void)
779 {
780 Query *query;
781
782 while ((query = mMainQueries.GetHead()) != nullptr)
783 {
784 FinalizeQuery(*query, kErrorAbort);
785 }
786
787 IgnoreError(mSocket.Close());
788 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
789 if (mTcpState != kTcpUninitialized)
790 {
791 IgnoreError(mEndpoint.Deinitialize());
792 }
793 #endif
794 }
795
796 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)797 Error Client::InitTcpSocket(void)
798 {
799 Error error;
800 otTcpEndpointInitializeArgs endpointArgs;
801
802 ClearAllBytes(endpointArgs);
803 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
804 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
805 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
806 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
807 endpointArgs.mContext = this;
808 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
809 endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes);
810
811 mSendLink.mNext = nullptr;
812 mSendLink.mData = mSendBufferBytes;
813 mSendLink.mLength = 0;
814
815 SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
816 exit:
817 return error;
818 }
819 #endif
820
SetDefaultConfig(const QueryConfig & aQueryConfig)821 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
822 {
823 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
824
825 mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
826
827 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
828 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
829 UpdateDefaultConfigAddress();
830 #endif
831 }
832
ResetDefaultConfig(void)833 void Client::ResetDefaultConfig(void)
834 {
835 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
836
837 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
838 mUserDidSetDefaultAddress = false;
839 UpdateDefaultConfigAddress();
840 #endif
841 }
842
843 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)844 void Client::UpdateDefaultConfigAddress(void)
845 {
846 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
847
848 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
849 !srpServerAddr.IsUnspecified())
850 {
851 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
852 }
853 }
854 #endif
855
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)856 Error Client::ResolveAddress(const char *aHostName,
857 AddressCallback aCallback,
858 void *aContext,
859 const QueryConfig *aConfig)
860 {
861 QueryInfo info;
862
863 info.Clear();
864 info.mQueryType = kIp6AddressQuery;
865 info.mConfig.SetFrom(aConfig, mDefaultConfig);
866 info.mCallback.mAddressCallback = aCallback;
867 info.mCallbackContext = aContext;
868
869 return StartQuery(info, nullptr, aHostName);
870 }
871
872 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)873 Error Client::ResolveIp4Address(const char *aHostName,
874 AddressCallback aCallback,
875 void *aContext,
876 const QueryConfig *aConfig)
877 {
878 QueryInfo info;
879
880 info.Clear();
881 info.mQueryType = kIp4AddressQuery;
882 info.mConfig.SetFrom(aConfig, mDefaultConfig);
883 info.mCallback.mAddressCallback = aCallback;
884 info.mCallbackContext = aContext;
885
886 return StartQuery(info, nullptr, aHostName);
887 }
888 #endif
889
890 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
891
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)892 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
893 {
894 QueryInfo info;
895
896 info.Clear();
897 info.mQueryType = kBrowseQuery;
898 info.mConfig.SetFrom(aConfig, mDefaultConfig);
899 info.mCallback.mBrowseCallback = aCallback;
900 info.mCallbackContext = aContext;
901
902 return StartQuery(info, nullptr, aServiceName);
903 }
904
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)905 Error Client::ResolveService(const char *aInstanceLabel,
906 const char *aServiceName,
907 ServiceCallback aCallback,
908 void *aContext,
909 const QueryConfig *aConfig)
910 {
911 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
912 }
913
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)914 Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel,
915 const char *aServiceName,
916 ServiceCallback aCallback,
917 void *aContext,
918 const QueryConfig *aConfig)
919 {
920 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
921 }
922
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)923 Error Client::Resolve(const char *aInstanceLabel,
924 const char *aServiceName,
925 ServiceCallback aCallback,
926 void *aContext,
927 const QueryConfig *aConfig,
928 bool aShouldResolveHostAddr)
929 {
930 QueryInfo info;
931 Error error;
932 QueryType secondQueryType = kNoQuery;
933
934 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
935
936 info.Clear();
937
938 info.mConfig.SetFrom(aConfig, mDefaultConfig);
939 info.mShouldResolveHostAddr = aShouldResolveHostAddr;
940
941 switch (info.mConfig.GetServiceMode())
942 {
943 case QueryConfig::kServiceModeSrvTxtSeparate:
944 secondQueryType = kServiceQueryTxt;
945
946 OT_FALL_THROUGH;
947
948 case QueryConfig::kServiceModeSrv:
949 info.mQueryType = kServiceQuerySrv;
950 break;
951
952 case QueryConfig::kServiceModeTxt:
953 info.mQueryType = kServiceQueryTxt;
954 VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
955 break;
956
957 case QueryConfig::kServiceModeSrvTxt:
958 case QueryConfig::kServiceModeSrvTxtOptimize:
959 default:
960 info.mQueryType = kServiceQuerySrvTxt;
961 break;
962 }
963
964 info.mCallback.mServiceCallback = aCallback;
965 info.mCallbackContext = aContext;
966
967 error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
968
969 exit:
970 return error;
971 }
972
973 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
974
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)975 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
976 {
977 // The `aLabel` can be `nullptr` and then `aName` provides the
978 // full name, otherwise the name is appended as `{aLabel}.
979 // {aName}`.
980
981 Error error;
982 Query *query;
983
984 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
985
986 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
987 if (aInfo.mQueryType == kIp4AddressQuery)
988 {
989 NetworkData::ExternalRouteConfig nat64Prefix;
990
991 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
992 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
993 error = kErrorInvalidState);
994 }
995 #endif
996
997 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
998
999 mMainQueries.Enqueue(*query);
1000
1001 error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1002 VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1003
1004 if (aSecondType != kNoQuery)
1005 {
1006 Query *secondQuery;
1007
1008 aInfo.mQueryType = aSecondType;
1009 aInfo.mMessageId = 0;
1010 aInfo.mTransmissionCount = 0;
1011 aInfo.mMainQuery = query;
1012
1013 // We intentionally do not use `error` here so in the unlikely
1014 // case where we cannot allocate the second query we can proceed
1015 // with the first one.
1016 SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1017
1018 IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1019
1020 // Update first query to link to second one by updating
1021 // its `mNextQuery`.
1022 aInfo.ReadFrom(*query);
1023 aInfo.mNextQuery = secondQuery;
1024 UpdateQuery(*query, aInfo);
1025 }
1026
1027 exit:
1028 return error;
1029 }
1030
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1031 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1032 {
1033 Error error = kErrorNone;
1034
1035 aQuery = nullptr;
1036
1037 VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1038
1039 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1040 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1041
1042 SuccessOrExit(error = aQuery->Append(aInfo));
1043
1044 if (aLabel != nullptr)
1045 {
1046 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1047 }
1048
1049 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1050
1051 exit:
1052 FreeAndNullMessageOnError(aQuery, error);
1053 return error;
1054 }
1055
FindMainQuery(Query & aQuery)1056 Client::Query &Client::FindMainQuery(Query &aQuery)
1057 {
1058 QueryInfo info;
1059
1060 info.ReadFrom(aQuery);
1061
1062 return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1063 }
1064
FreeQuery(Query & aQuery)1065 void Client::FreeQuery(Query &aQuery)
1066 {
1067 Query &mainQuery = FindMainQuery(aQuery);
1068 QueryInfo info;
1069
1070 mMainQueries.Dequeue(mainQuery);
1071
1072 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1073 {
1074 info.ReadFrom(*query);
1075 FreeMessage(info.mSavedResponse);
1076 query->Free();
1077 }
1078 }
1079
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1080 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1081 {
1082 // This method prepares and sends a query message represented by
1083 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1084 // the new `mRetransmissionTime`) and updates it in `aQuery` as
1085 // well. `aUpdateTimer` indicates whether the timer should be
1086 // updated when query is sent or not (used in the case where timer
1087 // is handled by caller).
1088
1089 Error error = kErrorNone;
1090 Message *message = nullptr;
1091 Header header;
1092 Ip6::MessageInfo messageInfo;
1093 uint16_t length = 0;
1094
1095 aInfo.mTransmissionCount++;
1096 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1097
1098 if (aInfo.mMessageId == 0)
1099 {
1100 do
1101 {
1102 SuccessOrExit(error = header.SetRandomMessageId());
1103 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1104
1105 aInfo.mMessageId = header.GetMessageId();
1106 }
1107 else
1108 {
1109 header.SetMessageId(aInfo.mMessageId);
1110 }
1111
1112 header.SetType(Header::kTypeQuery);
1113 header.SetQueryType(Header::kQueryTypeStandard);
1114
1115 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1116 {
1117 header.SetRecursionDesiredFlag();
1118 }
1119
1120 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1121
1122 message = mSocket.NewMessage();
1123 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1124
1125 SuccessOrExit(error = message->Append(header));
1126
1127 // Prepare the question section.
1128
1129 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1130 {
1131 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1132 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1133 }
1134
1135 length = message->GetLength() - message->GetOffset();
1136
1137 if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1138 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1139 {
1140 // Check if query will fit into tcp buffer if not return error.
1141 VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1142 OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1143 error = kErrorNoBufs);
1144
1145 // In case of initialized connection check if connected peer and new query have the same address.
1146 if (mTcpState != kTcpUninitialized)
1147 {
1148 VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1149 error = kErrorFailed);
1150 }
1151
1152 switch (mTcpState)
1153 {
1154 case kTcpUninitialized:
1155 SuccessOrExit(error = InitTcpSocket());
1156 SuccessOrExit(
1157 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1158 mTcpState = kTcpConnecting;
1159 PrepareTcpMessage(*message);
1160 break;
1161 case kTcpConnectedIdle:
1162 PrepareTcpMessage(*message);
1163 SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1164 mTcpState = kTcpConnectedSending;
1165 break;
1166 case kTcpConnecting:
1167 PrepareTcpMessage(*message);
1168 break;
1169 case kTcpConnectedSending:
1170 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1171 SuccessOrAssert(error = message->Read(message->GetOffset(),
1172 (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1173 IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1174 break;
1175 }
1176 message->Free();
1177 message = nullptr;
1178 }
1179 #else
1180 {
1181 error = kErrorInvalidArgs;
1182 LogWarn("DNS query over TCP not supported.");
1183 ExitNow();
1184 }
1185 #endif
1186 else
1187 {
1188 VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1189 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1190 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1191 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1192 }
1193
1194 exit:
1195
1196 FreeMessageOnError(message, error);
1197 if (aUpdateTimer)
1198 {
1199 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1200 }
1201
1202 UpdateQuery(aQuery, aInfo);
1203
1204 return error;
1205 }
1206
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1207 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1208 {
1209 // The name is encoded and included after the `Info` in `aQuery`
1210 // starting at `kNameOffsetInQuery`.
1211
1212 return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1213 }
1214
FinalizeQuery(Query & aQuery,Error aError)1215 void Client::FinalizeQuery(Query &aQuery, Error aError)
1216 {
1217 Response response;
1218 Query &mainQuery = FindMainQuery(aQuery);
1219
1220 response.mInstance = &Get<Instance>();
1221 response.mQuery = &mainQuery;
1222
1223 FinalizeQuery(response, aError);
1224 }
1225
FinalizeQuery(Response & aResponse,Error aError)1226 void Client::FinalizeQuery(Response &aResponse, Error aError)
1227 {
1228 QueryType type;
1229 Callback callback;
1230 void *context;
1231
1232 GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1233
1234 switch (type)
1235 {
1236 case kIp6AddressQuery:
1237 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1238 case kIp4AddressQuery:
1239 #endif
1240 if (callback.mAddressCallback != nullptr)
1241 {
1242 callback.mAddressCallback(aError, &aResponse, context);
1243 }
1244 break;
1245
1246 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1247 case kBrowseQuery:
1248 if (callback.mBrowseCallback != nullptr)
1249 {
1250 callback.mBrowseCallback(aError, &aResponse, context);
1251 }
1252 break;
1253
1254 case kServiceQuerySrvTxt:
1255 case kServiceQuerySrv:
1256 case kServiceQueryTxt:
1257 if (callback.mServiceCallback != nullptr)
1258 {
1259 callback.mServiceCallback(aError, &aResponse, context);
1260 }
1261 break;
1262 #endif
1263 case kNoQuery:
1264 break;
1265 }
1266
1267 FreeQuery(*aResponse.mQuery);
1268 }
1269
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1270 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1271 {
1272 QueryInfo info;
1273
1274 info.ReadFrom(aQuery);
1275
1276 aType = info.mQueryType;
1277 aCallback = info.mCallback;
1278 aContext = info.mCallbackContext;
1279 }
1280
FindQueryById(uint16_t aMessageId)1281 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1282 {
1283 Query *matchedQuery = nullptr;
1284 QueryInfo info;
1285
1286 for (Query &mainQuery : mMainQueries)
1287 {
1288 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1289 {
1290 info.ReadFrom(*query);
1291
1292 if (info.mMessageId == aMessageId)
1293 {
1294 matchedQuery = query;
1295 ExitNow();
1296 }
1297 }
1298 }
1299
1300 exit:
1301 return matchedQuery;
1302 }
1303
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1304 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1305 {
1306 OT_UNUSED_VARIABLE(aMsgInfo);
1307
1308 static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1309 }
1310
ProcessResponse(const Message & aResponseMessage)1311 void Client::ProcessResponse(const Message &aResponseMessage)
1312 {
1313 Error responseError;
1314 Query *query;
1315
1316 SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1317
1318 if (responseError != kErrorNone)
1319 {
1320 // Received an error from server, check if we can replace
1321 // the query.
1322
1323 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1324 if (ReplaceWithIp4Query(*query) == kErrorNone)
1325 {
1326 ExitNow();
1327 }
1328 #endif
1329 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1330 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1331 {
1332 ExitNow();
1333 }
1334 #endif
1335
1336 FinalizeQuery(*query, responseError);
1337 ExitNow();
1338 }
1339
1340 // Received successful response from server.
1341
1342 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1343 ResolveHostAddressIfNeeded(*query, aResponseMessage);
1344 #endif
1345
1346 if (!CanFinalizeQuery(*query))
1347 {
1348 SaveQueryResponse(*query, aResponseMessage);
1349 ExitNow();
1350 }
1351
1352 PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1353
1354 exit:
1355 return;
1356 }
1357
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1358 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1359 {
1360 Error error = kErrorNone;
1361 uint16_t offset = aResponseMessage.GetOffset();
1362 Header header;
1363 QueryInfo info;
1364 Name queryName;
1365
1366 SuccessOrExit(error = aResponseMessage.Read(offset, header));
1367 offset += sizeof(Header);
1368
1369 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1370 !header.IsTruncationFlagSet(),
1371 error = kErrorDrop);
1372
1373 aQuery = FindQueryById(header.GetMessageId());
1374 VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1375
1376 info.ReadFrom(*aQuery);
1377
1378 queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1379
1380 // Check the Question Section
1381
1382 if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1383 {
1384 for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1385 {
1386 SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1387 offset += sizeof(Question);
1388 }
1389 }
1390 else
1391 {
1392 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1393 error = kErrorParse);
1394 }
1395
1396 // Check the answer, authority and additional record sections
1397
1398 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1399 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1400 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1401
1402 // Read the response code
1403
1404 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1405
1406 exit:
1407 return error;
1408 }
1409
CanFinalizeQuery(Query & aQuery)1410 bool Client::CanFinalizeQuery(Query &aQuery)
1411 {
1412 // Determines whether we can finalize a main query by checking if
1413 // we have received and saved responses for all other related
1414 // queries associated with `aQuery`. Note that this method is
1415 // called when we receive a response for `aQuery`, so no need to
1416 // check for a saved response for `aQuery` itself.
1417
1418 bool canFinalize = true;
1419 QueryInfo info;
1420
1421 for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1422 {
1423 info.ReadFrom(*query);
1424
1425 if (query == &aQuery)
1426 {
1427 continue;
1428 }
1429
1430 if (info.mSavedResponse == nullptr)
1431 {
1432 canFinalize = false;
1433 ExitNow();
1434 }
1435 }
1436
1437 exit:
1438 return canFinalize;
1439 }
1440
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1441 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1442 {
1443 QueryInfo info;
1444
1445 info.ReadFrom(aQuery);
1446 VerifyOrExit(info.mSavedResponse == nullptr);
1447
1448 // If `Clone()` fails we let retry or timeout handle the error.
1449 info.mSavedResponse = aResponseMessage.Clone();
1450
1451 UpdateQuery(aQuery, info);
1452
1453 exit:
1454 return;
1455 }
1456
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1457 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1458 {
1459 // Populate `aResponse` for `aQuery`. If there is a saved response
1460 // message for `aQuery` we use it, otherwise, we use
1461 // `aResponseMessage`.
1462
1463 QueryInfo info;
1464
1465 info.ReadFrom(aQuery);
1466
1467 aResponse.mInstance = &Get<Instance>();
1468 aResponse.mQuery = &aQuery;
1469 aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1470
1471 return info.mNextQuery;
1472 }
1473
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1474 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1475 {
1476 // This method prepares a list of chained `Response` instances
1477 // corresponding to all related (chained) queries. It uses
1478 // recursion to go through the queries and construct the
1479 // `Response` chain.
1480
1481 Response response;
1482 Query *nextQuery;
1483
1484 nextQuery = PopulateResponse(response, aQuery, aResponseMessage);
1485 response.mNext = aPrevResponse;
1486
1487 if (nextQuery != nullptr)
1488 {
1489 PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1490 }
1491 else
1492 {
1493 FinalizeQuery(response, kErrorNone);
1494 }
1495 }
1496
HandleTimer(void)1497 void Client::HandleTimer(void)
1498 {
1499 NextFireTime nextTime;
1500 QueryInfo info;
1501 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1502 bool hasTcpQuery = false;
1503 #endif
1504
1505 for (Query &mainQuery : mMainQueries)
1506 {
1507 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1508 {
1509 info.ReadFrom(*query);
1510
1511 if (info.mSavedResponse != nullptr)
1512 {
1513 continue;
1514 }
1515
1516 if (nextTime.GetNow() >= info.mRetransmissionTime)
1517 {
1518 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1519 {
1520 FinalizeQuery(*query, kErrorResponseTimeout);
1521 break;
1522 }
1523
1524 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1525 }
1526
1527 nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1528
1529 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1530 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1531 {
1532 hasTcpQuery = true;
1533 }
1534 #endif
1535 }
1536 }
1537
1538 mTimer.FireAt(nextTime);
1539
1540 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1541 if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1542 {
1543 IgnoreError(mEndpoint.SendEndOfStream());
1544 }
1545 #endif
1546 }
1547
1548 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1549
ReplaceWithIp4Query(Query & aQuery)1550 Error Client::ReplaceWithIp4Query(Query &aQuery)
1551 {
1552 Error error = kErrorFailed;
1553 QueryInfo info;
1554
1555 info.ReadFrom(aQuery);
1556
1557 VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1558 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1559
1560 // We send a new query for IPv4 address resolution
1561 // for the same host name. We reuse the existing `aQuery`
1562 // instance and keep all the info but clear `mTransmissionCount`
1563 // and `mMessageId` (so that a new random message ID is
1564 // selected). The new `info` will be saved in the query in
1565 // `SendQuery()`. Note that the current query is still in the
1566 // `mMainQueries` list when `SendQuery()` selects a new random
1567 // message ID, so the existing message ID for this query will
1568 // not be reused.
1569
1570 info.mQueryType = kIp4AddressQuery;
1571 info.mMessageId = 0;
1572 info.mTransmissionCount = 0;
1573
1574 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1575 error = kErrorNone;
1576
1577 exit:
1578 return error;
1579 }
1580
1581 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1582
1583 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1584
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1585 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1586 {
1587 Error error = kErrorFailed;
1588 QueryInfo info;
1589 Query *secondQuery;
1590
1591 info.ReadFrom(aQuery);
1592
1593 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1594 VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1595
1596 secondQuery = aQuery.Clone();
1597 VerifyOrExit(secondQuery != nullptr);
1598
1599 info.mQueryType = kServiceQueryTxt;
1600 info.mMessageId = 0;
1601 info.mTransmissionCount = 0;
1602 info.mMainQuery = &aQuery;
1603 IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1604
1605 info.mQueryType = kServiceQuerySrv;
1606 info.mMessageId = 0;
1607 info.mTransmissionCount = 0;
1608 info.mNextQuery = secondQuery;
1609 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1610 error = kErrorNone;
1611
1612 exit:
1613 return error;
1614 }
1615
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1616 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1617 {
1618 QueryInfo info;
1619 Response response;
1620 ServiceInfo serviceInfo;
1621 char hostName[Name::kMaxNameSize];
1622
1623 info.ReadFrom(aQuery);
1624
1625 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1626 VerifyOrExit(info.mShouldResolveHostAddr);
1627
1628 PopulateResponse(response, aQuery, aResponseMessage);
1629
1630 ClearAllBytes(serviceInfo);
1631 serviceInfo.mHostNameBuffer = hostName;
1632 serviceInfo.mHostNameBufferSize = sizeof(hostName);
1633 SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1634
1635 // Check whether AAAA record for host address is provided in the SRV query response
1636
1637 if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1638 {
1639 Query *newQuery;
1640
1641 info.mQueryType = kIp6AddressQuery;
1642 info.mMessageId = 0;
1643 info.mTransmissionCount = 0;
1644 info.mMainQuery = &FindMainQuery(aQuery);
1645
1646 SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1647 IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1648
1649 // Update `aQuery` to be linked with new query (inserting
1650 // the `newQuery` into the linked-list after `aQuery`).
1651
1652 info.ReadFrom(aQuery);
1653 info.mNextQuery = newQuery;
1654 UpdateQuery(aQuery, info);
1655 }
1656
1657 exit:
1658 return;
1659 }
1660
1661 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1662
1663 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1664 void Client::PrepareTcpMessage(Message &aMessage)
1665 {
1666 uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1667
1668 // Prepending the DNS query with length of the packet according to RFC1035.
1669 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1670 SuccessOrAssert(
1671 aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1672 mSendLink.mLength += length + sizeof(uint16_t);
1673 }
1674
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1675 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1676 {
1677 OT_UNUSED_VARIABLE(aEndpoint);
1678 OT_UNUSED_VARIABLE(aData);
1679 OT_ASSERT(mTcpState == kTcpConnectedSending);
1680
1681 mSendLink.mLength = 0;
1682 mTcpState = kTcpConnectedIdle;
1683 }
1684
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1685 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1686 {
1687 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1688 }
1689
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1690 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1691 {
1692 OT_UNUSED_VARIABLE(aEndpoint);
1693 IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1694 mTcpState = kTcpConnectedSending;
1695 }
1696
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1697 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1698 {
1699 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1700 }
1701
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1702 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1703 size_t &aOffset,
1704 Message &aMessage,
1705 uint16_t aLength)
1706 {
1707 // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1708 // and copy the content into `aMessage`. As we read we can move
1709 // to the next `aLinkedBuffer` and update `aOffset`.
1710 // Returns:
1711 // - `kErrorNone` if `aLength` bytes are successfully read and
1712 // `aOffset` and `aLinkedBuffer` are updated.
1713 // - `kErrorNotFound` is not enough bytes available to read
1714 // from `aLinkedBuffer`.
1715 // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1716
1717 Error error = kErrorNone;
1718
1719 while (aLength > 0)
1720 {
1721 uint16_t bytesToRead = aLength;
1722
1723 VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1724
1725 if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1726 {
1727 bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1728 }
1729
1730 SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1731
1732 aLength -= bytesToRead;
1733 aOffset += bytesToRead;
1734
1735 if (aOffset == aLinkedBuffer->mLength)
1736 {
1737 aLinkedBuffer = aLinkedBuffer->mNext;
1738 aOffset = 0;
1739 }
1740 }
1741
1742 exit:
1743 return error;
1744 }
1745
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1746 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1747 size_t aBytesAvailable,
1748 bool aEndOfStream,
1749 size_t aBytesRemaining)
1750 {
1751 OT_UNUSED_VARIABLE(aEndpoint);
1752 OT_UNUSED_VARIABLE(aBytesRemaining);
1753
1754 Message *message = nullptr;
1755 size_t totalRead = 0;
1756 size_t offset = 0;
1757 const otLinkedBuffer *data;
1758
1759 if (aEndOfStream)
1760 {
1761 // Cleanup is done in disconnected callback.
1762 IgnoreError(mEndpoint.SendEndOfStream());
1763 }
1764
1765 SuccessOrExit(mEndpoint.ReceiveByReference(data));
1766 VerifyOrExit(data != nullptr);
1767
1768 message = mSocket.NewMessage();
1769 VerifyOrExit(message != nullptr);
1770
1771 while (aBytesAvailable > totalRead)
1772 {
1773 uint16_t length;
1774
1775 // Read the `length` field.
1776 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1777
1778 IgnoreError(message->Read(/* aOffset */ 0, length));
1779 length = BigEndian::HostSwap16(length);
1780
1781 // Try to read `length` bytes.
1782 IgnoreError(message->SetLength(0));
1783 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1784
1785 totalRead += length + sizeof(uint16_t);
1786
1787 // Now process the read message as query response.
1788 ProcessResponse(*message);
1789
1790 IgnoreError(message->SetLength(0));
1791
1792 // Loop again to see if we can read another response.
1793 }
1794
1795 exit:
1796 // Inform `mEndPoint` about the total read and processed bytes
1797 IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1798 FreeMessage(message);
1799 }
1800
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1801 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1802 size_t aBytesAvailable,
1803 bool aEndOfStream,
1804 size_t aBytesRemaining)
1805 {
1806 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1807 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1808 }
1809
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1810 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1811 {
1812 OT_UNUSED_VARIABLE(aEndpoint);
1813 OT_UNUSED_VARIABLE(aReason);
1814 QueryInfo info;
1815
1816 IgnoreError(mEndpoint.Deinitialize());
1817 mTcpState = kTcpUninitialized;
1818
1819 // Abort queries in case of connection failures
1820 for (Query &mainQuery : mMainQueries)
1821 {
1822 info.ReadFrom(mainQuery);
1823
1824 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1825 {
1826 FinalizeQuery(mainQuery, kErrorAbort);
1827 }
1828 }
1829 }
1830
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1831 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1832 {
1833 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1834 }
1835
1836 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1837
1838 } // namespace Dns
1839 } // namespace ot
1840
1841 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1842