1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 RegisterLogModule("DnsClient");
53 
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56 
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58 
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61     OT_UNUSED_VARIABLE(aMode);
62 
63     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64     GetServerSockAddr().SetPort(kDefaultServerPort);
65     SetResponseTimeout(kDefaultResponseTimeout);
66     SetMaxTxAttempts(kDefaultMaxTxAttempts);
67     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68     SetServiceMode(kDefaultServiceMode);
69 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
70     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
71 #endif
72     SetTransportProto(kDnsTransportUdp);
73 }
74 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)75 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
76 {
77     // This method sets the config from `aConfig` replacing any
78     // unspecified fields (value zero) with the fields from
79     // `aDefaultConfig`. If `aConfig` is `nullptr` then
80     // `aDefaultConfig` is used.
81 
82     if (aConfig == nullptr)
83     {
84         *this = aDefaultConfig;
85         ExitNow();
86     }
87 
88     *this = *aConfig;
89 
90     if (GetServerSockAddr().GetAddress().IsUnspecified())
91     {
92         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
93     }
94 
95     if (GetServerSockAddr().GetPort() == 0)
96     {
97         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
98     }
99 
100     if (GetResponseTimeout() == 0)
101     {
102         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
103     }
104 
105     if (GetMaxTxAttempts() == 0)
106     {
107         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
108     }
109 
110     if (GetRecursionFlag() == kFlagUnspecified)
111     {
112         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
113     }
114 
115 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
116     if (GetNat64Mode() == kNat64Unspecified)
117     {
118         SetNat64Mode(aDefaultConfig.GetNat64Mode());
119     }
120 #endif
121 
122     if (GetServiceMode() == kServiceModeUnspecified)
123     {
124         SetServiceMode(aDefaultConfig.GetServiceMode());
125     }
126 
127     if (GetTransportProto() == kDnsTransportUnspecified)
128     {
129         SetTransportProto(aDefaultConfig.GetTransportProto());
130     }
131 
132 exit:
133     return;
134 }
135 
136 //---------------------------------------------------------------------------------------------------------------------
137 // Client::Response
138 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const139 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
140 {
141     switch (aSection)
142     {
143     case kAnswerSection:
144         aOffset    = mAnswerOffset;
145         aNumRecord = mAnswerRecordCount;
146         break;
147     case kAdditionalDataSection:
148     default:
149         aOffset    = mAdditionalOffset;
150         aNumRecord = mAdditionalRecordCount;
151         break;
152     }
153 }
154 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const155 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
156 {
157     uint16_t offset = kNameOffsetInQuery;
158 
159     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
160 }
161 
CheckForHostNameAlias(Section aSection,Name & aHostName) const162 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
163 {
164     // If the response includes a CNAME record mapping the query host
165     // name to a canonical name, we update `aHostName` to the new alias
166     // name. Otherwise `aHostName` remains as before. This method handles
167     // when there are multiple CNAME records mapping the host name multiple
168     // times. We limit number of changes to `kMaxCnameAliasNameChanges`
169     // to detect and handle if the response contains CNAME record loops.
170 
171     Error       error;
172     uint16_t    offset;
173     uint16_t    numRecords;
174     CnameRecord cnameRecord;
175 
176     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
177 
178     for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
179     {
180         SelectSection(aSection, offset, numRecords);
181         error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182 
183         if (error == kErrorNotFound)
184         {
185             error = kErrorNone;
186             ExitNow();
187         }
188 
189         SuccessOrExit(error);
190 
191         // A CNAME record was found. `offset` now points to after the
192         // last read byte within the `mMessage` into the `cnameRecord`
193         // (which is the start of the new canonical name).
194 
195         aHostName.SetFromMessage(*mMessage, offset);
196         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
197 
198         // Loop back to check if there may be a CNAME record for the
199         // new `aHostName`.
200     }
201 
202     error = kErrorParse;
203 
204 exit:
205     return error;
206 }
207 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const208 Error Client::Response::FindHostAddress(Section       aSection,
209                                         const Name   &aHostName,
210                                         uint16_t      aIndex,
211                                         Ip6::Address &aAddress,
212                                         uint32_t     &aTtl) const
213 {
214     Error      error;
215     uint16_t   offset;
216     uint16_t   numRecords;
217     Name       name = aHostName;
218     AaaaRecord aaaaRecord;
219 
220     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
221 
222     SelectSection(aSection, offset, numRecords);
223     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
224     aAddress = aaaaRecord.GetAddress();
225     aTtl     = aaaaRecord.GetTtl();
226 
227 exit:
228     return error;
229 }
230 
231 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
232 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const233 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
234 {
235     Error    error;
236     uint16_t offset;
237     uint16_t numRecords;
238     Name     name = aHostName;
239 
240     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
241 
242     SelectSection(aSection, offset, numRecords);
243     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
244 
245 exit:
246     return error;
247 }
248 
249 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
250 
251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
252 
InitServiceInfo(ServiceInfo & aServiceInfo) const253 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
254 {
255     // This method initializes `aServiceInfo` setting all
256     // TTLs to zero and host name to empty string.
257 
258     aServiceInfo.mTtl              = 0;
259     aServiceInfo.mHostAddressTtl   = 0;
260     aServiceInfo.mTxtDataTtl       = 0;
261     aServiceInfo.mTxtDataTruncated = false;
262 
263     AsCoreType(&aServiceInfo.mHostAddress).Clear();
264 
265     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
266     {
267         aServiceInfo.mHostNameBuffer[0] = '\0';
268     }
269 }
270 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const271 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
272 {
273     // This method searches for SRV record in the given `aSection`
274     // matching the record name against `aName`, and updates the
275     // `aServiceInfo` accordingly. It also searches for AAAA record
276     // for host name associated with the service (from SRV record).
277     // The search for AAAA record is always performed in Additional
278     // Data section (independent of the value given in `aSection`).
279 
280     Error     error = kErrorNone;
281     uint16_t  offset;
282     uint16_t  numRecords;
283     Name      hostName;
284     SrvRecord srvRecord;
285 
286     // A non-zero `mTtl` indicates that SRV record is already found
287     // and parsed from a previous response.
288     VerifyOrExit(aServiceInfo.mTtl == 0);
289 
290     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
291 
292     // Search for a matching SRV record
293     SelectSection(aSection, offset, numRecords);
294     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
295 
296     aServiceInfo.mTtl      = srvRecord.GetTtl();
297     aServiceInfo.mPort     = srvRecord.GetPort();
298     aServiceInfo.mPriority = srvRecord.GetPriority();
299     aServiceInfo.mWeight   = srvRecord.GetWeight();
300 
301     hostName.SetFromMessage(*mMessage, offset);
302 
303     if (aServiceInfo.mHostNameBuffer != nullptr)
304     {
305         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
306                                                            aServiceInfo.mHostNameBufferSize));
307     }
308     else
309     {
310         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
311     }
312 
313     // Search in additional section for AAAA record for the host name.
314 
315     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
316 
317     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
318                             aServiceInfo.mHostAddressTtl);
319 
320     if (error == kErrorNotFound)
321     {
322         error = kErrorNone;
323     }
324 
325 exit:
326     return error;
327 }
328 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const329 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
330 {
331     // This method searches a TXT record in the given `aSection`
332     // matching the record name against `aName` and updates the TXT
333     // related properties in `aServicesInfo`.
334     //
335     // If no match is found `mTxtDataTtl` (which is initialized to zero)
336     // remains unchanged to indicate this. In this case this method still
337     // returns `kErrorNone`.
338 
339     Error     error = kErrorNone;
340     uint16_t  offset;
341     uint16_t  numRecords;
342     TxtRecord txtRecord;
343 
344     // A non-zero `mTxtDataTtl` indicates that TXT record is already
345     // found and parsed from a previous response.
346     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
347 
348     // A null `mTxtData` indicates that caller does not want to retrieve
349     // TXT data.
350     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
351 
352     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
353 
354     SelectSection(aSection, offset, numRecords);
355 
356     aServiceInfo.mTxtDataTruncated = false;
357 
358     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
359 
360     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
361 
362     if (error == kErrorNoBufs)
363     {
364         error = kErrorNone;
365 
366         // Mark `mTxtDataTruncated` to indicate that we could not read
367         // the full TXT record into the given `mTxtData` buffer.
368         aServiceInfo.mTxtDataTruncated = true;
369     }
370 
371     SuccessOrExit(error);
372     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
373 
374 exit:
375     if (error == kErrorNotFound)
376     {
377         error = kErrorNone;
378     }
379 
380     return error;
381 }
382 
383 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
384 
PopulateFrom(const Message & aMessage)385 void Client::Response::PopulateFrom(const Message &aMessage)
386 {
387     // Populate `Response` with info from `aMessage`.
388 
389     uint16_t offset = aMessage.GetOffset();
390     Header   header;
391 
392     mMessage = &aMessage;
393 
394     IgnoreError(aMessage.Read(offset, header));
395     offset += sizeof(Header);
396 
397     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
398     {
399         IgnoreError(Name::ParseName(aMessage, offset));
400         offset += sizeof(Question);
401     }
402 
403     mAnswerOffset = offset;
404     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
405     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
406     mAdditionalOffset = offset;
407     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
408 
409     mAnswerRecordCount     = header.GetAnswerCount();
410     mAdditionalRecordCount = header.GetAdditionalRecordCount();
411 }
412 
413 //---------------------------------------------------------------------------------------------------------------------
414 // Client::AddressResponse
415 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const416 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
417 {
418     Error error = kErrorNone;
419     Name  name(*mQuery, kNameOffsetInQuery);
420 
421 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
422 
423     // If the response is for an IPv4 address query or if it is an
424     // IPv6 address query response with no IPv6 address but with
425     // an IPv4 in its additional section, we read the IPv4 address
426     // and translate it to an IPv6 address.
427 
428     QueryInfo info;
429 
430     info.ReadFrom(*mQuery);
431 
432     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
433     {
434         Section                          section;
435         ARecord                          aRecord;
436         NetworkData::ExternalRouteConfig nat64Prefix;
437 
438         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
439                      error = kErrorInvalidState);
440 
441         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
442         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
443 
444         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
445         aTtl = aRecord.GetTtl();
446 
447         ExitNow();
448     }
449 
450 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
451 
452     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
453 
454 exit:
455     return error;
456 }
457 
458 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
459 
460 //---------------------------------------------------------------------------------------------------------------------
461 // Client::BrowseResponse
462 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const463 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
464 {
465     Error     error;
466     uint16_t  offset;
467     uint16_t  numRecords;
468     Name      serviceName(*mQuery, kNameOffsetInQuery);
469     PtrRecord ptrRecord;
470 
471     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
472 
473     SelectSection(kAnswerSection, offset, numRecords);
474     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
475     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
476 
477 exit:
478     return error;
479 }
480 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const481 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
482 {
483     Error error;
484     Name  instanceName;
485 
486     // Find a matching PTR record for the service instance label. Then
487     // search and read SRV, TXT and AAAA records in Additional Data
488     // section matching the same name to populate `aServiceInfo`.
489 
490     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
491 
492     InitServiceInfo(aServiceInfo);
493     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
494     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
495 
496     if (aServiceInfo.mTxtDataTtl == 0)
497     {
498         aServiceInfo.mTxtDataSize = 0;
499     }
500 
501 exit:
502     return error;
503 }
504 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const505 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
506                                              uint16_t      aIndex,
507                                              Ip6::Address &aAddress,
508                                              uint32_t     &aTtl) const
509 {
510     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
511 }
512 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const513 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
514 {
515     // This method searches within the Answer Section for a PTR record
516     // matching a given instance label @aInstanceLabel. If found, the
517     // `aName` is updated to return the name in the message.
518 
519     Error     error;
520     uint16_t  offset;
521     Name      serviceName(*mQuery, kNameOffsetInQuery);
522     uint16_t  numRecords;
523     uint16_t  labelOffset;
524     PtrRecord ptrRecord;
525 
526     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
527 
528     SelectSection(kAnswerSection, offset, numRecords);
529 
530     for (; numRecords > 0; numRecords--)
531     {
532         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
533 
534         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
535 
536         if (error == kErrorNotFound)
537         {
538             // `ReadRecord()` updates `offset` to skip over a
539             // non-matching record.
540             continue;
541         }
542 
543         SuccessOrExit(error);
544 
545         // It is a PTR record. Check the first label to match the
546         // instance label.
547 
548         labelOffset = offset;
549         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
550 
551         if (error == kErrorNone)
552         {
553             aInstanceName.SetFromMessage(*mMessage, offset);
554             ExitNow();
555         }
556 
557         VerifyOrExit(error == kErrorNotFound);
558 
559         // Update offset to skip over the PTR record.
560         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
561     }
562 
563     error = kErrorNotFound;
564 
565 exit:
566     return error;
567 }
568 
569 //---------------------------------------------------------------------------------------------------------------------
570 // Client::ServiceResponse
571 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const572 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
573                                               uint8_t  aLabelBufferSize,
574                                               char    *aNameBuffer,
575                                               uint16_t aNameBufferSize) const
576 {
577     Error    error;
578     uint16_t offset = kNameOffsetInQuery;
579 
580     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
581 
582     VerifyOrExit(aNameBuffer != nullptr);
583     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
584 
585 exit:
586     return error;
587 }
588 
GetServiceInfo(ServiceInfo & aServiceInfo) const589 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
590 {
591     // Search and read SRV, TXT records matching name from query.
592 
593     Error error = kErrorNotFound;
594 
595     InitServiceInfo(aServiceInfo);
596 
597     for (const Response *response = this; response != nullptr; response = response->mNext)
598     {
599         Name      name(*response->mQuery, kNameOffsetInQuery);
600         QueryInfo info;
601         Section   srvSection;
602         Section   txtSection;
603 
604         info.ReadFrom(*response->mQuery);
605 
606         switch (info.mQueryType)
607         {
608         case kIp6AddressQuery:
609 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
610         case kIp4AddressQuery:
611 #endif
612             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
613                                                   AsCoreType(&aServiceInfo.mHostAddress),
614                                                   aServiceInfo.mHostAddressTtl));
615 
616             continue; // to `for()` loop
617 
618         case kServiceQuerySrvTxt:
619         case kServiceQuerySrv:
620         case kServiceQueryTxt:
621             break;
622 
623         default:
624             continue;
625         }
626 
627         // Determine from which section we should try to read the SRV and
628         // TXT records based on the query type.
629         //
630         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
631         // only one record (SRV or TXT) in the answer section, but we
632         // still try to read the other records from additional data
633         // section in case server provided them.
634 
635         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
636         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
637 
638         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
639 
640         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
641         {
642             error = kErrorNone;
643         }
644 
645         SuccessOrExit(error);
646 
647         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
648     }
649 
650     if (aServiceInfo.mTxtDataTtl == 0)
651     {
652         aServiceInfo.mTxtDataSize = 0;
653     }
654 
655 exit:
656     return error;
657 }
658 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const659 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
660                                               uint16_t      aIndex,
661                                               Ip6::Address &aAddress,
662                                               uint32_t     &aTtl) const
663 {
664     Error error = kErrorNotFound;
665 
666     for (const Response *response = this; response != nullptr; response = response->mNext)
667     {
668         Section   section = kAdditionalDataSection;
669         QueryInfo info;
670 
671         info.ReadFrom(*response->mQuery);
672 
673         switch (info.mQueryType)
674         {
675         case kIp6AddressQuery:
676 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
677         case kIp4AddressQuery:
678 #endif
679             section = kAnswerSection;
680             break;
681 
682         default:
683             break;
684         }
685 
686         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
687 
688         if (error == kErrorNone)
689         {
690             break;
691         }
692     }
693 
694     return error;
695 }
696 
697 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
698 
699 //---------------------------------------------------------------------------------------------------------------------
700 // Client
701 
702 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
704 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
705 #endif
706 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
707 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
708 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
709 #endif
710 
711 const uint8_t Client::kQuestionCount[] = {
712     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
714     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
715 #endif
716 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
717     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
718     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
719     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
720     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
721 #endif
722 };
723 
724 const uint16_t *const Client::kQuestionRecordTypes[] = {
725     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
727     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
728 #endif
729 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
730     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
731     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
732     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
733     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
734 
735 #endif
736 };
737 
Client(Instance & aInstance)738 Client::Client(Instance &aInstance)
739     : InstanceLocator(aInstance)
740     , mSocket(aInstance)
741 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
742     , mTcpState(kTcpUninitialized)
743 #endif
744     , mTimer(aInstance)
745     , mDefaultConfig(QueryConfig::kInitFromDefaults)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
747     , mUserDidSetDefaultAddress(false)
748 #endif
749 {
750     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
752     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
753 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
754     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
755     static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
756     static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
757     static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
758 #endif
759 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
760     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
761     static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
762     static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
763     static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
764 #endif
765 }
766 
Start(void)767 Error Client::Start(void)
768 {
769     Error error;
770 
771     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
772     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
773 
774 exit:
775     return error;
776 }
777 
Stop(void)778 void Client::Stop(void)
779 {
780     Query *query;
781 
782     while ((query = mMainQueries.GetHead()) != nullptr)
783     {
784         FinalizeQuery(*query, kErrorAbort);
785     }
786 
787     IgnoreError(mSocket.Close());
788 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
789     if (mTcpState != kTcpUninitialized)
790     {
791         IgnoreError(mEndpoint.Deinitialize());
792     }
793 #endif
794 }
795 
796 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)797 Error Client::InitTcpSocket(void)
798 {
799     Error                       error;
800     otTcpEndpointInitializeArgs endpointArgs;
801 
802     ClearAllBytes(endpointArgs);
803     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
804     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
805     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
806     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
807     endpointArgs.mContext                  = this;
808     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
809     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
810 
811     mSendLink.mNext   = nullptr;
812     mSendLink.mData   = mSendBufferBytes;
813     mSendLink.mLength = 0;
814 
815     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
816 exit:
817     return error;
818 }
819 #endif
820 
SetDefaultConfig(const QueryConfig & aQueryConfig)821 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
822 {
823     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
824 
825     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
826 
827 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
828     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
829     UpdateDefaultConfigAddress();
830 #endif
831 }
832 
ResetDefaultConfig(void)833 void Client::ResetDefaultConfig(void)
834 {
835     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
836 
837 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
838     mUserDidSetDefaultAddress = false;
839     UpdateDefaultConfigAddress();
840 #endif
841 }
842 
843 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)844 void Client::UpdateDefaultConfigAddress(void)
845 {
846     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
847 
848     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
849         !srpServerAddr.IsUnspecified())
850     {
851         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
852     }
853 }
854 #endif
855 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)856 Error Client::ResolveAddress(const char        *aHostName,
857                              AddressCallback    aCallback,
858                              void              *aContext,
859                              const QueryConfig *aConfig)
860 {
861     QueryInfo info;
862 
863     info.Clear();
864     info.mQueryType = kIp6AddressQuery;
865     info.mConfig.SetFrom(aConfig, mDefaultConfig);
866     info.mCallback.mAddressCallback = aCallback;
867     info.mCallbackContext           = aContext;
868 
869     return StartQuery(info, nullptr, aHostName);
870 }
871 
872 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)873 Error Client::ResolveIp4Address(const char        *aHostName,
874                                 AddressCallback    aCallback,
875                                 void              *aContext,
876                                 const QueryConfig *aConfig)
877 {
878     QueryInfo info;
879 
880     info.Clear();
881     info.mQueryType = kIp4AddressQuery;
882     info.mConfig.SetFrom(aConfig, mDefaultConfig);
883     info.mCallback.mAddressCallback = aCallback;
884     info.mCallbackContext           = aContext;
885 
886     return StartQuery(info, nullptr, aHostName);
887 }
888 #endif
889 
890 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
891 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)892 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
893 {
894     QueryInfo info;
895 
896     info.Clear();
897     info.mQueryType = kBrowseQuery;
898     info.mConfig.SetFrom(aConfig, mDefaultConfig);
899     info.mCallback.mBrowseCallback = aCallback;
900     info.mCallbackContext          = aContext;
901 
902     return StartQuery(info, nullptr, aServiceName);
903 }
904 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)905 Error Client::ResolveService(const char        *aInstanceLabel,
906                              const char        *aServiceName,
907                              ServiceCallback    aCallback,
908                              void              *aContext,
909                              const QueryConfig *aConfig)
910 {
911     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
912 }
913 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)914 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
915                                            const char        *aServiceName,
916                                            ServiceCallback    aCallback,
917                                            void              *aContext,
918                                            const QueryConfig *aConfig)
919 {
920     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
921 }
922 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)923 Error Client::Resolve(const char        *aInstanceLabel,
924                       const char        *aServiceName,
925                       ServiceCallback    aCallback,
926                       void              *aContext,
927                       const QueryConfig *aConfig,
928                       bool               aShouldResolveHostAddr)
929 {
930     QueryInfo info;
931     Error     error;
932     QueryType secondQueryType = kNoQuery;
933 
934     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
935 
936     info.Clear();
937 
938     info.mConfig.SetFrom(aConfig, mDefaultConfig);
939     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
940 
941     switch (info.mConfig.GetServiceMode())
942     {
943     case QueryConfig::kServiceModeSrvTxtSeparate:
944         secondQueryType = kServiceQueryTxt;
945 
946         OT_FALL_THROUGH;
947 
948     case QueryConfig::kServiceModeSrv:
949         info.mQueryType = kServiceQuerySrv;
950         break;
951 
952     case QueryConfig::kServiceModeTxt:
953         info.mQueryType = kServiceQueryTxt;
954         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
955         break;
956 
957     case QueryConfig::kServiceModeSrvTxt:
958     case QueryConfig::kServiceModeSrvTxtOptimize:
959     default:
960         info.mQueryType = kServiceQuerySrvTxt;
961         break;
962     }
963 
964     info.mCallback.mServiceCallback = aCallback;
965     info.mCallbackContext           = aContext;
966 
967     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
968 
969 exit:
970     return error;
971 }
972 
973 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
974 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)975 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
976 {
977     // The `aLabel` can be `nullptr` and then `aName` provides the
978     // full name, otherwise the name is appended as `{aLabel}.
979     // {aName}`.
980 
981     Error  error;
982     Query *query;
983 
984     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
985 
986 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
987     if (aInfo.mQueryType == kIp4AddressQuery)
988     {
989         NetworkData::ExternalRouteConfig nat64Prefix;
990 
991         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
992         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
993                      error = kErrorInvalidState);
994     }
995 #endif
996 
997     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
998 
999     mMainQueries.Enqueue(*query);
1000 
1001     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1002     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1003 
1004     if (aSecondType != kNoQuery)
1005     {
1006         Query *secondQuery;
1007 
1008         aInfo.mQueryType         = aSecondType;
1009         aInfo.mMessageId         = 0;
1010         aInfo.mTransmissionCount = 0;
1011         aInfo.mMainQuery         = query;
1012 
1013         // We intentionally do not use `error` here so in the unlikely
1014         // case where we cannot allocate the second query we can proceed
1015         // with the first one.
1016         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1017 
1018         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1019 
1020         // Update first query to link to second one by updating
1021         // its `mNextQuery`.
1022         aInfo.ReadFrom(*query);
1023         aInfo.mNextQuery = secondQuery;
1024         UpdateQuery(*query, aInfo);
1025     }
1026 
1027 exit:
1028     return error;
1029 }
1030 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1031 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1032 {
1033     Error error = kErrorNone;
1034 
1035     aQuery = nullptr;
1036 
1037     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1038 
1039     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1040     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1041 
1042     SuccessOrExit(error = aQuery->Append(aInfo));
1043 
1044     if (aLabel != nullptr)
1045     {
1046         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1047     }
1048 
1049     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1050 
1051 exit:
1052     FreeAndNullMessageOnError(aQuery, error);
1053     return error;
1054 }
1055 
FindMainQuery(Query & aQuery)1056 Client::Query &Client::FindMainQuery(Query &aQuery)
1057 {
1058     QueryInfo info;
1059 
1060     info.ReadFrom(aQuery);
1061 
1062     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1063 }
1064 
FreeQuery(Query & aQuery)1065 void Client::FreeQuery(Query &aQuery)
1066 {
1067     Query    &mainQuery = FindMainQuery(aQuery);
1068     QueryInfo info;
1069 
1070     mMainQueries.Dequeue(mainQuery);
1071 
1072     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1073     {
1074         info.ReadFrom(*query);
1075         FreeMessage(info.mSavedResponse);
1076         query->Free();
1077     }
1078 }
1079 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1080 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1081 {
1082     // This method prepares and sends a query message represented by
1083     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1084     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1085     // well. `aUpdateTimer` indicates whether the timer should be
1086     // updated when query is sent or not (used in the case where timer
1087     // is handled by caller).
1088 
1089     Error            error   = kErrorNone;
1090     Message         *message = nullptr;
1091     Header           header;
1092     Ip6::MessageInfo messageInfo;
1093     uint16_t         length = 0;
1094 
1095     aInfo.mTransmissionCount++;
1096     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1097 
1098     if (aInfo.mMessageId == 0)
1099     {
1100         do
1101         {
1102             SuccessOrExit(error = header.SetRandomMessageId());
1103         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1104 
1105         aInfo.mMessageId = header.GetMessageId();
1106     }
1107     else
1108     {
1109         header.SetMessageId(aInfo.mMessageId);
1110     }
1111 
1112     header.SetType(Header::kTypeQuery);
1113     header.SetQueryType(Header::kQueryTypeStandard);
1114 
1115     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1116     {
1117         header.SetRecursionDesiredFlag();
1118     }
1119 
1120     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1121 
1122     message = mSocket.NewMessage();
1123     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1124 
1125     SuccessOrExit(error = message->Append(header));
1126 
1127     // Prepare the question section.
1128 
1129     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1130     {
1131         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1132         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1133     }
1134 
1135     length = message->GetLength() - message->GetOffset();
1136 
1137     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1138 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1139     {
1140         // Check if query will fit into tcp buffer if not return error.
1141         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1142                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1143                      error = kErrorNoBufs);
1144 
1145         // In case of initialized connection check if connected peer and new query have the same address.
1146         if (mTcpState != kTcpUninitialized)
1147         {
1148             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1149                          error = kErrorFailed);
1150         }
1151 
1152         switch (mTcpState)
1153         {
1154         case kTcpUninitialized:
1155             SuccessOrExit(error = InitTcpSocket());
1156             SuccessOrExit(
1157                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1158             mTcpState = kTcpConnecting;
1159             PrepareTcpMessage(*message);
1160             break;
1161         case kTcpConnectedIdle:
1162             PrepareTcpMessage(*message);
1163             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1164             mTcpState = kTcpConnectedSending;
1165             break;
1166         case kTcpConnecting:
1167             PrepareTcpMessage(*message);
1168             break;
1169         case kTcpConnectedSending:
1170             BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1171             SuccessOrAssert(error = message->Read(message->GetOffset(),
1172                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1173             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1174             break;
1175         }
1176         message->Free();
1177         message = nullptr;
1178     }
1179 #else
1180     {
1181         error = kErrorInvalidArgs;
1182         LogWarn("DNS query over TCP not supported.");
1183         ExitNow();
1184     }
1185 #endif
1186     else
1187     {
1188         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1189         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1190         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1191         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1192     }
1193 
1194 exit:
1195 
1196     FreeMessageOnError(message, error);
1197     if (aUpdateTimer)
1198     {
1199         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1200     }
1201 
1202     UpdateQuery(aQuery, aInfo);
1203 
1204     return error;
1205 }
1206 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1207 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1208 {
1209     // The name is encoded and included after the `Info` in `aQuery`
1210     // starting at `kNameOffsetInQuery`.
1211 
1212     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1213 }
1214 
FinalizeQuery(Query & aQuery,Error aError)1215 void Client::FinalizeQuery(Query &aQuery, Error aError)
1216 {
1217     Response response;
1218     Query   &mainQuery = FindMainQuery(aQuery);
1219 
1220     response.mInstance = &Get<Instance>();
1221     response.mQuery    = &mainQuery;
1222 
1223     FinalizeQuery(response, aError);
1224 }
1225 
FinalizeQuery(Response & aResponse,Error aError)1226 void Client::FinalizeQuery(Response &aResponse, Error aError)
1227 {
1228     QueryType type;
1229     Callback  callback;
1230     void     *context;
1231 
1232     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1233 
1234     switch (type)
1235     {
1236     case kIp6AddressQuery:
1237 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1238     case kIp4AddressQuery:
1239 #endif
1240         if (callback.mAddressCallback != nullptr)
1241         {
1242             callback.mAddressCallback(aError, &aResponse, context);
1243         }
1244         break;
1245 
1246 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1247     case kBrowseQuery:
1248         if (callback.mBrowseCallback != nullptr)
1249         {
1250             callback.mBrowseCallback(aError, &aResponse, context);
1251         }
1252         break;
1253 
1254     case kServiceQuerySrvTxt:
1255     case kServiceQuerySrv:
1256     case kServiceQueryTxt:
1257         if (callback.mServiceCallback != nullptr)
1258         {
1259             callback.mServiceCallback(aError, &aResponse, context);
1260         }
1261         break;
1262 #endif
1263     case kNoQuery:
1264         break;
1265     }
1266 
1267     FreeQuery(*aResponse.mQuery);
1268 }
1269 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1270 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1271 {
1272     QueryInfo info;
1273 
1274     info.ReadFrom(aQuery);
1275 
1276     aType     = info.mQueryType;
1277     aCallback = info.mCallback;
1278     aContext  = info.mCallbackContext;
1279 }
1280 
FindQueryById(uint16_t aMessageId)1281 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1282 {
1283     Query    *matchedQuery = nullptr;
1284     QueryInfo info;
1285 
1286     for (Query &mainQuery : mMainQueries)
1287     {
1288         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1289         {
1290             info.ReadFrom(*query);
1291 
1292             if (info.mMessageId == aMessageId)
1293             {
1294                 matchedQuery = query;
1295                 ExitNow();
1296             }
1297         }
1298     }
1299 
1300 exit:
1301     return matchedQuery;
1302 }
1303 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1304 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1305 {
1306     OT_UNUSED_VARIABLE(aMsgInfo);
1307 
1308     static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1309 }
1310 
ProcessResponse(const Message & aResponseMessage)1311 void Client::ProcessResponse(const Message &aResponseMessage)
1312 {
1313     Error  responseError;
1314     Query *query;
1315 
1316     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1317 
1318     if (responseError != kErrorNone)
1319     {
1320         // Received an error from server, check if we can replace
1321         // the query.
1322 
1323 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1324         if (ReplaceWithIp4Query(*query) == kErrorNone)
1325         {
1326             ExitNow();
1327         }
1328 #endif
1329 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1330         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1331         {
1332             ExitNow();
1333         }
1334 #endif
1335 
1336         FinalizeQuery(*query, responseError);
1337         ExitNow();
1338     }
1339 
1340     // Received successful response from server.
1341 
1342 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1343     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1344 #endif
1345 
1346     if (!CanFinalizeQuery(*query))
1347     {
1348         SaveQueryResponse(*query, aResponseMessage);
1349         ExitNow();
1350     }
1351 
1352     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1353 
1354 exit:
1355     return;
1356 }
1357 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1358 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1359 {
1360     Error     error  = kErrorNone;
1361     uint16_t  offset = aResponseMessage.GetOffset();
1362     Header    header;
1363     QueryInfo info;
1364     Name      queryName;
1365 
1366     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1367     offset += sizeof(Header);
1368 
1369     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1370                      !header.IsTruncationFlagSet(),
1371                  error = kErrorDrop);
1372 
1373     aQuery = FindQueryById(header.GetMessageId());
1374     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1375 
1376     info.ReadFrom(*aQuery);
1377 
1378     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1379 
1380     // Check the Question Section
1381 
1382     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1383     {
1384         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1385         {
1386             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1387             offset += sizeof(Question);
1388         }
1389     }
1390     else
1391     {
1392         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1393                      error = kErrorParse);
1394     }
1395 
1396     // Check the answer, authority and additional record sections
1397 
1398     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1399     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1400     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1401 
1402     // Read the response code
1403 
1404     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1405 
1406 exit:
1407     return error;
1408 }
1409 
CanFinalizeQuery(Query & aQuery)1410 bool Client::CanFinalizeQuery(Query &aQuery)
1411 {
1412     // Determines whether we can finalize a main query by checking if
1413     // we have received and saved responses for all other related
1414     // queries associated with `aQuery`. Note that this method is
1415     // called when we receive a response for `aQuery`, so no need to
1416     // check for a saved response for `aQuery` itself.
1417 
1418     bool      canFinalize = true;
1419     QueryInfo info;
1420 
1421     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1422     {
1423         info.ReadFrom(*query);
1424 
1425         if (query == &aQuery)
1426         {
1427             continue;
1428         }
1429 
1430         if (info.mSavedResponse == nullptr)
1431         {
1432             canFinalize = false;
1433             ExitNow();
1434         }
1435     }
1436 
1437 exit:
1438     return canFinalize;
1439 }
1440 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1441 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1442 {
1443     QueryInfo info;
1444 
1445     info.ReadFrom(aQuery);
1446     VerifyOrExit(info.mSavedResponse == nullptr);
1447 
1448     // If `Clone()` fails we let retry or timeout handle the error.
1449     info.mSavedResponse = aResponseMessage.Clone();
1450 
1451     UpdateQuery(aQuery, info);
1452 
1453 exit:
1454     return;
1455 }
1456 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1457 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1458 {
1459     // Populate `aResponse` for `aQuery`. If there is a saved response
1460     // message for `aQuery` we use it, otherwise, we use
1461     // `aResponseMessage`.
1462 
1463     QueryInfo info;
1464 
1465     info.ReadFrom(aQuery);
1466 
1467     aResponse.mInstance = &Get<Instance>();
1468     aResponse.mQuery    = &aQuery;
1469     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1470 
1471     return info.mNextQuery;
1472 }
1473 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1474 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1475 {
1476     // This method prepares a list of chained `Response` instances
1477     // corresponding to all related (chained) queries. It uses
1478     // recursion to go through the queries and construct the
1479     // `Response` chain.
1480 
1481     Response response;
1482     Query   *nextQuery;
1483 
1484     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1485     response.mNext = aPrevResponse;
1486 
1487     if (nextQuery != nullptr)
1488     {
1489         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1490     }
1491     else
1492     {
1493         FinalizeQuery(response, kErrorNone);
1494     }
1495 }
1496 
HandleTimer(void)1497 void Client::HandleTimer(void)
1498 {
1499     NextFireTime nextTime;
1500     QueryInfo    info;
1501 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1502     bool hasTcpQuery = false;
1503 #endif
1504 
1505     for (Query &mainQuery : mMainQueries)
1506     {
1507         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1508         {
1509             info.ReadFrom(*query);
1510 
1511             if (info.mSavedResponse != nullptr)
1512             {
1513                 continue;
1514             }
1515 
1516             if (nextTime.GetNow() >= info.mRetransmissionTime)
1517             {
1518                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1519                 {
1520                     FinalizeQuery(*query, kErrorResponseTimeout);
1521                     break;
1522                 }
1523 
1524                 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1525             }
1526 
1527             nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1528 
1529 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1530             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1531             {
1532                 hasTcpQuery = true;
1533             }
1534 #endif
1535         }
1536     }
1537 
1538     mTimer.FireAt(nextTime);
1539 
1540 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1541     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1542     {
1543         IgnoreError(mEndpoint.SendEndOfStream());
1544     }
1545 #endif
1546 }
1547 
1548 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1549 
ReplaceWithIp4Query(Query & aQuery)1550 Error Client::ReplaceWithIp4Query(Query &aQuery)
1551 {
1552     Error     error = kErrorFailed;
1553     QueryInfo info;
1554 
1555     info.ReadFrom(aQuery);
1556 
1557     VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1558     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1559 
1560     // We send a new query for IPv4 address resolution
1561     // for the same host name. We reuse the existing `aQuery`
1562     // instance and keep all the info but clear `mTransmissionCount`
1563     // and `mMessageId` (so that a new random message ID is
1564     // selected). The new `info` will be saved in the query in
1565     // `SendQuery()`. Note that the current query is still in the
1566     // `mMainQueries` list when `SendQuery()` selects a new random
1567     // message ID, so the existing message ID for this query will
1568     // not be reused.
1569 
1570     info.mQueryType         = kIp4AddressQuery;
1571     info.mMessageId         = 0;
1572     info.mTransmissionCount = 0;
1573 
1574     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1575     error = kErrorNone;
1576 
1577 exit:
1578     return error;
1579 }
1580 
1581 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1582 
1583 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1584 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1585 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1586 {
1587     Error     error = kErrorFailed;
1588     QueryInfo info;
1589     Query    *secondQuery;
1590 
1591     info.ReadFrom(aQuery);
1592 
1593     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1594     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1595 
1596     secondQuery = aQuery.Clone();
1597     VerifyOrExit(secondQuery != nullptr);
1598 
1599     info.mQueryType         = kServiceQueryTxt;
1600     info.mMessageId         = 0;
1601     info.mTransmissionCount = 0;
1602     info.mMainQuery         = &aQuery;
1603     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1604 
1605     info.mQueryType         = kServiceQuerySrv;
1606     info.mMessageId         = 0;
1607     info.mTransmissionCount = 0;
1608     info.mNextQuery         = secondQuery;
1609     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1610     error = kErrorNone;
1611 
1612 exit:
1613     return error;
1614 }
1615 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1616 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1617 {
1618     QueryInfo   info;
1619     Response    response;
1620     ServiceInfo serviceInfo;
1621     char        hostName[Name::kMaxNameSize];
1622 
1623     info.ReadFrom(aQuery);
1624 
1625     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1626     VerifyOrExit(info.mShouldResolveHostAddr);
1627 
1628     PopulateResponse(response, aQuery, aResponseMessage);
1629 
1630     ClearAllBytes(serviceInfo);
1631     serviceInfo.mHostNameBuffer     = hostName;
1632     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1633     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1634 
1635     // Check whether AAAA record for host address is provided in the SRV query response
1636 
1637     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1638     {
1639         Query *newQuery;
1640 
1641         info.mQueryType         = kIp6AddressQuery;
1642         info.mMessageId         = 0;
1643         info.mTransmissionCount = 0;
1644         info.mMainQuery         = &FindMainQuery(aQuery);
1645 
1646         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1647         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1648 
1649         // Update `aQuery` to be linked with new query (inserting
1650         // the `newQuery` into the linked-list after `aQuery`).
1651 
1652         info.ReadFrom(aQuery);
1653         info.mNextQuery = newQuery;
1654         UpdateQuery(aQuery, info);
1655     }
1656 
1657 exit:
1658     return;
1659 }
1660 
1661 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1662 
1663 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1664 void Client::PrepareTcpMessage(Message &aMessage)
1665 {
1666     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1667 
1668     // Prepending the DNS query with length of the packet according to RFC1035.
1669     BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1670     SuccessOrAssert(
1671         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1672     mSendLink.mLength += length + sizeof(uint16_t);
1673 }
1674 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1675 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1676 {
1677     OT_UNUSED_VARIABLE(aEndpoint);
1678     OT_UNUSED_VARIABLE(aData);
1679     OT_ASSERT(mTcpState == kTcpConnectedSending);
1680 
1681     mSendLink.mLength = 0;
1682     mTcpState         = kTcpConnectedIdle;
1683 }
1684 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1685 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1686 {
1687     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1688 }
1689 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1690 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1691 {
1692     OT_UNUSED_VARIABLE(aEndpoint);
1693     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1694     mTcpState = kTcpConnectedSending;
1695 }
1696 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1697 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1698 {
1699     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1700 }
1701 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1702 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1703                                  size_t                &aOffset,
1704                                  Message               &aMessage,
1705                                  uint16_t               aLength)
1706 {
1707     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1708     // and copy the content into `aMessage`. As we read we can move
1709     // to the next `aLinkedBuffer` and update `aOffset`.
1710     // Returns:
1711     // - `kErrorNone` if `aLength` bytes are successfully read and
1712     //    `aOffset` and `aLinkedBuffer` are updated.
1713     // - `kErrorNotFound` is not enough bytes available to read
1714     //    from `aLinkedBuffer`.
1715     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1716 
1717     Error error = kErrorNone;
1718 
1719     while (aLength > 0)
1720     {
1721         uint16_t bytesToRead = aLength;
1722 
1723         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1724 
1725         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1726         {
1727             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1728         }
1729 
1730         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1731 
1732         aLength -= bytesToRead;
1733         aOffset += bytesToRead;
1734 
1735         if (aOffset == aLinkedBuffer->mLength)
1736         {
1737             aLinkedBuffer = aLinkedBuffer->mNext;
1738             aOffset       = 0;
1739         }
1740     }
1741 
1742 exit:
1743     return error;
1744 }
1745 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1746 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1747                                        size_t         aBytesAvailable,
1748                                        bool           aEndOfStream,
1749                                        size_t         aBytesRemaining)
1750 {
1751     OT_UNUSED_VARIABLE(aEndpoint);
1752     OT_UNUSED_VARIABLE(aBytesRemaining);
1753 
1754     Message              *message   = nullptr;
1755     size_t                totalRead = 0;
1756     size_t                offset    = 0;
1757     const otLinkedBuffer *data;
1758 
1759     if (aEndOfStream)
1760     {
1761         // Cleanup is done in disconnected callback.
1762         IgnoreError(mEndpoint.SendEndOfStream());
1763     }
1764 
1765     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1766     VerifyOrExit(data != nullptr);
1767 
1768     message = mSocket.NewMessage();
1769     VerifyOrExit(message != nullptr);
1770 
1771     while (aBytesAvailable > totalRead)
1772     {
1773         uint16_t length;
1774 
1775         // Read the `length` field.
1776         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1777 
1778         IgnoreError(message->Read(/* aOffset */ 0, length));
1779         length = BigEndian::HostSwap16(length);
1780 
1781         // Try to read `length` bytes.
1782         IgnoreError(message->SetLength(0));
1783         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1784 
1785         totalRead += length + sizeof(uint16_t);
1786 
1787         // Now process the read message as query response.
1788         ProcessResponse(*message);
1789 
1790         IgnoreError(message->SetLength(0));
1791 
1792         // Loop again to see if we can read another response.
1793     }
1794 
1795 exit:
1796     // Inform `mEndPoint` about the total read and processed bytes
1797     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1798     FreeMessage(message);
1799 }
1800 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1801 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1802                                                size_t         aBytesAvailable,
1803                                                bool           aEndOfStream,
1804                                                size_t         aBytesRemaining)
1805 {
1806     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1807         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1808 }
1809 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1810 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1811 {
1812     OT_UNUSED_VARIABLE(aEndpoint);
1813     OT_UNUSED_VARIABLE(aReason);
1814     QueryInfo info;
1815 
1816     IgnoreError(mEndpoint.Deinitialize());
1817     mTcpState = kTcpUninitialized;
1818 
1819     // Abort queries in case of connection failures
1820     for (Query &mainQuery : mMainQueries)
1821     {
1822         info.ReadFrom(mainQuery);
1823 
1824         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1825         {
1826             FinalizeQuery(mainQuery, kErrorAbort);
1827         }
1828     }
1829 }
1830 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1831 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1832 {
1833     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1834 }
1835 
1836 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1837 
1838 } // namespace Dns
1839 } // namespace ot
1840 
1841 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1842