1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
53 using ot::Encoding::BigEndian::ReadUint16;
54 using ot::Encoding::BigEndian::WriteUint16;
55 #endif
56 
57 RegisterLogModule("DnsClient");
58 
59 //---------------------------------------------------------------------------------------------------------------------
60 // Client::QueryConfig
61 
62 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
63 
QueryConfig(InitMode aMode)64 Client::QueryConfig::QueryConfig(InitMode aMode)
65 {
66     OT_UNUSED_VARIABLE(aMode);
67 
68     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
69     GetServerSockAddr().SetPort(kDefaultServerPort);
70     SetResponseTimeout(kDefaultResponseTimeout);
71     SetMaxTxAttempts(kDefaultMaxTxAttempts);
72     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
73     SetServiceMode(kDefaultServiceMode);
74 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
75     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
76 #endif
77     SetTransportProto(kDnsTransportUdp);
78 }
79 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)80 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
81 {
82     // This method sets the config from `aConfig` replacing any
83     // unspecified fields (value zero) with the fields from
84     // `aDefaultConfig`. If `aConfig` is `nullptr` then
85     // `aDefaultConfig` is used.
86 
87     if (aConfig == nullptr)
88     {
89         *this = aDefaultConfig;
90         ExitNow();
91     }
92 
93     *this = *aConfig;
94 
95     if (GetServerSockAddr().GetAddress().IsUnspecified())
96     {
97         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
98     }
99 
100     if (GetServerSockAddr().GetPort() == 0)
101     {
102         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
103     }
104 
105     if (GetResponseTimeout() == 0)
106     {
107         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
108     }
109 
110     if (GetMaxTxAttempts() == 0)
111     {
112         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
113     }
114 
115     if (GetRecursionFlag() == kFlagUnspecified)
116     {
117         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
118     }
119 
120 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
121     if (GetNat64Mode() == kNat64Unspecified)
122     {
123         SetNat64Mode(aDefaultConfig.GetNat64Mode());
124     }
125 #endif
126 
127     if (GetServiceMode() == kServiceModeUnspecified)
128     {
129         SetServiceMode(aDefaultConfig.GetServiceMode());
130     }
131 
132     if (GetTransportProto() == kDnsTransportUnspecified)
133     {
134         SetTransportProto(aDefaultConfig.GetTransportProto());
135     }
136 
137 exit:
138     return;
139 }
140 
141 //---------------------------------------------------------------------------------------------------------------------
142 // Client::Response
143 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const144 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
145 {
146     switch (aSection)
147     {
148     case kAnswerSection:
149         aOffset    = mAnswerOffset;
150         aNumRecord = mAnswerRecordCount;
151         break;
152     case kAdditionalDataSection:
153     default:
154         aOffset    = mAdditionalOffset;
155         aNumRecord = mAdditionalRecordCount;
156         break;
157     }
158 }
159 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const160 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
161 {
162     uint16_t offset = kNameOffsetInQuery;
163 
164     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
165 }
166 
CheckForHostNameAlias(Section aSection,Name & aHostName) const167 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
168 {
169     // If the response includes a CNAME record mapping the query host
170     // name to a canonical name, we update `aHostName` to the new alias
171     // name. Otherwise `aHostName` remains as before. This method handles
172     // when there are multiple CNAME records mapping the host name multiple
173     // times. We limit number of changes to `kMaxCnameAliasNameChanges`
174     // to detect and handle if the response contains CNAME record loops.
175 
176     Error       error;
177     uint16_t    offset;
178     uint16_t    numRecords;
179     CnameRecord cnameRecord;
180 
181     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
182 
183     for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
184     {
185         SelectSection(aSection, offset, numRecords);
186         error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
187 
188         if (error == kErrorNotFound)
189         {
190             error = kErrorNone;
191             ExitNow();
192         }
193 
194         SuccessOrExit(error);
195 
196         // A CNAME record was found. `offset` now points to after the
197         // last read byte within the `mMessage` into the `cnameRecord`
198         // (which is the start of the new canonical name).
199 
200         aHostName.SetFromMessage(*mMessage, offset);
201         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
202 
203         // Loop back to check if there may be a CNAME record for the
204         // new `aHostName`.
205     }
206 
207     error = kErrorParse;
208 
209 exit:
210     return error;
211 }
212 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const213 Error Client::Response::FindHostAddress(Section       aSection,
214                                         const Name   &aHostName,
215                                         uint16_t      aIndex,
216                                         Ip6::Address &aAddress,
217                                         uint32_t     &aTtl) const
218 {
219     Error      error;
220     uint16_t   offset;
221     uint16_t   numRecords;
222     Name       name = aHostName;
223     AaaaRecord aaaaRecord;
224 
225     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
226 
227     SelectSection(aSection, offset, numRecords);
228     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
229     aAddress = aaaaRecord.GetAddress();
230     aTtl     = aaaaRecord.GetTtl();
231 
232 exit:
233     return error;
234 }
235 
236 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
237 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const238 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
239 {
240     Error    error;
241     uint16_t offset;
242     uint16_t numRecords;
243     Name     name = aHostName;
244 
245     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
246 
247     SelectSection(aSection, offset, numRecords);
248     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
249 
250 exit:
251     return error;
252 }
253 
254 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
255 
256 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
257 
InitServiceInfo(ServiceInfo & aServiceInfo) const258 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
259 {
260     // This method initializes `aServiceInfo` setting all
261     // TTLs to zero and host name to empty string.
262 
263     aServiceInfo.mTtl              = 0;
264     aServiceInfo.mHostAddressTtl   = 0;
265     aServiceInfo.mTxtDataTtl       = 0;
266     aServiceInfo.mTxtDataTruncated = false;
267 
268     AsCoreType(&aServiceInfo.mHostAddress).Clear();
269 
270     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
271     {
272         aServiceInfo.mHostNameBuffer[0] = '\0';
273     }
274 }
275 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const276 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
277 {
278     // This method searches for SRV record in the given `aSection`
279     // matching the record name against `aName`, and updates the
280     // `aServiceInfo` accordingly. It also searches for AAAA record
281     // for host name associated with the service (from SRV record).
282     // The search for AAAA record is always performed in Additional
283     // Data section (independent of the value given in `aSection`).
284 
285     Error     error = kErrorNone;
286     uint16_t  offset;
287     uint16_t  numRecords;
288     Name      hostName;
289     SrvRecord srvRecord;
290 
291     // A non-zero `mTtl` indicates that SRV record is already found
292     // and parsed from a previous response.
293     VerifyOrExit(aServiceInfo.mTtl == 0);
294 
295     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
296 
297     // Search for a matching SRV record
298     SelectSection(aSection, offset, numRecords);
299     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
300 
301     aServiceInfo.mTtl      = srvRecord.GetTtl();
302     aServiceInfo.mPort     = srvRecord.GetPort();
303     aServiceInfo.mPriority = srvRecord.GetPriority();
304     aServiceInfo.mWeight   = srvRecord.GetWeight();
305 
306     hostName.SetFromMessage(*mMessage, offset);
307 
308     if (aServiceInfo.mHostNameBuffer != nullptr)
309     {
310         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
311                                                            aServiceInfo.mHostNameBufferSize));
312     }
313     else
314     {
315         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
316     }
317 
318     // Search in additional section for AAAA record for the host name.
319 
320     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
321 
322     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
323                             aServiceInfo.mHostAddressTtl);
324 
325     if (error == kErrorNotFound)
326     {
327         error = kErrorNone;
328     }
329 
330 exit:
331     return error;
332 }
333 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const334 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
335 {
336     // This method searches a TXT record in the given `aSection`
337     // matching the record name against `aName` and updates the TXT
338     // related properties in `aServicesInfo`.
339     //
340     // If no match is found `mTxtDataTtl` (which is initialized to zero)
341     // remains unchanged to indicate this. In this case this method still
342     // returns `kErrorNone`.
343 
344     Error     error = kErrorNone;
345     uint16_t  offset;
346     uint16_t  numRecords;
347     TxtRecord txtRecord;
348 
349     // A non-zero `mTxtDataTtl` indicates that TXT record is already
350     // found and parsed from a previous response.
351     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
352 
353     // A null `mTxtData` indicates that caller does not want to retrieve
354     // TXT data.
355     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
356 
357     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
358 
359     SelectSection(aSection, offset, numRecords);
360 
361     aServiceInfo.mTxtDataTruncated = false;
362 
363     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
364 
365     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
366 
367     if (error == kErrorNoBufs)
368     {
369         error = kErrorNone;
370 
371         // Mark `mTxtDataTruncated` to indicate that we could not read
372         // the full TXT record into the given `mTxtData` buffer.
373         aServiceInfo.mTxtDataTruncated = true;
374     }
375 
376     SuccessOrExit(error);
377     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
378 
379 exit:
380     if (error == kErrorNotFound)
381     {
382         error = kErrorNone;
383     }
384 
385     return error;
386 }
387 
388 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
389 
PopulateFrom(const Message & aMessage)390 void Client::Response::PopulateFrom(const Message &aMessage)
391 {
392     // Populate `Response` with info from `aMessage`.
393 
394     uint16_t offset = aMessage.GetOffset();
395     Header   header;
396 
397     mMessage = &aMessage;
398 
399     IgnoreError(aMessage.Read(offset, header));
400     offset += sizeof(Header);
401 
402     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
403     {
404         IgnoreError(Name::ParseName(aMessage, offset));
405         offset += sizeof(Question);
406     }
407 
408     mAnswerOffset = offset;
409     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
410     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
411     mAdditionalOffset = offset;
412     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
413 
414     mAnswerRecordCount     = header.GetAnswerCount();
415     mAdditionalRecordCount = header.GetAdditionalRecordCount();
416 }
417 
418 //---------------------------------------------------------------------------------------------------------------------
419 // Client::AddressResponse
420 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const421 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
422 {
423     Error error = kErrorNone;
424     Name  name(*mQuery, kNameOffsetInQuery);
425 
426 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
427 
428     // If the response is for an IPv4 address query or if it is an
429     // IPv6 address query response with no IPv6 address but with
430     // an IPv4 in its additional section, we read the IPv4 address
431     // and translate it to an IPv6 address.
432 
433     QueryInfo info;
434 
435     info.ReadFrom(*mQuery);
436 
437     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
438     {
439         Section                          section;
440         ARecord                          aRecord;
441         NetworkData::ExternalRouteConfig nat64Prefix;
442 
443         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
444                      error = kErrorInvalidState);
445 
446         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
447         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
448 
449         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
450         aTtl = aRecord.GetTtl();
451 
452         ExitNow();
453     }
454 
455 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
456 
457     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
458 
459 exit:
460     return error;
461 }
462 
463 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
464 
465 //---------------------------------------------------------------------------------------------------------------------
466 // Client::BrowseResponse
467 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const468 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
469 {
470     Error     error;
471     uint16_t  offset;
472     uint16_t  numRecords;
473     Name      serviceName(*mQuery, kNameOffsetInQuery);
474     PtrRecord ptrRecord;
475 
476     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
477 
478     SelectSection(kAnswerSection, offset, numRecords);
479     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
480     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
481 
482 exit:
483     return error;
484 }
485 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const486 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
487 {
488     Error error;
489     Name  instanceName;
490 
491     // Find a matching PTR record for the service instance label. Then
492     // search and read SRV, TXT and AAAA records in Additional Data
493     // section matching the same name to populate `aServiceInfo`.
494 
495     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
496 
497     InitServiceInfo(aServiceInfo);
498     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
499     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
500 
501     if (aServiceInfo.mTxtDataTtl == 0)
502     {
503         aServiceInfo.mTxtDataSize = 0;
504     }
505 
506 exit:
507     return error;
508 }
509 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const510 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
511                                              uint16_t      aIndex,
512                                              Ip6::Address &aAddress,
513                                              uint32_t     &aTtl) const
514 {
515     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
516 }
517 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const518 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
519 {
520     // This method searches within the Answer Section for a PTR record
521     // matching a given instance label @aInstanceLabel. If found, the
522     // `aName` is updated to return the name in the message.
523 
524     Error     error;
525     uint16_t  offset;
526     Name      serviceName(*mQuery, kNameOffsetInQuery);
527     uint16_t  numRecords;
528     uint16_t  labelOffset;
529     PtrRecord ptrRecord;
530 
531     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
532 
533     SelectSection(kAnswerSection, offset, numRecords);
534 
535     for (; numRecords > 0; numRecords--)
536     {
537         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
538 
539         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
540 
541         if (error == kErrorNotFound)
542         {
543             // `ReadRecord()` updates `offset` to skip over a
544             // non-matching record.
545             continue;
546         }
547 
548         SuccessOrExit(error);
549 
550         // It is a PTR record. Check the first label to match the
551         // instance label.
552 
553         labelOffset = offset;
554         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
555 
556         if (error == kErrorNone)
557         {
558             aInstanceName.SetFromMessage(*mMessage, offset);
559             ExitNow();
560         }
561 
562         VerifyOrExit(error == kErrorNotFound);
563 
564         // Update offset to skip over the PTR record.
565         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
566     }
567 
568     error = kErrorNotFound;
569 
570 exit:
571     return error;
572 }
573 
574 //---------------------------------------------------------------------------------------------------------------------
575 // Client::ServiceResponse
576 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const577 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
578                                               uint8_t  aLabelBufferSize,
579                                               char    *aNameBuffer,
580                                               uint16_t aNameBufferSize) const
581 {
582     Error    error;
583     uint16_t offset = kNameOffsetInQuery;
584 
585     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
586 
587     VerifyOrExit(aNameBuffer != nullptr);
588     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
589 
590 exit:
591     return error;
592 }
593 
GetServiceInfo(ServiceInfo & aServiceInfo) const594 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
595 {
596     // Search and read SRV, TXT records matching name from query.
597 
598     Error error = kErrorNotFound;
599 
600     InitServiceInfo(aServiceInfo);
601 
602     for (const Response *response = this; response != nullptr; response = response->mNext)
603     {
604         Name      name(*response->mQuery, kNameOffsetInQuery);
605         QueryInfo info;
606         Section   srvSection;
607         Section   txtSection;
608 
609         info.ReadFrom(*response->mQuery);
610 
611         switch (info.mQueryType)
612         {
613         case kIp6AddressQuery:
614 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
615         case kIp4AddressQuery:
616 #endif
617             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
618                                                   AsCoreType(&aServiceInfo.mHostAddress),
619                                                   aServiceInfo.mHostAddressTtl));
620 
621             continue; // to `for()` loop
622 
623         case kServiceQuerySrvTxt:
624         case kServiceQuerySrv:
625         case kServiceQueryTxt:
626             break;
627 
628         default:
629             continue;
630         }
631 
632         // Determine from which section we should try to read the SRV and
633         // TXT records based on the query type.
634         //
635         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
636         // only one record (SRV or TXT) in the answer section, but we
637         // still try to read the other records from additional data
638         // section in case server provided them.
639 
640         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
641         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
642 
643         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
644 
645         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
646         {
647             error = kErrorNone;
648         }
649 
650         SuccessOrExit(error);
651 
652         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
653     }
654 
655     if (aServiceInfo.mTxtDataTtl == 0)
656     {
657         aServiceInfo.mTxtDataSize = 0;
658     }
659 
660 exit:
661     return error;
662 }
663 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const664 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
665                                               uint16_t      aIndex,
666                                               Ip6::Address &aAddress,
667                                               uint32_t     &aTtl) const
668 {
669     Error error = kErrorNotFound;
670 
671     for (const Response *response = this; response != nullptr; response = response->mNext)
672     {
673         Section   section = kAdditionalDataSection;
674         QueryInfo info;
675 
676         info.ReadFrom(*response->mQuery);
677 
678         switch (info.mQueryType)
679         {
680         case kIp6AddressQuery:
681 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
682         case kIp4AddressQuery:
683 #endif
684             section = kAnswerSection;
685             break;
686 
687         default:
688             break;
689         }
690 
691         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
692 
693         if (error == kErrorNone)
694         {
695             break;
696         }
697     }
698 
699     return error;
700 }
701 
702 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
703 
704 //---------------------------------------------------------------------------------------------------------------------
705 // Client
706 
707 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
708 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
709 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
710 #endif
711 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
712 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
713 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
714 #endif
715 
716 const uint8_t Client::kQuestionCount[] = {
717     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
718 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
719     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
720 #endif
721 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
722     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
723     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
724     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
725     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
726 #endif
727 };
728 
729 const uint16_t *const Client::kQuestionRecordTypes[] = {
730     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
731 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
732     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
733 #endif
734 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
735     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
736     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
737     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
738     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
739 
740 #endif
741 };
742 
Client(Instance & aInstance)743 Client::Client(Instance &aInstance)
744     : InstanceLocator(aInstance)
745     , mSocket(aInstance)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
747     , mTcpState(kTcpUninitialized)
748 #endif
749     , mTimer(aInstance)
750     , mDefaultConfig(QueryConfig::kInitFromDefaults)
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
752     , mUserDidSetDefaultAddress(false)
753 #endif
754 {
755     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
756 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
757     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
758 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
759     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
760     static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
761     static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
762     static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
763 #endif
764 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
765     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
766     static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
767     static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
768     static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
769 #endif
770 }
771 
Start(void)772 Error Client::Start(void)
773 {
774     Error error;
775 
776     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
777     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
778 
779 exit:
780     return error;
781 }
782 
Stop(void)783 void Client::Stop(void)
784 {
785     Query *query;
786 
787     while ((query = mMainQueries.GetHead()) != nullptr)
788     {
789         FinalizeQuery(*query, kErrorAbort);
790     }
791 
792     IgnoreError(mSocket.Close());
793 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
794     if (mTcpState != kTcpUninitialized)
795     {
796         IgnoreError(mEndpoint.Deinitialize());
797     }
798 #endif
799 }
800 
801 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)802 Error Client::InitTcpSocket(void)
803 {
804     Error                       error;
805     otTcpEndpointInitializeArgs endpointArgs;
806 
807     memset(&endpointArgs, 0x00, sizeof(endpointArgs));
808     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
809     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
810     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
811     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
812     endpointArgs.mContext                  = this;
813     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
814     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
815 
816     mSendLink.mNext   = nullptr;
817     mSendLink.mData   = mSendBufferBytes;
818     mSendLink.mLength = 0;
819 
820     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
821 exit:
822     return error;
823 }
824 #endif
825 
SetDefaultConfig(const QueryConfig & aQueryConfig)826 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
827 {
828     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
829 
830     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
831 
832 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
833     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
834     UpdateDefaultConfigAddress();
835 #endif
836 }
837 
ResetDefaultConfig(void)838 void Client::ResetDefaultConfig(void)
839 {
840     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
841 
842 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
843     mUserDidSetDefaultAddress = false;
844     UpdateDefaultConfigAddress();
845 #endif
846 }
847 
848 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)849 void Client::UpdateDefaultConfigAddress(void)
850 {
851     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
852 
853     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
854         !srpServerAddr.IsUnspecified())
855     {
856         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
857     }
858 }
859 #endif
860 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)861 Error Client::ResolveAddress(const char        *aHostName,
862                              AddressCallback    aCallback,
863                              void              *aContext,
864                              const QueryConfig *aConfig)
865 {
866     QueryInfo info;
867 
868     info.Clear();
869     info.mQueryType = kIp6AddressQuery;
870     info.mConfig.SetFrom(aConfig, mDefaultConfig);
871     info.mCallback.mAddressCallback = aCallback;
872     info.mCallbackContext           = aContext;
873 
874     return StartQuery(info, nullptr, aHostName);
875 }
876 
877 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)878 Error Client::ResolveIp4Address(const char        *aHostName,
879                                 AddressCallback    aCallback,
880                                 void              *aContext,
881                                 const QueryConfig *aConfig)
882 {
883     QueryInfo info;
884 
885     info.Clear();
886     info.mQueryType = kIp4AddressQuery;
887     info.mConfig.SetFrom(aConfig, mDefaultConfig);
888     info.mCallback.mAddressCallback = aCallback;
889     info.mCallbackContext           = aContext;
890 
891     return StartQuery(info, nullptr, aHostName);
892 }
893 #endif
894 
895 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
896 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)897 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
898 {
899     QueryInfo info;
900 
901     info.Clear();
902     info.mQueryType = kBrowseQuery;
903     info.mConfig.SetFrom(aConfig, mDefaultConfig);
904     info.mCallback.mBrowseCallback = aCallback;
905     info.mCallbackContext          = aContext;
906 
907     return StartQuery(info, nullptr, aServiceName);
908 }
909 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)910 Error Client::ResolveService(const char        *aInstanceLabel,
911                              const char        *aServiceName,
912                              ServiceCallback    aCallback,
913                              void              *aContext,
914                              const QueryConfig *aConfig)
915 {
916     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
917 }
918 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)919 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
920                                            const char        *aServiceName,
921                                            ServiceCallback    aCallback,
922                                            void              *aContext,
923                                            const QueryConfig *aConfig)
924 {
925     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
926 }
927 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)928 Error Client::Resolve(const char        *aInstanceLabel,
929                       const char        *aServiceName,
930                       ServiceCallback    aCallback,
931                       void              *aContext,
932                       const QueryConfig *aConfig,
933                       bool               aShouldResolveHostAddr)
934 {
935     QueryInfo info;
936     Error     error;
937     QueryType secondQueryType = kNoQuery;
938 
939     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
940 
941     info.Clear();
942 
943     info.mConfig.SetFrom(aConfig, mDefaultConfig);
944     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
945 
946     switch (info.mConfig.GetServiceMode())
947     {
948     case QueryConfig::kServiceModeSrvTxtSeparate:
949         secondQueryType = kServiceQueryTxt;
950 
951         OT_FALL_THROUGH;
952 
953     case QueryConfig::kServiceModeSrv:
954         info.mQueryType = kServiceQuerySrv;
955         break;
956 
957     case QueryConfig::kServiceModeTxt:
958         info.mQueryType = kServiceQueryTxt;
959         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
960         break;
961 
962     case QueryConfig::kServiceModeSrvTxt:
963     case QueryConfig::kServiceModeSrvTxtOptimize:
964     default:
965         info.mQueryType = kServiceQuerySrvTxt;
966         break;
967     }
968 
969     info.mCallback.mServiceCallback = aCallback;
970     info.mCallbackContext           = aContext;
971 
972     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
973 
974 exit:
975     return error;
976 }
977 
978 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
979 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)980 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
981 {
982     // The `aLabel` can be `nullptr` and then `aName` provides the
983     // full name, otherwise the name is appended as `{aLabel}.
984     // {aName}`.
985 
986     Error  error;
987     Query *query;
988 
989     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
990 
991 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
992     if (aInfo.mQueryType == kIp4AddressQuery)
993     {
994         NetworkData::ExternalRouteConfig nat64Prefix;
995 
996         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
997         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
998                      error = kErrorInvalidState);
999     }
1000 #endif
1001 
1002     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
1003 
1004     mMainQueries.Enqueue(*query);
1005 
1006     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1007     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1008 
1009     if (aSecondType != kNoQuery)
1010     {
1011         Query *secondQuery;
1012 
1013         aInfo.mQueryType         = aSecondType;
1014         aInfo.mMessageId         = 0;
1015         aInfo.mTransmissionCount = 0;
1016         aInfo.mMainQuery         = query;
1017 
1018         // We intentionally do not use `error` here so in the unlikely
1019         // case where we cannot allocate the second query we can proceed
1020         // with the first one.
1021         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1022 
1023         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1024 
1025         // Update first query to link to second one by updating
1026         // its `mNextQuery`.
1027         aInfo.ReadFrom(*query);
1028         aInfo.mNextQuery = secondQuery;
1029         UpdateQuery(*query, aInfo);
1030     }
1031 
1032 exit:
1033     return error;
1034 }
1035 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1036 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1037 {
1038     Error error = kErrorNone;
1039 
1040     aQuery = nullptr;
1041 
1042     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1043 
1044     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1045     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1046 
1047     SuccessOrExit(error = aQuery->Append(aInfo));
1048 
1049     if (aLabel != nullptr)
1050     {
1051         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1052     }
1053 
1054     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1055 
1056 exit:
1057     FreeAndNullMessageOnError(aQuery, error);
1058     return error;
1059 }
1060 
FindMainQuery(Query & aQuery)1061 Client::Query &Client::FindMainQuery(Query &aQuery)
1062 {
1063     QueryInfo info;
1064 
1065     info.ReadFrom(aQuery);
1066 
1067     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1068 }
1069 
FreeQuery(Query & aQuery)1070 void Client::FreeQuery(Query &aQuery)
1071 {
1072     Query    &mainQuery = FindMainQuery(aQuery);
1073     QueryInfo info;
1074 
1075     mMainQueries.Dequeue(mainQuery);
1076 
1077     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1078     {
1079         info.ReadFrom(*query);
1080         FreeMessage(info.mSavedResponse);
1081         query->Free();
1082     }
1083 }
1084 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1085 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1086 {
1087     // This method prepares and sends a query message represented by
1088     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1089     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1090     // well. `aUpdateTimer` indicates whether the timer should be
1091     // updated when query is sent or not (used in the case where timer
1092     // is handled by caller).
1093 
1094     Error            error   = kErrorNone;
1095     Message         *message = nullptr;
1096     Header           header;
1097     Ip6::MessageInfo messageInfo;
1098     uint16_t         length = 0;
1099 
1100     aInfo.mTransmissionCount++;
1101     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1102 
1103     if (aInfo.mMessageId == 0)
1104     {
1105         do
1106         {
1107             SuccessOrExit(error = header.SetRandomMessageId());
1108         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1109 
1110         aInfo.mMessageId = header.GetMessageId();
1111     }
1112     else
1113     {
1114         header.SetMessageId(aInfo.mMessageId);
1115     }
1116 
1117     header.SetType(Header::kTypeQuery);
1118     header.SetQueryType(Header::kQueryTypeStandard);
1119 
1120     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1121     {
1122         header.SetRecursionDesiredFlag();
1123     }
1124 
1125     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1126 
1127     message = mSocket.NewMessage();
1128     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1129 
1130     SuccessOrExit(error = message->Append(header));
1131 
1132     // Prepare the question section.
1133 
1134     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1135     {
1136         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1137         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1138     }
1139 
1140     length = message->GetLength() - message->GetOffset();
1141 
1142     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1143 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1144     {
1145         // Check if query will fit into tcp buffer if not return error.
1146         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1147                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1148                      error = kErrorNoBufs);
1149 
1150         // In case of initialized connection check if connected peer and new query have the same address.
1151         if (mTcpState != kTcpUninitialized)
1152         {
1153             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1154                          error = kErrorFailed);
1155         }
1156 
1157         switch (mTcpState)
1158         {
1159         case kTcpUninitialized:
1160             SuccessOrExit(error = InitTcpSocket());
1161             SuccessOrExit(
1162                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1163             mTcpState = kTcpConnecting;
1164             PrepareTcpMessage(*message);
1165             break;
1166         case kTcpConnectedIdle:
1167             PrepareTcpMessage(*message);
1168             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1169             mTcpState = kTcpConnectedSending;
1170             break;
1171         case kTcpConnecting:
1172             PrepareTcpMessage(*message);
1173             break;
1174         case kTcpConnectedSending:
1175             WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1176             SuccessOrAssert(error = message->Read(message->GetOffset(),
1177                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1178             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1179             break;
1180         }
1181         message->Free();
1182         message = nullptr;
1183     }
1184 #else
1185     {
1186         error = kErrorInvalidArgs;
1187         LogWarn("DNS query over TCP not supported.");
1188         ExitNow();
1189     }
1190 #endif
1191     else
1192     {
1193         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1194         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1195         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1196         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1197     }
1198 
1199 exit:
1200 
1201     FreeMessageOnError(message, error);
1202     if (aUpdateTimer)
1203     {
1204         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1205     }
1206 
1207     UpdateQuery(aQuery, aInfo);
1208 
1209     return error;
1210 }
1211 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1212 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1213 {
1214     // The name is encoded and included after the `Info` in `aQuery`
1215     // starting at `kNameOffsetInQuery`.
1216 
1217     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1218 }
1219 
FinalizeQuery(Query & aQuery,Error aError)1220 void Client::FinalizeQuery(Query &aQuery, Error aError)
1221 {
1222     Response response;
1223     Query   &mainQuery = FindMainQuery(aQuery);
1224 
1225     response.mInstance = &Get<Instance>();
1226     response.mQuery    = &mainQuery;
1227 
1228     FinalizeQuery(response, aError);
1229 }
1230 
FinalizeQuery(Response & aResponse,Error aError)1231 void Client::FinalizeQuery(Response &aResponse, Error aError)
1232 {
1233     QueryType type;
1234     Callback  callback;
1235     void     *context;
1236 
1237     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1238 
1239     switch (type)
1240     {
1241     case kIp6AddressQuery:
1242 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1243     case kIp4AddressQuery:
1244 #endif
1245         if (callback.mAddressCallback != nullptr)
1246         {
1247             callback.mAddressCallback(aError, &aResponse, context);
1248         }
1249         break;
1250 
1251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1252     case kBrowseQuery:
1253         if (callback.mBrowseCallback != nullptr)
1254         {
1255             callback.mBrowseCallback(aError, &aResponse, context);
1256         }
1257         break;
1258 
1259     case kServiceQuerySrvTxt:
1260     case kServiceQuerySrv:
1261     case kServiceQueryTxt:
1262         if (callback.mServiceCallback != nullptr)
1263         {
1264             callback.mServiceCallback(aError, &aResponse, context);
1265         }
1266         break;
1267 #endif
1268     case kNoQuery:
1269         break;
1270     }
1271 
1272     FreeQuery(*aResponse.mQuery);
1273 }
1274 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1275 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1276 {
1277     QueryInfo info;
1278 
1279     info.ReadFrom(aQuery);
1280 
1281     aType     = info.mQueryType;
1282     aCallback = info.mCallback;
1283     aContext  = info.mCallbackContext;
1284 }
1285 
FindQueryById(uint16_t aMessageId)1286 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1287 {
1288     Query    *matchedQuery = nullptr;
1289     QueryInfo info;
1290 
1291     for (Query &mainQuery : mMainQueries)
1292     {
1293         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1294         {
1295             info.ReadFrom(*query);
1296 
1297             if (info.mMessageId == aMessageId)
1298             {
1299                 matchedQuery = query;
1300                 ExitNow();
1301             }
1302         }
1303     }
1304 
1305 exit:
1306     return matchedQuery;
1307 }
1308 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1309 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1310 {
1311     OT_UNUSED_VARIABLE(aMsgInfo);
1312 
1313     static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1314 }
1315 
ProcessResponse(const Message & aResponseMessage)1316 void Client::ProcessResponse(const Message &aResponseMessage)
1317 {
1318     Error  responseError;
1319     Query *query;
1320 
1321     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1322 
1323     if (responseError != kErrorNone)
1324     {
1325         // Received an error from server, check if we can replace
1326         // the query.
1327 
1328 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1329         if (ReplaceWithIp4Query(*query) == kErrorNone)
1330         {
1331             ExitNow();
1332         }
1333 #endif
1334 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1335         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1336         {
1337             ExitNow();
1338         }
1339 #endif
1340 
1341         FinalizeQuery(*query, responseError);
1342         ExitNow();
1343     }
1344 
1345     // Received successful response from server.
1346 
1347 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1348     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1349 #endif
1350 
1351     if (!CanFinalizeQuery(*query))
1352     {
1353         SaveQueryResponse(*query, aResponseMessage);
1354         ExitNow();
1355     }
1356 
1357     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1358 
1359 exit:
1360     return;
1361 }
1362 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1363 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1364 {
1365     Error     error  = kErrorNone;
1366     uint16_t  offset = aResponseMessage.GetOffset();
1367     Header    header;
1368     QueryInfo info;
1369     Name      queryName;
1370 
1371     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1372     offset += sizeof(Header);
1373 
1374     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1375                      !header.IsTruncationFlagSet(),
1376                  error = kErrorDrop);
1377 
1378     aQuery = FindQueryById(header.GetMessageId());
1379     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1380 
1381     info.ReadFrom(*aQuery);
1382 
1383     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1384 
1385     // Check the Question Section
1386 
1387     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1388     {
1389         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1390         {
1391             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1392             offset += sizeof(Question);
1393         }
1394     }
1395     else
1396     {
1397         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1398                      error = kErrorParse);
1399     }
1400 
1401     // Check the answer, authority and additional record sections
1402 
1403     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1404     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1405     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1406 
1407     // Read the response code
1408 
1409     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1410 
1411 exit:
1412     return error;
1413 }
1414 
CanFinalizeQuery(Query & aQuery)1415 bool Client::CanFinalizeQuery(Query &aQuery)
1416 {
1417     // Determines whether we can finalize a main query by checking if
1418     // we have received and saved responses for all other related
1419     // queries associated with `aQuery`. Note that this method is
1420     // called when we receive a response for `aQuery`, so no need to
1421     // check for a saved response for `aQuery` itself.
1422 
1423     bool      canFinalize = true;
1424     QueryInfo info;
1425 
1426     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1427     {
1428         info.ReadFrom(*query);
1429 
1430         if (query == &aQuery)
1431         {
1432             continue;
1433         }
1434 
1435         if (info.mSavedResponse == nullptr)
1436         {
1437             canFinalize = false;
1438             ExitNow();
1439         }
1440     }
1441 
1442 exit:
1443     return canFinalize;
1444 }
1445 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1446 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1447 {
1448     QueryInfo info;
1449 
1450     info.ReadFrom(aQuery);
1451     VerifyOrExit(info.mSavedResponse == nullptr);
1452 
1453     // If `Clone()` fails we let retry or timeout handle the error.
1454     info.mSavedResponse = aResponseMessage.Clone();
1455 
1456     UpdateQuery(aQuery, info);
1457 
1458 exit:
1459     return;
1460 }
1461 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1462 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1463 {
1464     // Populate `aResponse` for `aQuery`. If there is a saved response
1465     // message for `aQuery` we use it, otherwise, we use
1466     // `aResponseMessage`.
1467 
1468     QueryInfo info;
1469 
1470     info.ReadFrom(aQuery);
1471 
1472     aResponse.mInstance = &Get<Instance>();
1473     aResponse.mQuery    = &aQuery;
1474     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1475 
1476     return info.mNextQuery;
1477 }
1478 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1479 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1480 {
1481     // This method prepares a list of chained `Response` instances
1482     // corresponding to all related (chained) queries. It uses
1483     // recursion to go through the queries and construct the
1484     // `Response` chain.
1485 
1486     Response response;
1487     Query   *nextQuery;
1488 
1489     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1490     response.mNext = aPrevResponse;
1491 
1492     if (nextQuery != nullptr)
1493     {
1494         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1495     }
1496     else
1497     {
1498         FinalizeQuery(response, kErrorNone);
1499     }
1500 }
1501 
HandleTimer(void)1502 void Client::HandleTimer(void)
1503 {
1504     TimeMilli now      = TimerMilli::GetNow();
1505     TimeMilli nextTime = now.GetDistantFuture();
1506     QueryInfo info;
1507 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1508     bool hasTcpQuery = false;
1509 #endif
1510 
1511     for (Query &mainQuery : mMainQueries)
1512     {
1513         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1514         {
1515             info.ReadFrom(*query);
1516 
1517             if (info.mSavedResponse != nullptr)
1518             {
1519                 continue;
1520             }
1521 
1522             if (now >= info.mRetransmissionTime)
1523             {
1524                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1525                 {
1526                     FinalizeQuery(*query, kErrorResponseTimeout);
1527                     break;
1528                 }
1529 
1530                 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1531             }
1532 
1533             if (nextTime > info.mRetransmissionTime)
1534             {
1535                 nextTime = info.mRetransmissionTime;
1536             }
1537 
1538 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1539             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1540             {
1541                 hasTcpQuery = true;
1542             }
1543 #endif
1544         }
1545     }
1546 
1547     if (nextTime < now.GetDistantFuture())
1548     {
1549         mTimer.FireAt(nextTime);
1550     }
1551 
1552 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1553     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1554     {
1555         IgnoreError(mEndpoint.SendEndOfStream());
1556     }
1557 #endif
1558 }
1559 
1560 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1561 
ReplaceWithIp4Query(Query & aQuery)1562 Error Client::ReplaceWithIp4Query(Query &aQuery)
1563 {
1564     Error     error = kErrorFailed;
1565     QueryInfo info;
1566 
1567     info.ReadFrom(aQuery);
1568 
1569     VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1570     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1571 
1572     // We send a new query for IPv4 address resolution
1573     // for the same host name. We reuse the existing `aQuery`
1574     // instance and keep all the info but clear `mTransmissionCount`
1575     // and `mMessageId` (so that a new random message ID is
1576     // selected). The new `info` will be saved in the query in
1577     // `SendQuery()`. Note that the current query is still in the
1578     // `mMainQueries` list when `SendQuery()` selects a new random
1579     // message ID, so the existing message ID for this query will
1580     // not be reused.
1581 
1582     info.mQueryType         = kIp4AddressQuery;
1583     info.mMessageId         = 0;
1584     info.mTransmissionCount = 0;
1585 
1586     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1587     error = kErrorNone;
1588 
1589 exit:
1590     return error;
1591 }
1592 
1593 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1594 
1595 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1596 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1597 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1598 {
1599     Error     error = kErrorFailed;
1600     QueryInfo info;
1601     Query    *secondQuery;
1602 
1603     info.ReadFrom(aQuery);
1604 
1605     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1606     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1607 
1608     secondQuery = aQuery.Clone();
1609     VerifyOrExit(secondQuery != nullptr);
1610 
1611     info.mQueryType         = kServiceQueryTxt;
1612     info.mMessageId         = 0;
1613     info.mTransmissionCount = 0;
1614     info.mMainQuery         = &aQuery;
1615     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1616 
1617     info.mQueryType         = kServiceQuerySrv;
1618     info.mMessageId         = 0;
1619     info.mTransmissionCount = 0;
1620     info.mNextQuery         = secondQuery;
1621     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1622     error = kErrorNone;
1623 
1624 exit:
1625     return error;
1626 }
1627 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1628 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1629 {
1630     QueryInfo   info;
1631     Response    response;
1632     ServiceInfo serviceInfo;
1633     char        hostName[Name::kMaxNameSize];
1634 
1635     info.ReadFrom(aQuery);
1636 
1637     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1638     VerifyOrExit(info.mShouldResolveHostAddr);
1639 
1640     PopulateResponse(response, aQuery, aResponseMessage);
1641 
1642     memset(&serviceInfo, 0, sizeof(serviceInfo));
1643     serviceInfo.mHostNameBuffer     = hostName;
1644     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1645     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1646 
1647     // Check whether AAAA record for host address is provided in the SRV query response
1648 
1649     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1650     {
1651         Query *newQuery;
1652 
1653         info.mQueryType         = kIp6AddressQuery;
1654         info.mMessageId         = 0;
1655         info.mTransmissionCount = 0;
1656         info.mMainQuery         = &FindMainQuery(aQuery);
1657 
1658         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1659         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1660 
1661         // Update `aQuery` to be linked with new query (inserting
1662         // the `newQuery` into the linked-list after `aQuery`).
1663 
1664         info.ReadFrom(aQuery);
1665         info.mNextQuery = newQuery;
1666         UpdateQuery(aQuery, info);
1667     }
1668 
1669 exit:
1670     return;
1671 }
1672 
1673 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1674 
1675 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1676 void Client::PrepareTcpMessage(Message &aMessage)
1677 {
1678     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1679 
1680     // Prepending the DNS query with length of the packet according to RFC1035.
1681     WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1682     SuccessOrAssert(
1683         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1684     mSendLink.mLength += length + sizeof(uint16_t);
1685 }
1686 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1687 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1688 {
1689     OT_UNUSED_VARIABLE(aEndpoint);
1690     OT_UNUSED_VARIABLE(aData);
1691     OT_ASSERT(mTcpState == kTcpConnectedSending);
1692 
1693     mSendLink.mLength = 0;
1694     mTcpState         = kTcpConnectedIdle;
1695 }
1696 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1697 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1698 {
1699     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1700 }
1701 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1702 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1703 {
1704     OT_UNUSED_VARIABLE(aEndpoint);
1705     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1706     mTcpState = kTcpConnectedSending;
1707 }
1708 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1709 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1710 {
1711     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1712 }
1713 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1714 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1715                                  size_t                &aOffset,
1716                                  Message               &aMessage,
1717                                  uint16_t               aLength)
1718 {
1719     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1720     // and copy the content into `aMessage`. As we read we can move
1721     // to the next `aLinkedBuffer` and update `aOffset`.
1722     // Returns:
1723     // - `kErrorNone` if `aLength` bytes are successfully read and
1724     //    `aOffset` and `aLinkedBuffer` are updated.
1725     // - `kErrorNotFound` is not enough bytes available to read
1726     //    from `aLinkedBuffer`.
1727     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1728 
1729     Error error = kErrorNone;
1730 
1731     while (aLength > 0)
1732     {
1733         uint16_t bytesToRead = aLength;
1734 
1735         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1736 
1737         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1738         {
1739             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1740         }
1741 
1742         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1743 
1744         aLength -= bytesToRead;
1745         aOffset += bytesToRead;
1746 
1747         if (aOffset == aLinkedBuffer->mLength)
1748         {
1749             aLinkedBuffer = aLinkedBuffer->mNext;
1750             aOffset       = 0;
1751         }
1752     }
1753 
1754 exit:
1755     return error;
1756 }
1757 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1758 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1759                                        size_t         aBytesAvailable,
1760                                        bool           aEndOfStream,
1761                                        size_t         aBytesRemaining)
1762 {
1763     OT_UNUSED_VARIABLE(aEndpoint);
1764     OT_UNUSED_VARIABLE(aBytesRemaining);
1765 
1766     Message              *message   = nullptr;
1767     size_t                totalRead = 0;
1768     size_t                offset    = 0;
1769     const otLinkedBuffer *data;
1770 
1771     if (aEndOfStream)
1772     {
1773         // Cleanup is done in disconnected callback.
1774         IgnoreError(mEndpoint.SendEndOfStream());
1775     }
1776 
1777     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1778     VerifyOrExit(data != nullptr);
1779 
1780     message = mSocket.NewMessage();
1781     VerifyOrExit(message != nullptr);
1782 
1783     while (aBytesAvailable > totalRead)
1784     {
1785         uint16_t length;
1786 
1787         // Read the `length` field.
1788         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1789 
1790         IgnoreError(message->Read(/* aOffset */ 0, length));
1791         length = HostSwap16(length);
1792 
1793         // Try to read `length` bytes.
1794         IgnoreError(message->SetLength(0));
1795         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1796 
1797         totalRead += length + sizeof(uint16_t);
1798 
1799         // Now process the read message as query response.
1800         ProcessResponse(*message);
1801 
1802         IgnoreError(message->SetLength(0));
1803 
1804         // Loop again to see if we can read another response.
1805     }
1806 
1807 exit:
1808     // Inform `mEndPoint` about the total read and processed bytes
1809     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1810     FreeMessage(message);
1811 }
1812 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1813 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1814                                                size_t         aBytesAvailable,
1815                                                bool           aEndOfStream,
1816                                                size_t         aBytesRemaining)
1817 {
1818     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1819         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1820 }
1821 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1822 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1823 {
1824     OT_UNUSED_VARIABLE(aEndpoint);
1825     OT_UNUSED_VARIABLE(aReason);
1826     QueryInfo info;
1827 
1828     IgnoreError(mEndpoint.Deinitialize());
1829     mTcpState = kTcpUninitialized;
1830 
1831     // Abort queries in case of connection failures
1832     for (Query &mainQuery : mMainQueries)
1833     {
1834         info.ReadFrom(mainQuery);
1835 
1836         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1837         {
1838             FinalizeQuery(mainQuery, kErrorAbort);
1839         }
1840     }
1841 }
1842 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1843 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1844 {
1845     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1846 }
1847 
1848 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1849 
1850 } // namespace Dns
1851 } // namespace ot
1852 
1853 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1854