1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43 
44 /**
45  * @file
46  *   This file implements the DNS client.
47  */
48 
49 namespace ot {
50 namespace Dns {
51 
52 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
53 using ot::Encoding::BigEndian::ReadUint16;
54 using ot::Encoding::BigEndian::WriteUint16;
55 #endif
56 
57 RegisterLogModule("DnsClient");
58 
59 //---------------------------------------------------------------------------------------------------------------------
60 // Client::QueryConfig
61 
62 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
63 
QueryConfig(InitMode aMode)64 Client::QueryConfig::QueryConfig(InitMode aMode)
65 {
66     OT_UNUSED_VARIABLE(aMode);
67 
68     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
69     GetServerSockAddr().SetPort(kDefaultServerPort);
70     SetResponseTimeout(kDefaultResponseTimeout);
71     SetMaxTxAttempts(kDefaultMaxTxAttempts);
72     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
73     SetServiceMode(kDefaultServiceMode);
74 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
75     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
76 #endif
77     SetTransportProto(kDnsTransportUdp);
78 }
79 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)80 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
81 {
82     // This method sets the config from `aConfig` replacing any
83     // unspecified fields (value zero) with the fields from
84     // `aDefaultConfig`. If `aConfig` is `nullptr` then
85     // `aDefaultConfig` is used.
86 
87     if (aConfig == nullptr)
88     {
89         *this = aDefaultConfig;
90         ExitNow();
91     }
92 
93     *this = *aConfig;
94 
95     if (GetServerSockAddr().GetAddress().IsUnspecified())
96     {
97         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
98     }
99 
100     if (GetServerSockAddr().GetPort() == 0)
101     {
102         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
103     }
104 
105     if (GetResponseTimeout() == 0)
106     {
107         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
108     }
109 
110     if (GetMaxTxAttempts() == 0)
111     {
112         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
113     }
114 
115     if (GetRecursionFlag() == kFlagUnspecified)
116     {
117         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
118     }
119 
120 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
121     if (GetNat64Mode() == kNat64Unspecified)
122     {
123         SetNat64Mode(aDefaultConfig.GetNat64Mode());
124     }
125 #endif
126 
127     if (GetServiceMode() == kServiceModeUnspecified)
128     {
129         SetServiceMode(aDefaultConfig.GetServiceMode());
130     }
131 
132     if (GetTransportProto() == kDnsTransportUnspecified)
133     {
134         SetTransportProto(aDefaultConfig.GetTransportProto());
135     }
136 
137 exit:
138     return;
139 }
140 
141 //---------------------------------------------------------------------------------------------------------------------
142 // Client::Response
143 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const144 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
145 {
146     switch (aSection)
147     {
148     case kAnswerSection:
149         aOffset    = mAnswerOffset;
150         aNumRecord = mAnswerRecordCount;
151         break;
152     case kAdditionalDataSection:
153     default:
154         aOffset    = mAdditionalOffset;
155         aNumRecord = mAdditionalRecordCount;
156         break;
157     }
158 }
159 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const160 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
161 {
162     uint16_t offset = kNameOffsetInQuery;
163 
164     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
165 }
166 
CheckForHostNameAlias(Section aSection,Name & aHostName) const167 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
168 {
169     // If the response includes a CNAME record mapping the query host
170     // name to a canonical name, we update `aHostName` to the new alias
171     // name. Otherwise `aHostName` remains as before.
172 
173     Error       error;
174     uint16_t    offset;
175     uint16_t    numRecords;
176     CnameRecord cnameRecord;
177 
178     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
179 
180     SelectSection(aSection, offset, numRecords);
181     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182 
183     switch (error)
184     {
185     case kErrorNone:
186         // A CNAME record was found. `offset` now points to after the
187         // last read byte within the `mMessage` into the `cnameRecord`
188         // (which is the start of the new canonical name).
189         aHostName.SetFromMessage(*mMessage, offset);
190         error = Name::ParseName(*mMessage, offset);
191         break;
192 
193     case kErrorNotFound:
194         error = kErrorNone;
195         break;
196 
197     default:
198         break;
199     }
200 
201 exit:
202     return error;
203 }
204 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const205 Error Client::Response::FindHostAddress(Section       aSection,
206                                         const Name   &aHostName,
207                                         uint16_t      aIndex,
208                                         Ip6::Address &aAddress,
209                                         uint32_t     &aTtl) const
210 {
211     Error      error;
212     uint16_t   offset;
213     uint16_t   numRecords;
214     Name       name = aHostName;
215     AaaaRecord aaaaRecord;
216 
217     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
218 
219     SelectSection(aSection, offset, numRecords);
220     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
221     aAddress = aaaaRecord.GetAddress();
222     aTtl     = aaaaRecord.GetTtl();
223 
224 exit:
225     return error;
226 }
227 
228 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
229 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const230 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
231 {
232     Error    error;
233     uint16_t offset;
234     uint16_t numRecords;
235     Name     name = aHostName;
236 
237     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
238 
239     SelectSection(aSection, offset, numRecords);
240     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
241 
242 exit:
243     return error;
244 }
245 
246 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
247 
248 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
249 
InitServiceInfo(ServiceInfo & aServiceInfo) const250 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
251 {
252     // This method initializes `aServiceInfo` setting all
253     // TTLs to zero and host name to empty string.
254 
255     aServiceInfo.mTtl              = 0;
256     aServiceInfo.mHostAddressTtl   = 0;
257     aServiceInfo.mTxtDataTtl       = 0;
258     aServiceInfo.mTxtDataTruncated = false;
259 
260     AsCoreType(&aServiceInfo.mHostAddress).Clear();
261 
262     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
263     {
264         aServiceInfo.mHostNameBuffer[0] = '\0';
265     }
266 }
267 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const268 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
269 {
270     // This method searches for SRV record in the given `aSection`
271     // matching the record name against `aName`, and updates the
272     // `aServiceInfo` accordingly. It also searches for AAAA record
273     // for host name associated with the service (from SRV record).
274     // The search for AAAA record is always performed in Additional
275     // Data section (independent of the value given in `aSection`).
276 
277     Error     error = kErrorNone;
278     uint16_t  offset;
279     uint16_t  numRecords;
280     Name      hostName;
281     SrvRecord srvRecord;
282 
283     // A non-zero `mTtl` indicates that SRV record is already found
284     // and parsed from a previous response.
285     VerifyOrExit(aServiceInfo.mTtl == 0);
286 
287     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
288 
289     // Search for a matching SRV record
290     SelectSection(aSection, offset, numRecords);
291     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
292 
293     aServiceInfo.mTtl      = srvRecord.GetTtl();
294     aServiceInfo.mPort     = srvRecord.GetPort();
295     aServiceInfo.mPriority = srvRecord.GetPriority();
296     aServiceInfo.mWeight   = srvRecord.GetWeight();
297 
298     hostName.SetFromMessage(*mMessage, offset);
299 
300     if (aServiceInfo.mHostNameBuffer != nullptr)
301     {
302         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
303                                                            aServiceInfo.mHostNameBufferSize));
304     }
305     else
306     {
307         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
308     }
309 
310     // Search in additional section for AAAA record for the host name.
311 
312     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
313 
314     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
315                             aServiceInfo.mHostAddressTtl);
316 
317     if (error == kErrorNotFound)
318     {
319         error = kErrorNone;
320     }
321 
322 exit:
323     return error;
324 }
325 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const326 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
327 {
328     // This method searches a TXT record in the given `aSection`
329     // matching the record name against `aName` and updates the TXT
330     // related properties in `aServicesInfo`.
331     //
332     // If no match is found `mTxtDataTtl` (which is initialized to zero)
333     // remains unchanged to indicate this. In this case this method still
334     // returns `kErrorNone`.
335 
336     Error     error = kErrorNone;
337     uint16_t  offset;
338     uint16_t  numRecords;
339     TxtRecord txtRecord;
340 
341     // A non-zero `mTxtDataTtl` indicates that TXT record is already
342     // found and parsed from a previous response.
343     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
344 
345     // A null `mTxtData` indicates that caller does not want to retrieve
346     // TXT data.
347     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
348 
349     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
350 
351     SelectSection(aSection, offset, numRecords);
352 
353     aServiceInfo.mTxtDataTruncated = false;
354 
355     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
356 
357     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
358 
359     if (error == kErrorNoBufs)
360     {
361         error = kErrorNone;
362 
363         // Mark `mTxtDataTruncated` to indicate that we could not read
364         // the full TXT record into the given `mTxtData` buffer.
365         aServiceInfo.mTxtDataTruncated = true;
366     }
367 
368     SuccessOrExit(error);
369     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
370 
371 exit:
372     if (error == kErrorNotFound)
373     {
374         error = kErrorNone;
375     }
376 
377     return error;
378 }
379 
380 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
381 
PopulateFrom(const Message & aMessage)382 void Client::Response::PopulateFrom(const Message &aMessage)
383 {
384     // Populate `Response` with info from `aMessage`.
385 
386     uint16_t offset = aMessage.GetOffset();
387     Header   header;
388 
389     mMessage = &aMessage;
390 
391     IgnoreError(aMessage.Read(offset, header));
392     offset += sizeof(Header);
393 
394     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
395     {
396         IgnoreError(Name::ParseName(aMessage, offset));
397         offset += sizeof(Question);
398     }
399 
400     mAnswerOffset = offset;
401     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
402     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
403     mAdditionalOffset = offset;
404     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
405 
406     mAnswerRecordCount     = header.GetAnswerCount();
407     mAdditionalRecordCount = header.GetAdditionalRecordCount();
408 }
409 
410 //---------------------------------------------------------------------------------------------------------------------
411 // Client::AddressResponse
412 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const413 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
414 {
415     Error error = kErrorNone;
416     Name  name(*mQuery, kNameOffsetInQuery);
417 
418 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
419 
420     // If the response is for an IPv4 address query or if it is an
421     // IPv6 address query response with no IPv6 address but with
422     // an IPv4 in its additional section, we read the IPv4 address
423     // and translate it to an IPv6 address.
424 
425     QueryInfo info;
426 
427     info.ReadFrom(*mQuery);
428 
429     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
430     {
431         Section                          section;
432         ARecord                          aRecord;
433         NetworkData::ExternalRouteConfig nat64Prefix;
434 
435         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
436                      error = kErrorInvalidState);
437 
438         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
439         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
440 
441         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
442         aTtl = aRecord.GetTtl();
443 
444         ExitNow();
445     }
446 
447 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
448 
449     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
450 
451 exit:
452     return error;
453 }
454 
455 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
456 
457 //---------------------------------------------------------------------------------------------------------------------
458 // Client::BrowseResponse
459 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const460 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
461 {
462     Error     error;
463     uint16_t  offset;
464     uint16_t  numRecords;
465     Name      serviceName(*mQuery, kNameOffsetInQuery);
466     PtrRecord ptrRecord;
467 
468     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
469 
470     SelectSection(kAnswerSection, offset, numRecords);
471     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
472     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
473 
474 exit:
475     return error;
476 }
477 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const478 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
479 {
480     Error error;
481     Name  instanceName;
482 
483     // Find a matching PTR record for the service instance label. Then
484     // search and read SRV, TXT and AAAA records in Additional Data
485     // section matching the same name to populate `aServiceInfo`.
486 
487     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
488 
489     InitServiceInfo(aServiceInfo);
490     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
491     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
492 
493     if (aServiceInfo.mTxtDataTtl == 0)
494     {
495         aServiceInfo.mTxtDataSize = 0;
496     }
497 
498 exit:
499     return error;
500 }
501 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const502 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
503                                              uint16_t      aIndex,
504                                              Ip6::Address &aAddress,
505                                              uint32_t     &aTtl) const
506 {
507     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
508 }
509 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const510 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
511 {
512     // This method searches within the Answer Section for a PTR record
513     // matching a given instance label @aInstanceLabel. If found, the
514     // `aName` is updated to return the name in the message.
515 
516     Error     error;
517     uint16_t  offset;
518     Name      serviceName(*mQuery, kNameOffsetInQuery);
519     uint16_t  numRecords;
520     uint16_t  labelOffset;
521     PtrRecord ptrRecord;
522 
523     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
524 
525     SelectSection(kAnswerSection, offset, numRecords);
526 
527     for (; numRecords > 0; numRecords--)
528     {
529         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
530 
531         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
532 
533         if (error == kErrorNotFound)
534         {
535             // `ReadRecord()` updates `offset` to skip over a
536             // non-matching record.
537             continue;
538         }
539 
540         SuccessOrExit(error);
541 
542         // It is a PTR record. Check the first label to match the
543         // instance label.
544 
545         labelOffset = offset;
546         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
547 
548         if (error == kErrorNone)
549         {
550             aInstanceName.SetFromMessage(*mMessage, offset);
551             ExitNow();
552         }
553 
554         VerifyOrExit(error == kErrorNotFound);
555 
556         // Update offset to skip over the PTR record.
557         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
558     }
559 
560     error = kErrorNotFound;
561 
562 exit:
563     return error;
564 }
565 
566 //---------------------------------------------------------------------------------------------------------------------
567 // Client::ServiceResponse
568 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const569 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
570                                               uint8_t  aLabelBufferSize,
571                                               char    *aNameBuffer,
572                                               uint16_t aNameBufferSize) const
573 {
574     Error    error;
575     uint16_t offset = kNameOffsetInQuery;
576 
577     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
578 
579     VerifyOrExit(aNameBuffer != nullptr);
580     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
581 
582 exit:
583     return error;
584 }
585 
GetServiceInfo(ServiceInfo & aServiceInfo) const586 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
587 {
588     // Search and read SRV, TXT records matching name from query.
589 
590     Error error = kErrorNotFound;
591 
592     InitServiceInfo(aServiceInfo);
593 
594     for (const Response *response = this; response != nullptr; response = response->mNext)
595     {
596         Name      name(*response->mQuery, kNameOffsetInQuery);
597         QueryInfo info;
598         Section   srvSection;
599         Section   txtSection;
600 
601         info.ReadFrom(*response->mQuery);
602 
603         switch (info.mQueryType)
604         {
605         case kIp6AddressQuery:
606 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
607         case kIp4AddressQuery:
608 #endif
609             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
610                                                   AsCoreType(&aServiceInfo.mHostAddress),
611                                                   aServiceInfo.mHostAddressTtl));
612 
613             continue; // to `for()` loop
614 
615         case kServiceQuerySrvTxt:
616         case kServiceQuerySrv:
617         case kServiceQueryTxt:
618             break;
619 
620         default:
621             continue;
622         }
623 
624         // Determine from which section we should try to read the SRV and
625         // TXT records based on the query type.
626         //
627         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
628         // only one record (SRV or TXT) in the answer section, but we
629         // still try to read the other records from additional data
630         // section in case server provided them.
631 
632         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
633         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
634 
635         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
636 
637         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
638         {
639             error = kErrorNone;
640         }
641 
642         SuccessOrExit(error);
643 
644         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
645     }
646 
647     if (aServiceInfo.mTxtDataTtl == 0)
648     {
649         aServiceInfo.mTxtDataSize = 0;
650     }
651 
652 exit:
653     return error;
654 }
655 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const656 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
657                                               uint16_t      aIndex,
658                                               Ip6::Address &aAddress,
659                                               uint32_t     &aTtl) const
660 {
661     Error error = kErrorNotFound;
662 
663     for (const Response *response = this; response != nullptr; response = response->mNext)
664     {
665         Section   section = kAdditionalDataSection;
666         QueryInfo info;
667 
668         info.ReadFrom(*response->mQuery);
669 
670         switch (info.mQueryType)
671         {
672         case kIp6AddressQuery:
673 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
674         case kIp4AddressQuery:
675 #endif
676             section = kAnswerSection;
677             break;
678 
679         default:
680             break;
681         }
682 
683         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
684 
685         if (error == kErrorNone)
686         {
687             break;
688         }
689     }
690 
691     return error;
692 }
693 
694 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
695 
696 //---------------------------------------------------------------------------------------------------------------------
697 // Client
698 
699 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
700 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
701 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
702 #endif
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
704 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
705 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
706 #endif
707 
708 const uint8_t Client::kQuestionCount[] = {
709     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
710 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
711     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
712 #endif
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
714     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
715     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
716     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
717     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
718 #endif
719 };
720 
721 const uint16_t *const Client::kQuestionRecordTypes[] = {
722     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
723 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
724     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
725 #endif
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
727     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
728     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
729     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
730     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
731 
732 #endif
733 };
734 
Client(Instance & aInstance)735 Client::Client(Instance &aInstance)
736     : InstanceLocator(aInstance)
737     , mSocket(aInstance)
738 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
739     , mTcpState(kTcpUninitialized)
740 #endif
741     , mTimer(aInstance)
742     , mDefaultConfig(QueryConfig::kInitFromDefaults)
743 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
744     , mUserDidSetDefaultAddress(false)
745 #endif
746 {
747     static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
748 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
749     static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
750 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
751     static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
752     static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
753     static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
754     static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
755 #endif
756 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
757     static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
758     static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
759     static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
760     static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
761 #endif
762 }
763 
Start(void)764 Error Client::Start(void)
765 {
766     Error error;
767 
768     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
769     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
770 
771 exit:
772     return error;
773 }
774 
Stop(void)775 void Client::Stop(void)
776 {
777     Query *query;
778 
779     while ((query = mMainQueries.GetHead()) != nullptr)
780     {
781         FinalizeQuery(*query, kErrorAbort);
782     }
783 
784     IgnoreError(mSocket.Close());
785 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
786     if (mTcpState != kTcpUninitialized)
787     {
788         IgnoreError(mEndpoint.Deinitialize());
789     }
790 #endif
791 }
792 
793 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)794 Error Client::InitTcpSocket(void)
795 {
796     Error                       error;
797     otTcpEndpointInitializeArgs endpointArgs;
798 
799     memset(&endpointArgs, 0x00, sizeof(endpointArgs));
800     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
801     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
802     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
803     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
804     endpointArgs.mContext                  = this;
805     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
806     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
807 
808     mSendLink.mNext   = nullptr;
809     mSendLink.mData   = mSendBufferBytes;
810     mSendLink.mLength = 0;
811 
812     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
813 exit:
814     return error;
815 }
816 #endif
817 
SetDefaultConfig(const QueryConfig & aQueryConfig)818 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
819 {
820     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
821 
822     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
823 
824 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
825     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
826     UpdateDefaultConfigAddress();
827 #endif
828 }
829 
ResetDefaultConfig(void)830 void Client::ResetDefaultConfig(void)
831 {
832     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
833 
834 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
835     mUserDidSetDefaultAddress = false;
836     UpdateDefaultConfigAddress();
837 #endif
838 }
839 
840 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)841 void Client::UpdateDefaultConfigAddress(void)
842 {
843     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
844 
845     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
846         !srpServerAddr.IsUnspecified())
847     {
848         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
849     }
850 }
851 #endif
852 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)853 Error Client::ResolveAddress(const char        *aHostName,
854                              AddressCallback    aCallback,
855                              void              *aContext,
856                              const QueryConfig *aConfig)
857 {
858     QueryInfo info;
859 
860     info.Clear();
861     info.mQueryType = kIp6AddressQuery;
862     info.mConfig.SetFrom(aConfig, mDefaultConfig);
863     info.mCallback.mAddressCallback = aCallback;
864     info.mCallbackContext           = aContext;
865 
866     return StartQuery(info, nullptr, aHostName);
867 }
868 
869 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)870 Error Client::ResolveIp4Address(const char        *aHostName,
871                                 AddressCallback    aCallback,
872                                 void              *aContext,
873                                 const QueryConfig *aConfig)
874 {
875     QueryInfo info;
876 
877     info.Clear();
878     info.mQueryType = kIp4AddressQuery;
879     info.mConfig.SetFrom(aConfig, mDefaultConfig);
880     info.mCallback.mAddressCallback = aCallback;
881     info.mCallbackContext           = aContext;
882 
883     return StartQuery(info, nullptr, aHostName);
884 }
885 #endif
886 
887 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
888 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)889 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
890 {
891     QueryInfo info;
892 
893     info.Clear();
894     info.mQueryType = kBrowseQuery;
895     info.mConfig.SetFrom(aConfig, mDefaultConfig);
896     info.mCallback.mBrowseCallback = aCallback;
897     info.mCallbackContext          = aContext;
898 
899     return StartQuery(info, nullptr, aServiceName);
900 }
901 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)902 Error Client::ResolveService(const char        *aInstanceLabel,
903                              const char        *aServiceName,
904                              ServiceCallback    aCallback,
905                              void              *aContext,
906                              const QueryConfig *aConfig)
907 {
908     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
909 }
910 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)911 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
912                                            const char        *aServiceName,
913                                            ServiceCallback    aCallback,
914                                            void              *aContext,
915                                            const QueryConfig *aConfig)
916 {
917     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
918 }
919 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)920 Error Client::Resolve(const char        *aInstanceLabel,
921                       const char        *aServiceName,
922                       ServiceCallback    aCallback,
923                       void              *aContext,
924                       const QueryConfig *aConfig,
925                       bool               aShouldResolveHostAddr)
926 {
927     QueryInfo info;
928     Error     error;
929     QueryType secondQueryType = kNoQuery;
930 
931     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
932 
933     info.Clear();
934 
935     info.mConfig.SetFrom(aConfig, mDefaultConfig);
936     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
937 
938     switch (info.mConfig.GetServiceMode())
939     {
940     case QueryConfig::kServiceModeSrvTxtSeparate:
941         secondQueryType = kServiceQueryTxt;
942 
943         OT_FALL_THROUGH;
944 
945     case QueryConfig::kServiceModeSrv:
946         info.mQueryType = kServiceQuerySrv;
947         break;
948 
949     case QueryConfig::kServiceModeTxt:
950         info.mQueryType = kServiceQueryTxt;
951         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
952         break;
953 
954     case QueryConfig::kServiceModeSrvTxt:
955     case QueryConfig::kServiceModeSrvTxtOptimize:
956     default:
957         info.mQueryType = kServiceQuerySrvTxt;
958         break;
959     }
960 
961     info.mCallback.mServiceCallback = aCallback;
962     info.mCallbackContext           = aContext;
963 
964     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
965 
966 exit:
967     return error;
968 }
969 
970 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
971 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)972 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
973 {
974     // The `aLabel` can be `nullptr` and then `aName` provides the
975     // full name, otherwise the name is appended as `{aLabel}.
976     // {aName}`.
977 
978     Error  error;
979     Query *query;
980 
981     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
982 
983 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
984     if (aInfo.mQueryType == kIp4AddressQuery)
985     {
986         NetworkData::ExternalRouteConfig nat64Prefix;
987 
988         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
989         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
990                      error = kErrorInvalidState);
991     }
992 #endif
993 
994     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
995 
996     mMainQueries.Enqueue(*query);
997 
998     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
999     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1000 
1001     if (aSecondType != kNoQuery)
1002     {
1003         Query *secondQuery;
1004 
1005         aInfo.mQueryType         = aSecondType;
1006         aInfo.mMessageId         = 0;
1007         aInfo.mTransmissionCount = 0;
1008         aInfo.mMainQuery         = query;
1009 
1010         // We intentionally do not use `error` here so in the unlikely
1011         // case where we cannot allocate the second query we can proceed
1012         // with the first one.
1013         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1014 
1015         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1016 
1017         // Update first query to link to second one by updating
1018         // its `mNextQuery`.
1019         aInfo.ReadFrom(*query);
1020         aInfo.mNextQuery = secondQuery;
1021         UpdateQuery(*query, aInfo);
1022     }
1023 
1024 exit:
1025     return error;
1026 }
1027 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1028 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1029 {
1030     Error error = kErrorNone;
1031 
1032     aQuery = nullptr;
1033 
1034     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1035 
1036     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1037     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1038 
1039     SuccessOrExit(error = aQuery->Append(aInfo));
1040 
1041     if (aLabel != nullptr)
1042     {
1043         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1044     }
1045 
1046     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1047 
1048 exit:
1049     FreeAndNullMessageOnError(aQuery, error);
1050     return error;
1051 }
1052 
FindMainQuery(Query & aQuery)1053 Client::Query &Client::FindMainQuery(Query &aQuery)
1054 {
1055     QueryInfo info;
1056 
1057     info.ReadFrom(aQuery);
1058 
1059     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1060 }
1061 
FreeQuery(Query & aQuery)1062 void Client::FreeQuery(Query &aQuery)
1063 {
1064     Query    &mainQuery = FindMainQuery(aQuery);
1065     QueryInfo info;
1066 
1067     mMainQueries.Dequeue(mainQuery);
1068 
1069     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1070     {
1071         info.ReadFrom(*query);
1072         FreeMessage(info.mSavedResponse);
1073         query->Free();
1074     }
1075 }
1076 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1077 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1078 {
1079     // This method prepares and sends a query message represented by
1080     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1081     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1082     // well. `aUpdateTimer` indicates whether the timer should be
1083     // updated when query is sent or not (used in the case where timer
1084     // is handled by caller).
1085 
1086     Error            error   = kErrorNone;
1087     Message         *message = nullptr;
1088     Header           header;
1089     Ip6::MessageInfo messageInfo;
1090     uint16_t         length = 0;
1091 
1092     aInfo.mTransmissionCount++;
1093     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1094 
1095     if (aInfo.mMessageId == 0)
1096     {
1097         do
1098         {
1099             SuccessOrExit(error = header.SetRandomMessageId());
1100         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1101 
1102         aInfo.mMessageId = header.GetMessageId();
1103     }
1104     else
1105     {
1106         header.SetMessageId(aInfo.mMessageId);
1107     }
1108 
1109     header.SetType(Header::kTypeQuery);
1110     header.SetQueryType(Header::kQueryTypeStandard);
1111 
1112     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1113     {
1114         header.SetRecursionDesiredFlag();
1115     }
1116 
1117     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1118 
1119     message = mSocket.NewMessage();
1120     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1121 
1122     SuccessOrExit(error = message->Append(header));
1123 
1124     // Prepare the question section.
1125 
1126     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1127     {
1128         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1129         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1130     }
1131 
1132     length = message->GetLength() - message->GetOffset();
1133 
1134     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1135 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1136     {
1137         // Check if query will fit into tcp buffer if not return error.
1138         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1139                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1140                      error = kErrorNoBufs);
1141 
1142         // In case of initialized connection check if connected peer and new query have the same address.
1143         if (mTcpState != kTcpUninitialized)
1144         {
1145             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1146                          error = kErrorFailed);
1147         }
1148 
1149         switch (mTcpState)
1150         {
1151         case kTcpUninitialized:
1152             SuccessOrExit(error = InitTcpSocket());
1153             SuccessOrExit(
1154                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1155             mTcpState = kTcpConnecting;
1156             PrepareTcpMessage(*message);
1157             break;
1158         case kTcpConnectedIdle:
1159             PrepareTcpMessage(*message);
1160             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1161             mTcpState = kTcpConnectedSending;
1162             break;
1163         case kTcpConnecting:
1164             PrepareTcpMessage(*message);
1165             break;
1166         case kTcpConnectedSending:
1167             WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1168             SuccessOrAssert(error = message->Read(message->GetOffset(),
1169                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1170             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1171             break;
1172         }
1173         message->Free();
1174         message = nullptr;
1175     }
1176 #else
1177     {
1178         error = kErrorInvalidArgs;
1179         LogWarn("DNS query over TCP not supported.");
1180         ExitNow();
1181     }
1182 #endif
1183     else
1184     {
1185         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1186         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1187         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1188         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1189     }
1190 
1191 exit:
1192 
1193     FreeMessageOnError(message, error);
1194     if (aUpdateTimer)
1195     {
1196         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1197     }
1198 
1199     UpdateQuery(aQuery, aInfo);
1200 
1201     return error;
1202 }
1203 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1204 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1205 {
1206     // The name is encoded and included after the `Info` in `aQuery`
1207     // starting at `kNameOffsetInQuery`.
1208 
1209     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1210 }
1211 
FinalizeQuery(Query & aQuery,Error aError)1212 void Client::FinalizeQuery(Query &aQuery, Error aError)
1213 {
1214     Response response;
1215     Query   &mainQuery = FindMainQuery(aQuery);
1216 
1217     response.mInstance = &Get<Instance>();
1218     response.mQuery    = &mainQuery;
1219 
1220     FinalizeQuery(response, aError);
1221 }
1222 
FinalizeQuery(Response & aResponse,Error aError)1223 void Client::FinalizeQuery(Response &aResponse, Error aError)
1224 {
1225     QueryType type;
1226     Callback  callback;
1227     void     *context;
1228 
1229     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1230 
1231     switch (type)
1232     {
1233     case kIp6AddressQuery:
1234 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1235     case kIp4AddressQuery:
1236 #endif
1237         if (callback.mAddressCallback != nullptr)
1238         {
1239             callback.mAddressCallback(aError, &aResponse, context);
1240         }
1241         break;
1242 
1243 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1244     case kBrowseQuery:
1245         if (callback.mBrowseCallback != nullptr)
1246         {
1247             callback.mBrowseCallback(aError, &aResponse, context);
1248         }
1249         break;
1250 
1251     case kServiceQuerySrvTxt:
1252     case kServiceQuerySrv:
1253     case kServiceQueryTxt:
1254         if (callback.mServiceCallback != nullptr)
1255         {
1256             callback.mServiceCallback(aError, &aResponse, context);
1257         }
1258         break;
1259 #endif
1260     case kNoQuery:
1261         break;
1262     }
1263 
1264     FreeQuery(*aResponse.mQuery);
1265 }
1266 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1267 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1268 {
1269     QueryInfo info;
1270 
1271     info.ReadFrom(aQuery);
1272 
1273     aType     = info.mQueryType;
1274     aCallback = info.mCallback;
1275     aContext  = info.mCallbackContext;
1276 }
1277 
FindQueryById(uint16_t aMessageId)1278 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1279 {
1280     Query    *matchedQuery = nullptr;
1281     QueryInfo info;
1282 
1283     for (Query &mainQuery : mMainQueries)
1284     {
1285         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1286         {
1287             info.ReadFrom(*query);
1288 
1289             if (info.mMessageId == aMessageId)
1290             {
1291                 matchedQuery = query;
1292                 ExitNow();
1293             }
1294         }
1295     }
1296 
1297 exit:
1298     return matchedQuery;
1299 }
1300 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1301 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1302 {
1303     OT_UNUSED_VARIABLE(aMsgInfo);
1304 
1305     static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1306 }
1307 
ProcessResponse(const Message & aResponseMessage)1308 void Client::ProcessResponse(const Message &aResponseMessage)
1309 {
1310     Error  responseError;
1311     Query *query;
1312 
1313     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1314 
1315     if (responseError != kErrorNone)
1316     {
1317         // Received an error from server, check if we can replace
1318         // the query.
1319 
1320 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1321         if (ReplaceWithIp4Query(*query) == kErrorNone)
1322         {
1323             ExitNow();
1324         }
1325 #endif
1326 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1327         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1328         {
1329             ExitNow();
1330         }
1331 #endif
1332 
1333         FinalizeQuery(*query, responseError);
1334         ExitNow();
1335     }
1336 
1337     // Received successful response from server.
1338 
1339 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1340     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1341 #endif
1342 
1343     if (!CanFinalizeQuery(*query))
1344     {
1345         SaveQueryResponse(*query, aResponseMessage);
1346         ExitNow();
1347     }
1348 
1349     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1350 
1351 exit:
1352     return;
1353 }
1354 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1355 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1356 {
1357     Error     error  = kErrorNone;
1358     uint16_t  offset = aResponseMessage.GetOffset();
1359     Header    header;
1360     QueryInfo info;
1361     Name      queryName;
1362 
1363     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1364     offset += sizeof(Header);
1365 
1366     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1367                      !header.IsTruncationFlagSet(),
1368                  error = kErrorDrop);
1369 
1370     aQuery = FindQueryById(header.GetMessageId());
1371     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1372 
1373     info.ReadFrom(*aQuery);
1374 
1375     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1376 
1377     // Check the Question Section
1378 
1379     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1380     {
1381         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1382         {
1383             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1384             offset += sizeof(Question);
1385         }
1386     }
1387     else
1388     {
1389         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1390                      error = kErrorParse);
1391     }
1392 
1393     // Check the answer, authority and additional record sections
1394 
1395     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1396     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1397     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1398 
1399     // Read the response code
1400 
1401     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1402 
1403 exit:
1404     return error;
1405 }
1406 
CanFinalizeQuery(Query & aQuery)1407 bool Client::CanFinalizeQuery(Query &aQuery)
1408 {
1409     // Determines whether we can finalize a main query by checking if
1410     // we have received and saved responses for all other related
1411     // queries associated with `aQuery`. Note that this method is
1412     // called when we receive a response for `aQuery`, so no need to
1413     // check for a saved response for `aQuery` itself.
1414 
1415     bool      canFinalize = true;
1416     QueryInfo info;
1417 
1418     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1419     {
1420         info.ReadFrom(*query);
1421 
1422         if (query == &aQuery)
1423         {
1424             continue;
1425         }
1426 
1427         if (info.mSavedResponse == nullptr)
1428         {
1429             canFinalize = false;
1430             ExitNow();
1431         }
1432     }
1433 
1434 exit:
1435     return canFinalize;
1436 }
1437 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1438 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1439 {
1440     QueryInfo info;
1441 
1442     info.ReadFrom(aQuery);
1443     VerifyOrExit(info.mSavedResponse == nullptr);
1444 
1445     // If `Clone()` fails we let retry or timeout handle the error.
1446     info.mSavedResponse = aResponseMessage.Clone();
1447 
1448     UpdateQuery(aQuery, info);
1449 
1450 exit:
1451     return;
1452 }
1453 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1454 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1455 {
1456     // Populate `aResponse` for `aQuery`. If there is a saved response
1457     // message for `aQuery` we use it, otherwise, we use
1458     // `aResponseMessage`.
1459 
1460     QueryInfo info;
1461 
1462     info.ReadFrom(aQuery);
1463 
1464     aResponse.mInstance = &Get<Instance>();
1465     aResponse.mQuery    = &aQuery;
1466     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1467 
1468     return info.mNextQuery;
1469 }
1470 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1471 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1472 {
1473     // This method prepares a list of chained `Response` instances
1474     // corresponding to all related (chained) queries. It uses
1475     // recursion to go through the queries and construct the
1476     // `Response` chain.
1477 
1478     Response response;
1479     Query   *nextQuery;
1480 
1481     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1482     response.mNext = aPrevResponse;
1483 
1484     if (nextQuery != nullptr)
1485     {
1486         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1487     }
1488     else
1489     {
1490         FinalizeQuery(response, kErrorNone);
1491     }
1492 }
1493 
HandleTimer(void)1494 void Client::HandleTimer(void)
1495 {
1496     TimeMilli now      = TimerMilli::GetNow();
1497     TimeMilli nextTime = now.GetDistantFuture();
1498     QueryInfo info;
1499 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1500     bool hasTcpQuery = false;
1501 #endif
1502 
1503     for (Query &mainQuery : mMainQueries)
1504     {
1505         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1506         {
1507             info.ReadFrom(*query);
1508 
1509             if (info.mSavedResponse != nullptr)
1510             {
1511                 continue;
1512             }
1513 
1514             if (now >= info.mRetransmissionTime)
1515             {
1516                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1517                 {
1518                     FinalizeQuery(*query, kErrorResponseTimeout);
1519                     break;
1520                 }
1521 
1522                 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1523             }
1524 
1525             if (nextTime > info.mRetransmissionTime)
1526             {
1527                 nextTime = info.mRetransmissionTime;
1528             }
1529 
1530 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1531             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1532             {
1533                 hasTcpQuery = true;
1534             }
1535 #endif
1536         }
1537     }
1538 
1539     if (nextTime < now.GetDistantFuture())
1540     {
1541         mTimer.FireAt(nextTime);
1542     }
1543 
1544 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1545     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1546     {
1547         IgnoreError(mEndpoint.SendEndOfStream());
1548     }
1549 #endif
1550 }
1551 
1552 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1553 
ReplaceWithIp4Query(Query & aQuery)1554 Error Client::ReplaceWithIp4Query(Query &aQuery)
1555 {
1556     Error     error = kErrorFailed;
1557     QueryInfo info;
1558 
1559     info.ReadFrom(aQuery);
1560 
1561     VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1562     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1563 
1564     // We send a new query for IPv4 address resolution
1565     // for the same host name. We reuse the existing `aQuery`
1566     // instance and keep all the info but clear `mTransmissionCount`
1567     // and `mMessageId` (so that a new random message ID is
1568     // selected). The new `info` will be saved in the query in
1569     // `SendQuery()`. Note that the current query is still in the
1570     // `mMainQueries` list when `SendQuery()` selects a new random
1571     // message ID, so the existing message ID for this query will
1572     // not be reused.
1573 
1574     info.mQueryType         = kIp4AddressQuery;
1575     info.mMessageId         = 0;
1576     info.mTransmissionCount = 0;
1577 
1578     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1579     error = kErrorNone;
1580 
1581 exit:
1582     return error;
1583 }
1584 
1585 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1586 
1587 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1588 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1589 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1590 {
1591     Error     error = kErrorFailed;
1592     QueryInfo info;
1593     Query    *secondQuery;
1594 
1595     info.ReadFrom(aQuery);
1596 
1597     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1598     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1599 
1600     secondQuery = aQuery.Clone();
1601     VerifyOrExit(secondQuery != nullptr);
1602 
1603     info.mQueryType         = kServiceQueryTxt;
1604     info.mMessageId         = 0;
1605     info.mTransmissionCount = 0;
1606     info.mMainQuery         = &aQuery;
1607     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1608 
1609     info.mQueryType         = kServiceQuerySrv;
1610     info.mMessageId         = 0;
1611     info.mTransmissionCount = 0;
1612     info.mNextQuery         = secondQuery;
1613     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1614     error = kErrorNone;
1615 
1616 exit:
1617     return error;
1618 }
1619 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1620 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1621 {
1622     QueryInfo   info;
1623     Response    response;
1624     ServiceInfo serviceInfo;
1625     char        hostName[Name::kMaxNameSize];
1626 
1627     info.ReadFrom(aQuery);
1628 
1629     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1630     VerifyOrExit(info.mShouldResolveHostAddr);
1631 
1632     PopulateResponse(response, aQuery, aResponseMessage);
1633 
1634     memset(&serviceInfo, 0, sizeof(serviceInfo));
1635     serviceInfo.mHostNameBuffer     = hostName;
1636     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1637     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1638 
1639     // Check whether AAAA record for host address is provided in the SRV query response
1640 
1641     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1642     {
1643         Query *newQuery;
1644 
1645         info.mQueryType         = kIp6AddressQuery;
1646         info.mMessageId         = 0;
1647         info.mTransmissionCount = 0;
1648         info.mMainQuery         = &FindMainQuery(aQuery);
1649 
1650         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1651         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1652 
1653         // Update `aQuery` to be linked with new query (inserting
1654         // the `newQuery` into the linked-list after `aQuery`).
1655 
1656         info.ReadFrom(aQuery);
1657         info.mNextQuery = newQuery;
1658         UpdateQuery(aQuery, info);
1659     }
1660 
1661 exit:
1662     return;
1663 }
1664 
1665 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1666 
1667 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1668 void Client::PrepareTcpMessage(Message &aMessage)
1669 {
1670     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1671 
1672     // Prepending the DNS query with length of the packet according to RFC1035.
1673     WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1674     SuccessOrAssert(
1675         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1676     mSendLink.mLength += length + sizeof(uint16_t);
1677 }
1678 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1679 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1680 {
1681     OT_UNUSED_VARIABLE(aEndpoint);
1682     OT_UNUSED_VARIABLE(aData);
1683     OT_ASSERT(mTcpState == kTcpConnectedSending);
1684 
1685     mSendLink.mLength = 0;
1686     mTcpState         = kTcpConnectedIdle;
1687 }
1688 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1689 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1690 {
1691     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1692 }
1693 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1694 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1695 {
1696     OT_UNUSED_VARIABLE(aEndpoint);
1697     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1698     mTcpState = kTcpConnectedSending;
1699 }
1700 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1701 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1702 {
1703     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1704 }
1705 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1706 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1707                                  size_t                &aOffset,
1708                                  Message               &aMessage,
1709                                  uint16_t               aLength)
1710 {
1711     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1712     // and copy the content into `aMessage`. As we read we can move
1713     // to the next `aLinkedBuffer` and update `aOffset`.
1714     // Returns:
1715     // - `kErrorNone` if `aLength` bytes are successfully read and
1716     //    `aOffset` and `aLinkedBuffer` are updated.
1717     // - `kErrorNotFound` is not enough bytes available to read
1718     //    from `aLinkedBuffer`.
1719     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1720 
1721     Error error = kErrorNone;
1722 
1723     while (aLength > 0)
1724     {
1725         uint16_t bytesToRead = aLength;
1726 
1727         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1728 
1729         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1730         {
1731             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1732         }
1733 
1734         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1735 
1736         aLength -= bytesToRead;
1737         aOffset += bytesToRead;
1738 
1739         if (aOffset == aLinkedBuffer->mLength)
1740         {
1741             aLinkedBuffer = aLinkedBuffer->mNext;
1742             aOffset       = 0;
1743         }
1744     }
1745 
1746 exit:
1747     return error;
1748 }
1749 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1750 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1751                                        size_t         aBytesAvailable,
1752                                        bool           aEndOfStream,
1753                                        size_t         aBytesRemaining)
1754 {
1755     OT_UNUSED_VARIABLE(aEndpoint);
1756     OT_UNUSED_VARIABLE(aBytesRemaining);
1757 
1758     Message              *message   = nullptr;
1759     size_t                totalRead = 0;
1760     size_t                offset    = 0;
1761     const otLinkedBuffer *data;
1762 
1763     if (aEndOfStream)
1764     {
1765         // Cleanup is done in disconnected callback.
1766         IgnoreError(mEndpoint.SendEndOfStream());
1767     }
1768 
1769     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1770     VerifyOrExit(data != nullptr);
1771 
1772     message = mSocket.NewMessage();
1773     VerifyOrExit(message != nullptr);
1774 
1775     while (aBytesAvailable > totalRead)
1776     {
1777         uint16_t length;
1778 
1779         // Read the `length` field.
1780         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1781 
1782         IgnoreError(message->Read(/* aOffset */ 0, length));
1783         length = HostSwap16(length);
1784 
1785         // Try to read `length` bytes.
1786         IgnoreError(message->SetLength(0));
1787         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1788 
1789         totalRead += length + sizeof(uint16_t);
1790 
1791         // Now process the read message as query response.
1792         ProcessResponse(*message);
1793 
1794         IgnoreError(message->SetLength(0));
1795 
1796         // Loop again to see if we can read another response.
1797     }
1798 
1799 exit:
1800     // Inform `mEndPoint` about the total read and processed bytes
1801     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1802     FreeMessage(message);
1803 }
1804 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1805 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1806                                                size_t         aBytesAvailable,
1807                                                bool           aEndOfStream,
1808                                                size_t         aBytesRemaining)
1809 {
1810     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1811         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1812 }
1813 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1814 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1815 {
1816     OT_UNUSED_VARIABLE(aEndpoint);
1817     OT_UNUSED_VARIABLE(aReason);
1818     QueryInfo info;
1819 
1820     IgnoreError(mEndpoint.Deinitialize());
1821     mTcpState = kTcpUninitialized;
1822 
1823     // Abort queries in case of connection failures
1824     for (Query &mainQuery : mMainQueries)
1825     {
1826         info.ReadFrom(mainQuery);
1827 
1828         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1829         {
1830             FinalizeQuery(mainQuery, kErrorAbort);
1831         }
1832     }
1833 }
1834 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1835 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1836 {
1837     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1838 }
1839 
1840 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1841 
1842 } // namespace Dns
1843 } // namespace ot
1844 
1845 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1846