1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 RegisterLogModule("DnsClient");
53 
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56 
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58 
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61     OT_UNUSED_VARIABLE(aMode);
62 
63     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64     GetServerSockAddr().SetPort(kDefaultServerPort);
65     SetResponseTimeout(kDefaultResponseTimeout);
66     SetMaxTxAttempts(kDefaultMaxTxAttempts);
67     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68     SetServiceMode(kDefaultServiceMode);
69 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
70     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
71 #endif
72     SetTransportProto(kDnsTransportUdp);
73 }
74 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)75 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
76 {
77     // This method sets the config from `aConfig` replacing any
78     // unspecified fields (value zero) with the fields from
79     // `aDefaultConfig`. If `aConfig` is `nullptr` then
80     // `aDefaultConfig` is used.
81 
82     if (aConfig == nullptr)
83     {
84         *this = aDefaultConfig;
85         ExitNow();
86     }
87 
88     *this = *aConfig;
89 
90     if (GetServerSockAddr().GetAddress().IsUnspecified())
91     {
92         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
93     }
94 
95     if (GetServerSockAddr().GetPort() == 0)
96     {
97         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
98     }
99 
100     if (GetResponseTimeout() == 0)
101     {
102         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
103     }
104 
105     if (GetMaxTxAttempts() == 0)
106     {
107         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
108     }
109 
110     if (GetRecursionFlag() == kFlagUnspecified)
111     {
112         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
113     }
114 
115 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
116     if (GetNat64Mode() == kNat64Unspecified)
117     {
118         SetNat64Mode(aDefaultConfig.GetNat64Mode());
119     }
120 #endif
121 
122     if (GetServiceMode() == kServiceModeUnspecified)
123     {
124         SetServiceMode(aDefaultConfig.GetServiceMode());
125     }
126 
127     if (GetTransportProto() == kDnsTransportUnspecified)
128     {
129         SetTransportProto(aDefaultConfig.GetTransportProto());
130     }
131 
132 exit:
133     return;
134 }
135 
136 //---------------------------------------------------------------------------------------------------------------------
137 // Client::Response
138 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const139 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
140 {
141     switch (aSection)
142     {
143     case kAnswerSection:
144         aOffset    = mAnswerOffset;
145         aNumRecord = mAnswerRecordCount;
146         break;
147     case kAdditionalDataSection:
148     default:
149         aOffset    = mAdditionalOffset;
150         aNumRecord = mAdditionalRecordCount;
151         break;
152     }
153 }
154 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const155 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
156 {
157     uint16_t offset = kNameOffsetInQuery;
158 
159     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
160 }
161 
CheckForHostNameAlias(Section aSection,Name & aHostName) const162 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
163 {
164     // If the response includes a CNAME record mapping the query host
165     // name to a canonical name, we update `aHostName` to the new alias
166     // name. Otherwise `aHostName` remains as before. This method handles
167     // when there are multiple CNAME records mapping the host name multiple
168     // times. We limit number of changes to `kMaxCnameAliasNameChanges`
169     // to detect and handle if the response contains CNAME record loops.
170 
171     Error       error;
172     uint16_t    offset;
173     uint16_t    numRecords;
174     CnameRecord cnameRecord;
175 
176     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
177 
178     for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
179     {
180         SelectSection(aSection, offset, numRecords);
181         error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182 
183         if (error == kErrorNotFound)
184         {
185             error = kErrorNone;
186             ExitNow();
187         }
188 
189         SuccessOrExit(error);
190 
191         // A CNAME record was found. `offset` now points to after the
192         // last read byte within the `mMessage` into the `cnameRecord`
193         // (which is the start of the new canonical name).
194 
195         aHostName.SetFromMessage(*mMessage, offset);
196         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
197 
198         // Loop back to check if there may be a CNAME record for the
199         // new `aHostName`.
200     }
201 
202     error = kErrorParse;
203 
204 exit:
205     return error;
206 }
207 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const208 Error Client::Response::FindHostAddress(Section       aSection,
209                                         const Name   &aHostName,
210                                         uint16_t      aIndex,
211                                         Ip6::Address &aAddress,
212                                         uint32_t     &aTtl) const
213 {
214     Error      error;
215     uint16_t   offset;
216     uint16_t   numRecords;
217     Name       name = aHostName;
218     AaaaRecord aaaaRecord;
219 
220     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
221 
222     SelectSection(aSection, offset, numRecords);
223     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
224     aAddress = aaaaRecord.GetAddress();
225     aTtl     = aaaaRecord.GetTtl();
226 
227 exit:
228     return error;
229 }
230 
231 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
232 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const233 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
234 {
235     Error    error;
236     uint16_t offset;
237     uint16_t numRecords;
238     Name     name = aHostName;
239 
240     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
241 
242     SelectSection(aSection, offset, numRecords);
243     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
244 
245 exit:
246     return error;
247 }
248 
249 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
250 
251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
252 
InitServiceInfo(ServiceInfo & aServiceInfo) const253 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
254 {
255     // This method initializes `aServiceInfo` setting all
256     // TTLs to zero and host name to empty string.
257 
258     aServiceInfo.mTtl              = 0;
259     aServiceInfo.mHostAddressTtl   = 0;
260     aServiceInfo.mTxtDataTtl       = 0;
261     aServiceInfo.mTxtDataTruncated = false;
262 
263     AsCoreType(&aServiceInfo.mHostAddress).Clear();
264 
265     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
266     {
267         aServiceInfo.mHostNameBuffer[0] = '\0';
268     }
269 }
270 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const271 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
272 {
273     // This method searches for SRV record in the given `aSection`
274     // matching the record name against `aName`, and updates the
275     // `aServiceInfo` accordingly. It also searches for AAAA record
276     // for host name associated with the service (from SRV record).
277     // The search for AAAA record is always performed in Additional
278     // Data section (independent of the value given in `aSection`).
279 
280     Error     error = kErrorNone;
281     uint16_t  offset;
282     uint16_t  numRecords;
283     Name      hostName;
284     SrvRecord srvRecord;
285 
286     // A non-zero `mTtl` indicates that SRV record is already found
287     // and parsed from a previous response.
288     VerifyOrExit(aServiceInfo.mTtl == 0);
289 
290     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
291 
292     // Search for a matching SRV record
293     SelectSection(aSection, offset, numRecords);
294     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
295 
296     aServiceInfo.mTtl      = srvRecord.GetTtl();
297     aServiceInfo.mPort     = srvRecord.GetPort();
298     aServiceInfo.mPriority = srvRecord.GetPriority();
299     aServiceInfo.mWeight   = srvRecord.GetWeight();
300 
301     hostName.SetFromMessage(*mMessage, offset);
302 
303     if (aServiceInfo.mHostNameBuffer != nullptr)
304     {
305         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
306                                                            aServiceInfo.mHostNameBufferSize));
307     }
308     else
309     {
310         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
311     }
312 
313     // Search in additional section for AAAA record for the host name.
314 
315     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
316 
317     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
318                             aServiceInfo.mHostAddressTtl);
319 
320     if (error == kErrorNotFound)
321     {
322         error = kErrorNone;
323     }
324 
325 exit:
326     return error;
327 }
328 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const329 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
330 {
331     // This method searches a TXT record in the given `aSection`
332     // matching the record name against `aName` and updates the TXT
333     // related properties in `aServicesInfo`.
334     //
335     // If no match is found `mTxtDataTtl` (which is initialized to zero)
336     // remains unchanged to indicate this. In this case this method still
337     // returns `kErrorNone`.
338 
339     Error     error = kErrorNone;
340     uint16_t  offset;
341     uint16_t  numRecords;
342     TxtRecord txtRecord;
343 
344     // A non-zero `mTxtDataTtl` indicates that TXT record is already
345     // found and parsed from a previous response.
346     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
347 
348     // A null `mTxtData` indicates that caller does not want to retrieve
349     // TXT data.
350     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
351 
352     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
353 
354     SelectSection(aSection, offset, numRecords);
355 
356     aServiceInfo.mTxtDataTruncated = false;
357 
358     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
359 
360     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
361 
362     if (error == kErrorNoBufs)
363     {
364         error = kErrorNone;
365 
366         // Mark `mTxtDataTruncated` to indicate that we could not read
367         // the full TXT record into the given `mTxtData` buffer.
368         aServiceInfo.mTxtDataTruncated = true;
369     }
370 
371     SuccessOrExit(error);
372     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
373 
374 exit:
375     if (error == kErrorNotFound)
376     {
377         error = kErrorNone;
378     }
379 
380     return error;
381 }
382 
383 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
384 
PopulateFrom(const Message & aMessage)385 void Client::Response::PopulateFrom(const Message &aMessage)
386 {
387     // Populate `Response` with info from `aMessage`.
388 
389     uint16_t offset = aMessage.GetOffset();
390     Header   header;
391 
392     mMessage = &aMessage;
393 
394     IgnoreError(aMessage.Read(offset, header));
395     offset += sizeof(Header);
396 
397     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
398     {
399         IgnoreError(Name::ParseName(aMessage, offset));
400         offset += sizeof(Question);
401     }
402 
403     mAnswerOffset = offset;
404     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
405     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
406     mAdditionalOffset = offset;
407     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
408 
409     mAnswerRecordCount     = header.GetAnswerCount();
410     mAdditionalRecordCount = header.GetAdditionalRecordCount();
411 }
412 
413 //---------------------------------------------------------------------------------------------------------------------
414 // Client::AddressResponse
415 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const416 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
417 {
418     Error error = kErrorNone;
419     Name  name(*mQuery, kNameOffsetInQuery);
420 
421 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
422 
423     // If the response is for an IPv4 address query or if it is an
424     // IPv6 address query response with no IPv6 address but with
425     // an IPv4 in its additional section, we read the IPv4 address
426     // and translate it to an IPv6 address.
427 
428     QueryInfo info;
429 
430     info.ReadFrom(*mQuery);
431 
432     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
433     {
434         Section                          section;
435         ARecord                          aRecord;
436         NetworkData::ExternalRouteConfig nat64Prefix;
437 
438         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
439                      error = kErrorInvalidState);
440 
441         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
442         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
443 
444         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
445         aTtl = aRecord.GetTtl();
446 
447         ExitNow();
448     }
449 
450 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
451 
452     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
453 
454 exit:
455     return error;
456 }
457 
458 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
459 
460 //---------------------------------------------------------------------------------------------------------------------
461 // Client::BrowseResponse
462 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const463 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
464 {
465     Error     error;
466     uint16_t  offset;
467     uint16_t  numRecords;
468     Name      serviceName(*mQuery, kNameOffsetInQuery);
469     PtrRecord ptrRecord;
470 
471     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
472 
473     SelectSection(kAnswerSection, offset, numRecords);
474     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
475     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
476 
477 exit:
478     return error;
479 }
480 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const481 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
482 {
483     Error error;
484     Name  instanceName;
485 
486     // Find a matching PTR record for the service instance label. Then
487     // search and read SRV, TXT and AAAA records in Additional Data
488     // section matching the same name to populate `aServiceInfo`.
489 
490     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
491 
492     InitServiceInfo(aServiceInfo);
493     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
494     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
495 
496     if (aServiceInfo.mTxtDataTtl == 0)
497     {
498         aServiceInfo.mTxtDataSize = 0;
499     }
500 
501 exit:
502     return error;
503 }
504 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const505 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
506                                              uint16_t      aIndex,
507                                              Ip6::Address &aAddress,
508                                              uint32_t     &aTtl) const
509 {
510     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
511 }
512 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const513 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
514 {
515     // This method searches within the Answer Section for a PTR record
516     // matching a given instance label @aInstanceLabel. If found, the
517     // `aName` is updated to return the name in the message.
518 
519     Error     error;
520     uint16_t  offset;
521     Name      serviceName(*mQuery, kNameOffsetInQuery);
522     uint16_t  numRecords;
523     uint16_t  labelOffset;
524     PtrRecord ptrRecord;
525 
526     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
527 
528     SelectSection(kAnswerSection, offset, numRecords);
529 
530     for (; numRecords > 0; numRecords--)
531     {
532         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
533 
534         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
535 
536         if (error == kErrorNotFound)
537         {
538             // `ReadRecord()` updates `offset` to skip over a
539             // non-matching record.
540             continue;
541         }
542 
543         SuccessOrExit(error);
544 
545         // It is a PTR record. Check the first label to match the
546         // instance label.
547 
548         labelOffset = offset;
549         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
550 
551         if (error == kErrorNone)
552         {
553             aInstanceName.SetFromMessage(*mMessage, offset);
554             ExitNow();
555         }
556 
557         VerifyOrExit(error == kErrorNotFound);
558 
559         // Update offset to skip over the PTR record.
560         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
561     }
562 
563     error = kErrorNotFound;
564 
565 exit:
566     return error;
567 }
568 
569 //---------------------------------------------------------------------------------------------------------------------
570 // Client::ServiceResponse
571 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const572 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
573                                               uint8_t  aLabelBufferSize,
574                                               char    *aNameBuffer,
575                                               uint16_t aNameBufferSize) const
576 {
577     Error    error;
578     uint16_t offset = kNameOffsetInQuery;
579 
580     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
581 
582     VerifyOrExit(aNameBuffer != nullptr);
583     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
584 
585 exit:
586     return error;
587 }
588 
GetServiceInfo(ServiceInfo & aServiceInfo) const589 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
590 {
591     // Search and read SRV, TXT records matching name from query.
592 
593     Error error = kErrorNotFound;
594 
595     InitServiceInfo(aServiceInfo);
596 
597     for (const Response *response = this; response != nullptr; response = response->mNext)
598     {
599         Name      name(*response->mQuery, kNameOffsetInQuery);
600         QueryInfo info;
601         Section   srvSection;
602         Section   txtSection;
603 
604         info.ReadFrom(*response->mQuery);
605 
606         switch (info.mQueryType)
607         {
608         case kIp6AddressQuery:
609 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
610         case kIp4AddressQuery:
611 #endif
612             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
613                                                   AsCoreType(&aServiceInfo.mHostAddress),
614                                                   aServiceInfo.mHostAddressTtl));
615 
616             continue; // to `for()` loop
617 
618         case kServiceQuerySrvTxt:
619         case kServiceQuerySrv:
620         case kServiceQueryTxt:
621             break;
622 
623         default:
624             continue;
625         }
626 
627         // Determine from which section we should try to read the SRV and
628         // TXT records based on the query type.
629         //
630         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
631         // only one record (SRV or TXT) in the answer section, but we
632         // still try to read the other records from additional data
633         // section in case server provided them.
634 
635         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
636         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
637 
638         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
639 
640         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
641         {
642             error = kErrorNone;
643         }
644 
645         SuccessOrExit(error);
646 
647         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
648     }
649 
650     if (aServiceInfo.mTxtDataTtl == 0)
651     {
652         aServiceInfo.mTxtDataSize = 0;
653     }
654 
655 exit:
656     return error;
657 }
658 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const659 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
660                                               uint16_t      aIndex,
661                                               Ip6::Address &aAddress,
662                                               uint32_t     &aTtl) const
663 {
664     Error error = kErrorNotFound;
665 
666     for (const Response *response = this; response != nullptr; response = response->mNext)
667     {
668         Section   section = kAdditionalDataSection;
669         QueryInfo info;
670 
671         info.ReadFrom(*response->mQuery);
672 
673         switch (info.mQueryType)
674         {
675         case kIp6AddressQuery:
676 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
677         case kIp4AddressQuery:
678 #endif
679             section = kAnswerSection;
680             break;
681 
682         default:
683             break;
684         }
685 
686         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
687 
688         if (error == kErrorNone)
689         {
690             break;
691         }
692     }
693 
694     return error;
695 }
696 
697 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
698 
699 //---------------------------------------------------------------------------------------------------------------------
700 // Client
701 
702 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
704 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
705 #endif
706 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
707 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
708 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
709 #endif
710 
711 const uint8_t Client::kQuestionCount[] = {
712     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
714     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
715 #endif
716 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
717     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
718     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
719     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
720     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
721 #endif
722 };
723 
724 const uint16_t *const Client::kQuestionRecordTypes[] = {
725     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
727     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
728 #endif
729 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
730     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
731     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
732     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
733     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
734 
735 #endif
736 };
737 
Client(Instance & aInstance)738 Client::Client(Instance &aInstance)
739     : InstanceLocator(aInstance)
740     , mSocket(aInstance, *this)
741 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
742     , mTcpState(kTcpUninitialized)
743 #endif
744     , mTimer(aInstance)
745     , mDefaultConfig(QueryConfig::kInitFromDefaults)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
747     , mUserDidSetDefaultAddress(false)
748 #endif
749 {
750     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
752     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
753 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
754     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
755     static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
756     static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
757     static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
758 #endif
759 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
760     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
761     static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
762     static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
763     static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
764 #endif
765 }
766 
Start(void)767 Error Client::Start(void)
768 {
769     Error error;
770 
771     SuccessOrExit(error = mSocket.Open());
772     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
773 
774 exit:
775     return error;
776 }
777 
Stop(void)778 void Client::Stop(void)
779 {
780     Query *query;
781 
782     while ((query = mMainQueries.GetHead()) != nullptr)
783     {
784         FinalizeQuery(*query, kErrorAbort);
785     }
786 
787     IgnoreError(mSocket.Close());
788 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
789     if (mTcpState != kTcpUninitialized)
790     {
791         IgnoreError(mEndpoint.Deinitialize());
792     }
793 #endif
794 
795     mLimitedQueryServers.Clear();
796 }
797 
798 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)799 Error Client::InitTcpSocket(void)
800 {
801     Error                       error;
802     otTcpEndpointInitializeArgs endpointArgs;
803 
804     ClearAllBytes(endpointArgs);
805     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
806     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
807     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
808     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
809     endpointArgs.mContext                  = this;
810     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
811     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
812 
813     mSendLink.mNext   = nullptr;
814     mSendLink.mData   = mSendBufferBytes;
815     mSendLink.mLength = 0;
816 
817     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
818 exit:
819     return error;
820 }
821 #endif
822 
SetDefaultConfig(const QueryConfig & aQueryConfig)823 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
824 {
825     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
826 
827     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
828 
829 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
830     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
831     UpdateDefaultConfigAddress();
832 #endif
833 }
834 
ResetDefaultConfig(void)835 void Client::ResetDefaultConfig(void)
836 {
837     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
838 
839 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
840     mUserDidSetDefaultAddress = false;
841     UpdateDefaultConfigAddress();
842 #endif
843 }
844 
845 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)846 void Client::UpdateDefaultConfigAddress(void)
847 {
848     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
849 
850     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
851         !srpServerAddr.IsUnspecified())
852     {
853         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
854     }
855 }
856 #endif
857 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)858 Error Client::ResolveAddress(const char        *aHostName,
859                              AddressCallback    aCallback,
860                              void              *aContext,
861                              const QueryConfig *aConfig)
862 {
863     QueryInfo info;
864 
865     info.Clear();
866     info.mQueryType = kIp6AddressQuery;
867     info.mConfig.SetFrom(aConfig, mDefaultConfig);
868     info.mCallback.mAddressCallback = aCallback;
869     info.mCallbackContext           = aContext;
870 
871     return StartQuery(info, nullptr, aHostName);
872 }
873 
874 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)875 Error Client::ResolveIp4Address(const char        *aHostName,
876                                 AddressCallback    aCallback,
877                                 void              *aContext,
878                                 const QueryConfig *aConfig)
879 {
880     QueryInfo info;
881 
882     info.Clear();
883     info.mQueryType = kIp4AddressQuery;
884     info.mConfig.SetFrom(aConfig, mDefaultConfig);
885     info.mCallback.mAddressCallback = aCallback;
886     info.mCallbackContext           = aContext;
887 
888     return StartQuery(info, nullptr, aHostName);
889 }
890 #endif
891 
892 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
893 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)894 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
895 {
896     QueryInfo info;
897 
898     info.Clear();
899     info.mQueryType = kBrowseQuery;
900     info.mConfig.SetFrom(aConfig, mDefaultConfig);
901     info.mCallback.mBrowseCallback = aCallback;
902     info.mCallbackContext          = aContext;
903 
904     return StartQuery(info, nullptr, aServiceName);
905 }
906 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)907 Error Client::ResolveService(const char        *aInstanceLabel,
908                              const char        *aServiceName,
909                              ServiceCallback    aCallback,
910                              void              *aContext,
911                              const QueryConfig *aConfig)
912 {
913     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
914 }
915 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)916 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
917                                            const char        *aServiceName,
918                                            ServiceCallback    aCallback,
919                                            void              *aContext,
920                                            const QueryConfig *aConfig)
921 {
922     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
923 }
924 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)925 Error Client::Resolve(const char        *aInstanceLabel,
926                       const char        *aServiceName,
927                       ServiceCallback    aCallback,
928                       void              *aContext,
929                       const QueryConfig *aConfig,
930                       bool               aShouldResolveHostAddr)
931 {
932     QueryInfo info;
933     Error     error;
934     QueryType secondQueryType = kNoQuery;
935 
936     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
937 
938     info.Clear();
939 
940     info.mConfig.SetFrom(aConfig, mDefaultConfig);
941     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
942 
943     CheckAndUpdateServiceMode(info.mConfig, aConfig);
944 
945     switch (info.mConfig.GetServiceMode())
946     {
947     case QueryConfig::kServiceModeSrvTxtSeparate:
948         secondQueryType = kServiceQueryTxt;
949 
950         OT_FALL_THROUGH;
951 
952     case QueryConfig::kServiceModeSrv:
953         info.mQueryType = kServiceQuerySrv;
954         break;
955 
956     case QueryConfig::kServiceModeTxt:
957         info.mQueryType = kServiceQueryTxt;
958         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
959         break;
960 
961     case QueryConfig::kServiceModeSrvTxt:
962     case QueryConfig::kServiceModeSrvTxtOptimize:
963     default:
964         info.mQueryType = kServiceQuerySrvTxt;
965         break;
966     }
967 
968     info.mCallback.mServiceCallback = aCallback;
969     info.mCallbackContext           = aContext;
970 
971     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
972 
973 exit:
974     return error;
975 }
976 
977 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
978 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)979 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
980 {
981     // The `aLabel` can be `nullptr` and then `aName` provides the
982     // full name, otherwise the name is appended as `{aLabel}.
983     // {aName}`.
984 
985     Error  error;
986     Query *query;
987 
988     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
989 
990 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
991     if (aInfo.mQueryType == kIp4AddressQuery)
992     {
993         NetworkData::ExternalRouteConfig nat64Prefix;
994 
995         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
996         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
997                      error = kErrorInvalidState);
998     }
999 #endif
1000 
1001     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
1002 
1003     mMainQueries.Enqueue(*query);
1004 
1005     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1006     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1007 
1008     if (aSecondType != kNoQuery)
1009     {
1010         Query *secondQuery;
1011 
1012         aInfo.mQueryType         = aSecondType;
1013         aInfo.mMessageId         = 0;
1014         aInfo.mTransmissionCount = 0;
1015         aInfo.mMainQuery         = query;
1016 
1017         // We intentionally do not use `error` here so in the unlikely
1018         // case where we cannot allocate the second query we can proceed
1019         // with the first one.
1020         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1021 
1022         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1023 
1024         // Update first query to link to second one by updating
1025         // its `mNextQuery`.
1026         aInfo.ReadFrom(*query);
1027         aInfo.mNextQuery = secondQuery;
1028         UpdateQuery(*query, aInfo);
1029     }
1030 
1031 exit:
1032     return error;
1033 }
1034 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1035 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1036 {
1037     Error error = kErrorNone;
1038 
1039     aQuery = nullptr;
1040 
1041     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1042 
1043     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1044     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1045 
1046     SuccessOrExit(error = aQuery->Append(aInfo));
1047 
1048     if (aLabel != nullptr)
1049     {
1050         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1051     }
1052 
1053     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1054 
1055 exit:
1056     FreeAndNullMessageOnError(aQuery, error);
1057     return error;
1058 }
1059 
FindMainQuery(Query & aQuery)1060 Client::Query &Client::FindMainQuery(Query &aQuery)
1061 {
1062     QueryInfo info;
1063 
1064     info.ReadFrom(aQuery);
1065 
1066     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1067 }
1068 
FreeQuery(Query & aQuery)1069 void Client::FreeQuery(Query &aQuery)
1070 {
1071     Query    &mainQuery = FindMainQuery(aQuery);
1072     QueryInfo info;
1073 
1074     mMainQueries.Dequeue(mainQuery);
1075 
1076     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1077     {
1078         info.ReadFrom(*query);
1079         FreeMessage(info.mSavedResponse);
1080         query->Free();
1081     }
1082 }
1083 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1084 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1085 {
1086     // This method prepares and sends a query message represented by
1087     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1088     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1089     // well. `aUpdateTimer` indicates whether the timer should be
1090     // updated when query is sent or not (used in the case where timer
1091     // is handled by caller).
1092 
1093     Error            error   = kErrorNone;
1094     Message         *message = nullptr;
1095     Header           header;
1096     Ip6::MessageInfo messageInfo;
1097     uint16_t         length = 0;
1098 
1099     aInfo.mTransmissionCount++;
1100     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1101 
1102     if (aInfo.mMessageId == 0)
1103     {
1104         do
1105         {
1106             SuccessOrExit(error = header.SetRandomMessageId());
1107         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1108 
1109         aInfo.mMessageId = header.GetMessageId();
1110     }
1111     else
1112     {
1113         header.SetMessageId(aInfo.mMessageId);
1114     }
1115 
1116     header.SetType(Header::kTypeQuery);
1117     header.SetQueryType(Header::kQueryTypeStandard);
1118 
1119     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1120     {
1121         header.SetRecursionDesiredFlag();
1122     }
1123 
1124     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1125 
1126     message = mSocket.NewMessage();
1127     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1128 
1129     SuccessOrExit(error = message->Append(header));
1130 
1131     // Prepare the question section.
1132 
1133     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1134     {
1135         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1136         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1137     }
1138 
1139     length = message->GetLength() - message->GetOffset();
1140 
1141     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1142 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1143     {
1144         // Check if query will fit into tcp buffer if not return error.
1145         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1146                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1147                      error = kErrorNoBufs);
1148 
1149         // In case of initialized connection check if connected peer and new query have the same address.
1150         if (mTcpState != kTcpUninitialized)
1151         {
1152             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1153                          error = kErrorFailed);
1154         }
1155 
1156         switch (mTcpState)
1157         {
1158         case kTcpUninitialized:
1159             SuccessOrExit(error = InitTcpSocket());
1160             SuccessOrExit(
1161                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1162             mTcpState = kTcpConnecting;
1163             PrepareTcpMessage(*message);
1164             break;
1165         case kTcpConnectedIdle:
1166             PrepareTcpMessage(*message);
1167             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1168             mTcpState = kTcpConnectedSending;
1169             break;
1170         case kTcpConnecting:
1171             PrepareTcpMessage(*message);
1172             break;
1173         case kTcpConnectedSending:
1174             BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1175             SuccessOrAssert(error = message->Read(message->GetOffset(),
1176                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1177             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1178             break;
1179         }
1180         message->Free();
1181         message = nullptr;
1182     }
1183 #else
1184     {
1185         error = kErrorInvalidArgs;
1186         LogWarn("DNS query over TCP not supported.");
1187         ExitNow();
1188     }
1189 #endif
1190     else
1191     {
1192         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1193         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1194         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1195         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1196     }
1197 
1198 exit:
1199 
1200     FreeMessageOnError(message, error);
1201     if (aUpdateTimer)
1202     {
1203         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1204     }
1205 
1206     UpdateQuery(aQuery, aInfo);
1207 
1208     return error;
1209 }
1210 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1211 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1212 {
1213     // The name is encoded and included after the `Info` in `aQuery`
1214     // starting at `kNameOffsetInQuery`.
1215 
1216     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1217 }
1218 
FinalizeQuery(Query & aQuery,Error aError)1219 void Client::FinalizeQuery(Query &aQuery, Error aError)
1220 {
1221     Response response;
1222     Query   &mainQuery = FindMainQuery(aQuery);
1223 
1224     response.mInstance = &Get<Instance>();
1225     response.mQuery    = &mainQuery;
1226 
1227     FinalizeQuery(response, aError);
1228 }
1229 
FinalizeQuery(Response & aResponse,Error aError)1230 void Client::FinalizeQuery(Response &aResponse, Error aError)
1231 {
1232     QueryType type;
1233     Callback  callback;
1234     void     *context;
1235 
1236     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1237 
1238     switch (type)
1239     {
1240     case kIp6AddressQuery:
1241 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1242     case kIp4AddressQuery:
1243 #endif
1244         if (callback.mAddressCallback != nullptr)
1245         {
1246             callback.mAddressCallback(aError, &aResponse, context);
1247         }
1248         break;
1249 
1250 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1251     case kBrowseQuery:
1252         if (callback.mBrowseCallback != nullptr)
1253         {
1254             callback.mBrowseCallback(aError, &aResponse, context);
1255         }
1256         break;
1257 
1258     case kServiceQuerySrvTxt:
1259     case kServiceQuerySrv:
1260     case kServiceQueryTxt:
1261         if (callback.mServiceCallback != nullptr)
1262         {
1263             callback.mServiceCallback(aError, &aResponse, context);
1264         }
1265         break;
1266 #endif
1267     case kNoQuery:
1268         break;
1269     }
1270 
1271     FreeQuery(*aResponse.mQuery);
1272 }
1273 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1274 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1275 {
1276     QueryInfo info;
1277 
1278     info.ReadFrom(aQuery);
1279 
1280     aType     = info.mQueryType;
1281     aCallback = info.mCallback;
1282     aContext  = info.mCallbackContext;
1283 }
1284 
FindQueryById(uint16_t aMessageId)1285 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1286 {
1287     Query    *matchedQuery = nullptr;
1288     QueryInfo info;
1289 
1290     for (Query &mainQuery : mMainQueries)
1291     {
1292         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1293         {
1294             info.ReadFrom(*query);
1295 
1296             if (info.mMessageId == aMessageId)
1297             {
1298                 matchedQuery = query;
1299                 ExitNow();
1300             }
1301         }
1302     }
1303 
1304 exit:
1305     return matchedQuery;
1306 }
1307 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMsgInfo)1308 void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMsgInfo)
1309 {
1310     OT_UNUSED_VARIABLE(aMsgInfo);
1311     ProcessResponse(aMessage);
1312 }
1313 
ProcessResponse(const Message & aResponseMessage)1314 void Client::ProcessResponse(const Message &aResponseMessage)
1315 {
1316     Error  responseError;
1317     Query *query;
1318 
1319     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1320 
1321     if (responseError != kErrorNone)
1322     {
1323         // Received an error from server, check if we can replace
1324         // the query.
1325 
1326 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1327         if (ReplaceWithIp4Query(*query) == kErrorNone)
1328         {
1329             ExitNow();
1330         }
1331 #endif
1332 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1333         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1334         {
1335             ExitNow();
1336         }
1337 #endif
1338 
1339         FinalizeQuery(*query, responseError);
1340         ExitNow();
1341     }
1342 
1343     // Received successful response from server.
1344 
1345 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1346     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1347 #endif
1348 
1349     if (!CanFinalizeQuery(*query))
1350     {
1351         SaveQueryResponse(*query, aResponseMessage);
1352         ExitNow();
1353     }
1354 
1355     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1356 
1357 exit:
1358     return;
1359 }
1360 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1361 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1362 {
1363     Error     error  = kErrorNone;
1364     uint16_t  offset = aResponseMessage.GetOffset();
1365     Header    header;
1366     QueryInfo info;
1367     Name      queryName;
1368 
1369     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1370     offset += sizeof(Header);
1371 
1372     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1373                      !header.IsTruncationFlagSet(),
1374                  error = kErrorDrop);
1375 
1376     aQuery = FindQueryById(header.GetMessageId());
1377     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1378 
1379     info.ReadFrom(*aQuery);
1380 
1381     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1382 
1383     // Check the Question Section
1384 
1385     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1386     {
1387         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1388         {
1389             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1390             offset += sizeof(Question);
1391         }
1392     }
1393     else
1394     {
1395         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1396                      error = kErrorParse);
1397     }
1398 
1399     // Check the answer, authority and additional record sections
1400 
1401     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1402     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1403     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1404 
1405     // Read the response code
1406 
1407     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1408 
1409     if ((aResponseError == kErrorNone) && (info.mQueryType == kServiceQuerySrvTxt))
1410     {
1411         RecordServerAsCapableOfMultiQuestions(info.mConfig.GetServerSockAddr().GetAddress());
1412     }
1413 
1414 exit:
1415     return error;
1416 }
1417 
CanFinalizeQuery(Query & aQuery)1418 bool Client::CanFinalizeQuery(Query &aQuery)
1419 {
1420     // Determines whether we can finalize a main query by checking if
1421     // we have received and saved responses for all other related
1422     // queries associated with `aQuery`. Note that this method is
1423     // called when we receive a response for `aQuery`, so no need to
1424     // check for a saved response for `aQuery` itself.
1425 
1426     bool      canFinalize = true;
1427     QueryInfo info;
1428 
1429     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1430     {
1431         info.ReadFrom(*query);
1432 
1433         if (query == &aQuery)
1434         {
1435             continue;
1436         }
1437 
1438         if (info.mSavedResponse == nullptr)
1439         {
1440             canFinalize = false;
1441             ExitNow();
1442         }
1443     }
1444 
1445 exit:
1446     return canFinalize;
1447 }
1448 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1449 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1450 {
1451     QueryInfo info;
1452 
1453     info.ReadFrom(aQuery);
1454     VerifyOrExit(info.mSavedResponse == nullptr);
1455 
1456     // If `Clone()` fails we let retry or timeout handle the error.
1457     info.mSavedResponse = aResponseMessage.Clone();
1458 
1459     UpdateQuery(aQuery, info);
1460 
1461 exit:
1462     return;
1463 }
1464 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1465 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1466 {
1467     // Populate `aResponse` for `aQuery`. If there is a saved response
1468     // message for `aQuery` we use it, otherwise, we use
1469     // `aResponseMessage`.
1470 
1471     QueryInfo info;
1472 
1473     info.ReadFrom(aQuery);
1474 
1475     aResponse.mInstance = &Get<Instance>();
1476     aResponse.mQuery    = &aQuery;
1477     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1478 
1479     return info.mNextQuery;
1480 }
1481 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1482 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1483 {
1484     // This method prepares a list of chained `Response` instances
1485     // corresponding to all related (chained) queries. It uses
1486     // recursion to go through the queries and construct the
1487     // `Response` chain.
1488 
1489     Response response;
1490     Query   *nextQuery;
1491 
1492     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1493     response.mNext = aPrevResponse;
1494 
1495     if (nextQuery != nullptr)
1496     {
1497         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1498     }
1499     else
1500     {
1501         FinalizeQuery(response, kErrorNone);
1502     }
1503 }
1504 
HandleTimer(void)1505 void Client::HandleTimer(void)
1506 {
1507     NextFireTime nextTime;
1508     QueryInfo    info;
1509 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1510     bool hasTcpQuery = false;
1511 #endif
1512 
1513     for (Query &mainQuery : mMainQueries)
1514     {
1515         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1516         {
1517             info.ReadFrom(*query);
1518 
1519             if (info.mSavedResponse != nullptr)
1520             {
1521                 continue;
1522             }
1523 
1524             if (nextTime.GetNow() >= info.mRetransmissionTime)
1525             {
1526                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1527                 {
1528                     FinalizeQuery(*query, kErrorResponseTimeout);
1529                     break;
1530                 }
1531 
1532 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1533                 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1534                 {
1535                     LogInfo("Switching to separate SRV/TXT on response timeout");
1536                     info.ReadFrom(*query);
1537                 }
1538                 else
1539 #endif
1540                 {
1541                     IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1542                 }
1543             }
1544 
1545             nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1546 
1547 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1548             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1549             {
1550                 hasTcpQuery = true;
1551             }
1552 #endif
1553         }
1554     }
1555 
1556     mTimer.FireAtIfEarlier(nextTime);
1557 
1558 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1559     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1560     {
1561         IgnoreError(mEndpoint.SendEndOfStream());
1562     }
1563 #endif
1564 }
1565 
1566 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1567 
ReplaceWithIp4Query(Query & aQuery)1568 Error Client::ReplaceWithIp4Query(Query &aQuery)
1569 {
1570     Error     error = kErrorFailed;
1571     QueryInfo info;
1572 
1573     info.ReadFrom(aQuery);
1574 
1575     VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1576     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1577 
1578     // We send a new query for IPv4 address resolution
1579     // for the same host name. We reuse the existing `aQuery`
1580     // instance and keep all the info but clear `mTransmissionCount`
1581     // and `mMessageId` (so that a new random message ID is
1582     // selected). The new `info` will be saved in the query in
1583     // `SendQuery()`. Note that the current query is still in the
1584     // `mMainQueries` list when `SendQuery()` selects a new random
1585     // message ID, so the existing message ID for this query will
1586     // not be reused.
1587 
1588     info.mQueryType         = kIp4AddressQuery;
1589     info.mMessageId         = 0;
1590     info.mTransmissionCount = 0;
1591 
1592     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1593     error = kErrorNone;
1594 
1595 exit:
1596     return error;
1597 }
1598 
1599 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1600 
1601 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1602 
CheckAndUpdateServiceMode(QueryConfig & aConfig,const QueryConfig * aRequestConfig) const1603 void Client::CheckAndUpdateServiceMode(QueryConfig &aConfig, const QueryConfig *aRequestConfig) const
1604 {
1605     // If the user explicitly requested "optimize" mode, we honor that
1606     // request. Otherwise, if "optimize" is chosen from the default
1607     // config, we check if the DNS server is known to have trouble
1608     // with multiple-question queries. If so, we switch to "separate"
1609     // mode.
1610 
1611     if ((aRequestConfig != nullptr) && (aRequestConfig->GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize))
1612     {
1613         ExitNow();
1614     }
1615 
1616     VerifyOrExit(aConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1617 
1618     if (mLimitedQueryServers.Contains(aConfig.GetServerSockAddr().GetAddress()))
1619     {
1620         aConfig.SetServiceMode(QueryConfig::kServiceModeSrvTxtSeparate);
1621     }
1622 
1623 exit:
1624     return;
1625 }
1626 
RecordServerAsLimitedToSingleQuestion(const Ip6::Address & aServerAddress)1627 void Client::RecordServerAsLimitedToSingleQuestion(const Ip6::Address &aServerAddress)
1628 {
1629     VerifyOrExit(!aServerAddress.IsUnspecified());
1630 
1631     VerifyOrExit(!mLimitedQueryServers.Contains(aServerAddress));
1632 
1633     if (mLimitedQueryServers.IsFull())
1634     {
1635         uint8_t randomIndex = Random::NonCrypto::GetUint8InRange(0, mLimitedQueryServers.GetMaxSize());
1636 
1637         mLimitedQueryServers.Remove(mLimitedQueryServers[randomIndex]);
1638     }
1639 
1640     IgnoreError(mLimitedQueryServers.PushBack(aServerAddress));
1641 
1642 exit:
1643     return;
1644 }
1645 
RecordServerAsCapableOfMultiQuestions(const Ip6::Address & aServerAddress)1646 void Client::RecordServerAsCapableOfMultiQuestions(const Ip6::Address &aServerAddress)
1647 {
1648     Ip6::Address *entry = mLimitedQueryServers.Find(aServerAddress);
1649 
1650     VerifyOrExit(entry != nullptr);
1651     mLimitedQueryServers.Remove(*entry);
1652 
1653 exit:
1654     return;
1655 }
1656 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1657 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1658 {
1659     Error     error = kErrorFailed;
1660     QueryInfo info;
1661     Query    *secondQuery;
1662 
1663     info.ReadFrom(aQuery);
1664 
1665     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1666     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1667 
1668     RecordServerAsLimitedToSingleQuestion(info.mConfig.GetServerSockAddr().GetAddress());
1669 
1670     secondQuery = aQuery.Clone();
1671     VerifyOrExit(secondQuery != nullptr);
1672 
1673     info.mQueryType         = kServiceQueryTxt;
1674     info.mMessageId         = 0;
1675     info.mTransmissionCount = 0;
1676     info.mMainQuery         = &aQuery;
1677     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1678 
1679     info.mQueryType         = kServiceQuerySrv;
1680     info.mMessageId         = 0;
1681     info.mTransmissionCount = 0;
1682     info.mNextQuery         = secondQuery;
1683     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1684     error = kErrorNone;
1685 
1686 exit:
1687     return error;
1688 }
1689 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1690 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1691 {
1692     QueryInfo   info;
1693     Response    response;
1694     ServiceInfo serviceInfo;
1695     char        hostName[Name::kMaxNameSize];
1696 
1697     info.ReadFrom(aQuery);
1698 
1699     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1700     VerifyOrExit(info.mShouldResolveHostAddr);
1701 
1702     PopulateResponse(response, aQuery, aResponseMessage);
1703 
1704     ClearAllBytes(serviceInfo);
1705     serviceInfo.mHostNameBuffer     = hostName;
1706     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1707     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1708 
1709     // Check whether AAAA record for host address is provided in the SRV query response
1710 
1711     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1712     {
1713         Query *newQuery;
1714 
1715         info.mQueryType         = kIp6AddressQuery;
1716         info.mMessageId         = 0;
1717         info.mTransmissionCount = 0;
1718         info.mMainQuery         = &FindMainQuery(aQuery);
1719 
1720         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1721         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1722 
1723         // Update `aQuery` to be linked with new query (inserting
1724         // the `newQuery` into the linked-list after `aQuery`).
1725 
1726         info.ReadFrom(aQuery);
1727         info.mNextQuery = newQuery;
1728         UpdateQuery(aQuery, info);
1729     }
1730 
1731 exit:
1732     return;
1733 }
1734 
1735 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1736 
1737 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1738 void Client::PrepareTcpMessage(Message &aMessage)
1739 {
1740     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1741 
1742     // Prepending the DNS query with length of the packet according to RFC1035.
1743     BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1744     SuccessOrAssert(
1745         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1746     mSendLink.mLength += length + sizeof(uint16_t);
1747 }
1748 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1749 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1750 {
1751     OT_UNUSED_VARIABLE(aEndpoint);
1752     OT_UNUSED_VARIABLE(aData);
1753     OT_ASSERT(mTcpState == kTcpConnectedSending);
1754 
1755     mSendLink.mLength = 0;
1756     mTcpState         = kTcpConnectedIdle;
1757 }
1758 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1759 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1760 {
1761     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1762 }
1763 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1764 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1765 {
1766     OT_UNUSED_VARIABLE(aEndpoint);
1767     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1768     mTcpState = kTcpConnectedSending;
1769 }
1770 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1771 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1772 {
1773     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1774 }
1775 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1776 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1777                                  size_t                &aOffset,
1778                                  Message               &aMessage,
1779                                  uint16_t               aLength)
1780 {
1781     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1782     // and copy the content into `aMessage`. As we read we can move
1783     // to the next `aLinkedBuffer` and update `aOffset`.
1784     // Returns:
1785     // - `kErrorNone` if `aLength` bytes are successfully read and
1786     //    `aOffset` and `aLinkedBuffer` are updated.
1787     // - `kErrorNotFound` is not enough bytes available to read
1788     //    from `aLinkedBuffer`.
1789     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1790 
1791     Error error = kErrorNone;
1792 
1793     while (aLength > 0)
1794     {
1795         uint16_t bytesToRead = aLength;
1796 
1797         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1798 
1799         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1800         {
1801             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1802         }
1803 
1804         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1805 
1806         aLength -= bytesToRead;
1807         aOffset += bytesToRead;
1808 
1809         if (aOffset == aLinkedBuffer->mLength)
1810         {
1811             aLinkedBuffer = aLinkedBuffer->mNext;
1812             aOffset       = 0;
1813         }
1814     }
1815 
1816 exit:
1817     return error;
1818 }
1819 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1820 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1821                                        size_t         aBytesAvailable,
1822                                        bool           aEndOfStream,
1823                                        size_t         aBytesRemaining)
1824 {
1825     OT_UNUSED_VARIABLE(aEndpoint);
1826     OT_UNUSED_VARIABLE(aBytesRemaining);
1827 
1828     Message              *message   = nullptr;
1829     size_t                totalRead = 0;
1830     size_t                offset    = 0;
1831     const otLinkedBuffer *data;
1832 
1833     if (aEndOfStream)
1834     {
1835         // Cleanup is done in disconnected callback.
1836         IgnoreError(mEndpoint.SendEndOfStream());
1837     }
1838 
1839     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1840     VerifyOrExit(data != nullptr);
1841 
1842     message = mSocket.NewMessage();
1843     VerifyOrExit(message != nullptr);
1844 
1845     while (aBytesAvailable > totalRead)
1846     {
1847         uint16_t length;
1848 
1849         // Read the `length` field.
1850         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1851 
1852         IgnoreError(message->Read(/* aOffset */ 0, length));
1853         length = BigEndian::HostSwap16(length);
1854 
1855         // Try to read `length` bytes.
1856         IgnoreError(message->SetLength(0));
1857         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1858 
1859         totalRead += length + sizeof(uint16_t);
1860 
1861         // Now process the read message as query response.
1862         ProcessResponse(*message);
1863 
1864         IgnoreError(message->SetLength(0));
1865 
1866         // Loop again to see if we can read another response.
1867     }
1868 
1869 exit:
1870     // Inform `mEndPoint` about the total read and processed bytes
1871     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1872     FreeMessage(message);
1873 }
1874 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1875 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1876                                                size_t         aBytesAvailable,
1877                                                bool           aEndOfStream,
1878                                                size_t         aBytesRemaining)
1879 {
1880     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1881         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1882 }
1883 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1884 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1885 {
1886     OT_UNUSED_VARIABLE(aEndpoint);
1887     OT_UNUSED_VARIABLE(aReason);
1888     QueryInfo info;
1889 
1890     IgnoreError(mEndpoint.Deinitialize());
1891     mTcpState = kTcpUninitialized;
1892 
1893     // Abort queries in case of connection failures
1894     for (Query &mainQuery : mMainQueries)
1895     {
1896         info.ReadFrom(mainQuery);
1897 
1898         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1899         {
1900             FinalizeQuery(mainQuery, kErrorAbort);
1901         }
1902     }
1903 }
1904 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1905 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1906 {
1907     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1908 }
1909 
1910 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1911 
1912 } // namespace Dns
1913 } // namespace ot
1914 
1915 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1916