1 /*
2  *  Copyright (c) 2017-2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "dns_client.hpp"
30 
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32 
33 #include "instance/instance.hpp"
34 #include "utils/static_counter.hpp"
35 
36 /**
37  * @file
38  *   This file implements the DNS client.
39  */
40 
41 namespace ot {
42 namespace Dns {
43 
44 RegisterLogModule("DnsClient");
45 
46 //---------------------------------------------------------------------------------------------------------------------
47 // Client::QueryConfig
48 
49 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
50 
QueryConfig(InitMode aMode)51 Client::QueryConfig::QueryConfig(InitMode aMode)
52 {
53     OT_UNUSED_VARIABLE(aMode);
54 
55     IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
56     GetServerSockAddr().SetPort(kDefaultServerPort);
57     SetResponseTimeout(kDefaultResponseTimeout);
58     SetMaxTxAttempts(kDefaultMaxTxAttempts);
59     SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
60     SetServiceMode(kDefaultServiceMode);
61 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
62     SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
63 #endif
64     SetTransportProto(kDnsTransportUdp);
65 }
66 
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)67 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
68 {
69     // This method sets the config from `aConfig` replacing any
70     // unspecified fields (value zero) with the fields from
71     // `aDefaultConfig`. If `aConfig` is `nullptr` then
72     // `aDefaultConfig` is used.
73 
74     if (aConfig == nullptr)
75     {
76         *this = aDefaultConfig;
77         ExitNow();
78     }
79 
80     *this = *aConfig;
81 
82     if (GetServerSockAddr().GetAddress().IsUnspecified())
83     {
84         GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
85     }
86 
87     if (GetServerSockAddr().GetPort() == 0)
88     {
89         GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
90     }
91 
92     if (GetResponseTimeout() == 0)
93     {
94         SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
95     }
96 
97     if (GetMaxTxAttempts() == 0)
98     {
99         SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
100     }
101 
102     if (GetRecursionFlag() == kFlagUnspecified)
103     {
104         SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
105     }
106 
107 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
108     if (GetNat64Mode() == kNat64Unspecified)
109     {
110         SetNat64Mode(aDefaultConfig.GetNat64Mode());
111     }
112 #endif
113 
114     if (GetServiceMode() == kServiceModeUnspecified)
115     {
116         SetServiceMode(aDefaultConfig.GetServiceMode());
117     }
118 
119     if (GetTransportProto() == kDnsTransportUnspecified)
120     {
121         SetTransportProto(aDefaultConfig.GetTransportProto());
122     }
123 
124 exit:
125     return;
126 }
127 
128 //---------------------------------------------------------------------------------------------------------------------
129 // Client::Response
130 
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const131 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
132 {
133     switch (aSection)
134     {
135     case kAnswerSection:
136         aOffset    = mAnswerOffset;
137         aNumRecord = mAnswerRecordCount;
138         break;
139     case kAdditionalDataSection:
140     default:
141         aOffset    = mAdditionalOffset;
142         aNumRecord = mAdditionalRecordCount;
143         break;
144     }
145 }
146 
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const147 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
148 {
149     uint16_t offset = kNameOffsetInQuery;
150 
151     return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
152 }
153 
CheckForHostNameAlias(Section aSection,Name & aHostName) const154 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
155 {
156     // If the response includes a CNAME record mapping the query host
157     // name to a canonical name, we update `aHostName` to the new alias
158     // name. Otherwise `aHostName` remains as before. This method handles
159     // when there are multiple CNAME records mapping the host name multiple
160     // times. We limit number of changes to `kMaxCnameAliasNameChanges`
161     // to detect and handle if the response contains CNAME record loops.
162 
163     Error       error;
164     uint16_t    offset;
165     uint16_t    numRecords;
166     CnameRecord cnameRecord;
167 
168     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
169 
170     for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
171     {
172         SelectSection(aSection, offset, numRecords);
173         error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
174 
175         if (error == kErrorNotFound)
176         {
177             error = kErrorNone;
178             ExitNow();
179         }
180 
181         SuccessOrExit(error);
182 
183         // A CNAME record was found. `offset` now points to after the
184         // last read byte within the `mMessage` into the `cnameRecord`
185         // (which is the start of the new canonical name).
186 
187         aHostName.SetFromMessage(*mMessage, offset);
188         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
189 
190         // Loop back to check if there may be a CNAME record for the
191         // new `aHostName`.
192     }
193 
194     error = kErrorParse;
195 
196 exit:
197     return error;
198 }
199 
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const200 Error Client::Response::FindHostAddress(Section       aSection,
201                                         const Name   &aHostName,
202                                         uint16_t      aIndex,
203                                         Ip6::Address &aAddress,
204                                         uint32_t     &aTtl) const
205 {
206     Error      error;
207     uint16_t   offset;
208     uint16_t   numRecords;
209     Name       name = aHostName;
210     AaaaRecord aaaaRecord;
211 
212     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
213 
214     SelectSection(aSection, offset, numRecords);
215     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
216     aAddress = aaaaRecord.GetAddress();
217     aTtl     = aaaaRecord.GetTtl();
218 
219 exit:
220     return error;
221 }
222 
223 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
224 
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const225 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
226 {
227     Error    error;
228     uint16_t offset;
229     uint16_t numRecords;
230     Name     name = aHostName;
231 
232     SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
233 
234     SelectSection(aSection, offset, numRecords);
235     error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
236 
237 exit:
238     return error;
239 }
240 
241 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
242 
243 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
244 
InitServiceInfo(ServiceInfo & aServiceInfo) const245 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
246 {
247     // This method initializes `aServiceInfo` setting all
248     // TTLs to zero and host name to empty string.
249 
250     aServiceInfo.mTtl              = 0;
251     aServiceInfo.mHostAddressTtl   = 0;
252     aServiceInfo.mTxtDataTtl       = 0;
253     aServiceInfo.mTxtDataTruncated = false;
254 
255     AsCoreType(&aServiceInfo.mHostAddress).Clear();
256 
257     if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
258     {
259         aServiceInfo.mHostNameBuffer[0] = '\0';
260     }
261 }
262 
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const263 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
264 {
265     // This method searches for SRV record in the given `aSection`
266     // matching the record name against `aName`, and updates the
267     // `aServiceInfo` accordingly. It also searches for AAAA record
268     // for host name associated with the service (from SRV record).
269     // The search for AAAA record is always performed in Additional
270     // Data section (independent of the value given in `aSection`).
271 
272     Error     error = kErrorNone;
273     uint16_t  offset;
274     uint16_t  numRecords;
275     Name      hostName;
276     SrvRecord srvRecord;
277 
278     // A non-zero `mTtl` indicates that SRV record is already found
279     // and parsed from a previous response.
280     VerifyOrExit(aServiceInfo.mTtl == 0);
281 
282     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
283 
284     // Search for a matching SRV record
285     SelectSection(aSection, offset, numRecords);
286     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
287 
288     aServiceInfo.mTtl      = srvRecord.GetTtl();
289     aServiceInfo.mPort     = srvRecord.GetPort();
290     aServiceInfo.mPriority = srvRecord.GetPriority();
291     aServiceInfo.mWeight   = srvRecord.GetWeight();
292 
293     hostName.SetFromMessage(*mMessage, offset);
294 
295     if (aServiceInfo.mHostNameBuffer != nullptr)
296     {
297         SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
298                                                            aServiceInfo.mHostNameBufferSize));
299     }
300     else
301     {
302         SuccessOrExit(error = Name::ParseName(*mMessage, offset));
303     }
304 
305     // Search in additional section for AAAA record for the host name.
306 
307     VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
308 
309     error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
310                             aServiceInfo.mHostAddressTtl);
311 
312     if (error == kErrorNotFound)
313     {
314         error = kErrorNone;
315     }
316 
317 exit:
318     return error;
319 }
320 
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const321 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
322 {
323     // This method searches a TXT record in the given `aSection`
324     // matching the record name against `aName` and updates the TXT
325     // related properties in `aServicesInfo`.
326     //
327     // If no match is found `mTxtDataTtl` (which is initialized to zero)
328     // remains unchanged to indicate this. In this case this method still
329     // returns `kErrorNone`.
330 
331     Error     error = kErrorNone;
332     uint16_t  offset;
333     uint16_t  numRecords;
334     TxtRecord txtRecord;
335 
336     // A non-zero `mTxtDataTtl` indicates that TXT record is already
337     // found and parsed from a previous response.
338     VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
339 
340     // A null `mTxtData` indicates that caller does not want to retrieve
341     // TXT data.
342     VerifyOrExit(aServiceInfo.mTxtData != nullptr);
343 
344     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
345 
346     SelectSection(aSection, offset, numRecords);
347 
348     aServiceInfo.mTxtDataTruncated = false;
349 
350     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
351 
352     error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
353 
354     if (error == kErrorNoBufs)
355     {
356         error = kErrorNone;
357 
358         // Mark `mTxtDataTruncated` to indicate that we could not read
359         // the full TXT record into the given `mTxtData` buffer.
360         aServiceInfo.mTxtDataTruncated = true;
361     }
362 
363     SuccessOrExit(error);
364     aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
365 
366 exit:
367     if (error == kErrorNotFound)
368     {
369         error = kErrorNone;
370     }
371 
372     return error;
373 }
374 
375 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
376 
PopulateFrom(const Message & aMessage)377 void Client::Response::PopulateFrom(const Message &aMessage)
378 {
379     // Populate `Response` with info from `aMessage`.
380 
381     uint16_t offset = aMessage.GetOffset();
382     Header   header;
383 
384     mMessage = &aMessage;
385 
386     IgnoreError(aMessage.Read(offset, header));
387     offset += sizeof(Header);
388 
389     for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
390     {
391         IgnoreError(Name::ParseName(aMessage, offset));
392         offset += sizeof(Question);
393     }
394 
395     mAnswerOffset = offset;
396     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
397     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
398     mAdditionalOffset = offset;
399     IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
400 
401     mAnswerRecordCount     = header.GetAnswerCount();
402     mAdditionalRecordCount = header.GetAdditionalRecordCount();
403 }
404 
405 //---------------------------------------------------------------------------------------------------------------------
406 // Client::AddressResponse
407 
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const408 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
409 {
410     Error error = kErrorNone;
411     Name  name(*mQuery, kNameOffsetInQuery);
412 
413 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
414 
415     // If the response is for an IPv4 address query or if it is an
416     // IPv6 address query response with no IPv6 address but with
417     // an IPv4 in its additional section, we read the IPv4 address
418     // and translate it to an IPv6 address.
419 
420     QueryInfo info;
421 
422     info.ReadFrom(*mQuery);
423 
424     if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
425     {
426         Section                          section;
427         ARecord                          aRecord;
428         NetworkData::ExternalRouteConfig nat64Prefix;
429 
430         VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
431                      error = kErrorInvalidState);
432 
433         section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
434         SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
435 
436         aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
437         aTtl = aRecord.GetTtl();
438 
439         ExitNow();
440     }
441 
442 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
443 
444     ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
445 
446 exit:
447     return error;
448 }
449 
450 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
451 
452 //---------------------------------------------------------------------------------------------------------------------
453 // Client::BrowseResponse
454 
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const455 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
456 {
457     Error     error;
458     uint16_t  offset;
459     uint16_t  numRecords;
460     Name      serviceName(*mQuery, kNameOffsetInQuery);
461     PtrRecord ptrRecord;
462 
463     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
464 
465     SelectSection(kAnswerSection, offset, numRecords);
466     SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
467     error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
468 
469 exit:
470     return error;
471 }
472 
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const473 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
474 {
475     Error error;
476     Name  instanceName;
477 
478     // Find a matching PTR record for the service instance label. Then
479     // search and read SRV, TXT and AAAA records in Additional Data
480     // section matching the same name to populate `aServiceInfo`.
481 
482     SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
483 
484     InitServiceInfo(aServiceInfo);
485     SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
486     SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
487 
488     if (aServiceInfo.mTxtDataTtl == 0)
489     {
490         aServiceInfo.mTxtDataSize = 0;
491     }
492 
493 exit:
494     return error;
495 }
496 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const497 Error Client::BrowseResponse::GetHostAddress(const char   *aHostName,
498                                              uint16_t      aIndex,
499                                              Ip6::Address &aAddress,
500                                              uint32_t     &aTtl) const
501 {
502     return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
503 }
504 
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const505 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
506 {
507     // This method searches within the Answer Section for a PTR record
508     // matching a given instance label @aInstanceLabel. If found, the
509     // `aName` is updated to return the name in the message.
510 
511     Error     error;
512     uint16_t  offset;
513     Name      serviceName(*mQuery, kNameOffsetInQuery);
514     uint16_t  numRecords;
515     uint16_t  labelOffset;
516     PtrRecord ptrRecord;
517 
518     VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
519 
520     SelectSection(kAnswerSection, offset, numRecords);
521 
522     for (; numRecords > 0; numRecords--)
523     {
524         SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
525 
526         error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
527 
528         if (error == kErrorNotFound)
529         {
530             // `ReadRecord()` updates `offset` to skip over a
531             // non-matching record.
532             continue;
533         }
534 
535         SuccessOrExit(error);
536 
537         // It is a PTR record. Check the first label to match the
538         // instance label.
539 
540         labelOffset = offset;
541         error       = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
542 
543         if (error == kErrorNone)
544         {
545             aInstanceName.SetFromMessage(*mMessage, offset);
546             ExitNow();
547         }
548 
549         VerifyOrExit(error == kErrorNotFound);
550 
551         // Update offset to skip over the PTR record.
552         offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
553     }
554 
555     error = kErrorNotFound;
556 
557 exit:
558     return error;
559 }
560 
561 //---------------------------------------------------------------------------------------------------------------------
562 // Client::ServiceResponse
563 
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const564 Error Client::ServiceResponse::GetServiceName(char    *aLabelBuffer,
565                                               uint8_t  aLabelBufferSize,
566                                               char    *aNameBuffer,
567                                               uint16_t aNameBufferSize) const
568 {
569     Error    error;
570     uint16_t offset = kNameOffsetInQuery;
571 
572     SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
573 
574     VerifyOrExit(aNameBuffer != nullptr);
575     SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
576 
577 exit:
578     return error;
579 }
580 
GetServiceInfo(ServiceInfo & aServiceInfo) const581 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
582 {
583     // Search and read SRV, TXT records matching name from query.
584 
585     Error error = kErrorNotFound;
586 
587     InitServiceInfo(aServiceInfo);
588 
589     for (const Response *response = this; response != nullptr; response = response->mNext)
590     {
591         Name      name(*response->mQuery, kNameOffsetInQuery);
592         QueryInfo info;
593         Section   srvSection;
594         Section   txtSection;
595 
596         info.ReadFrom(*response->mQuery);
597 
598         switch (info.mQueryType)
599         {
600         case kIp6AddressQuery:
601 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
602         case kIp4AddressQuery:
603 #endif
604             IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
605                                                   AsCoreType(&aServiceInfo.mHostAddress),
606                                                   aServiceInfo.mHostAddressTtl));
607 
608             continue; // to `for()` loop
609 
610         case kServiceQuerySrvTxt:
611         case kServiceQuerySrv:
612         case kServiceQueryTxt:
613             break;
614 
615         default:
616             continue;
617         }
618 
619         // Determine from which section we should try to read the SRV and
620         // TXT records based on the query type.
621         //
622         // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
623         // only one record (SRV or TXT) in the answer section, but we
624         // still try to read the other records from additional data
625         // section in case server provided them.
626 
627         srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
628         txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
629 
630         error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
631 
632         if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
633         {
634             error = kErrorNone;
635         }
636 
637         SuccessOrExit(error);
638 
639         SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
640     }
641 
642     if (aServiceInfo.mTxtDataTtl == 0)
643     {
644         aServiceInfo.mTxtDataSize = 0;
645     }
646 
647 exit:
648     return error;
649 }
650 
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const651 Error Client::ServiceResponse::GetHostAddress(const char   *aHostName,
652                                               uint16_t      aIndex,
653                                               Ip6::Address &aAddress,
654                                               uint32_t     &aTtl) const
655 {
656     Error error = kErrorNotFound;
657 
658     for (const Response *response = this; response != nullptr; response = response->mNext)
659     {
660         Section   section = kAdditionalDataSection;
661         QueryInfo info;
662 
663         info.ReadFrom(*response->mQuery);
664 
665         switch (info.mQueryType)
666         {
667         case kIp6AddressQuery:
668 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
669         case kIp4AddressQuery:
670 #endif
671             section = kAnswerSection;
672             break;
673 
674         default:
675             break;
676         }
677 
678         error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
679 
680         if (error == kErrorNone)
681         {
682             break;
683         }
684     }
685 
686     return error;
687 }
688 
689 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
690 
691 //---------------------------------------------------------------------------------------------------------------------
692 // Client
693 
694 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
695 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
696 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
697 #endif
698 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
699 const uint16_t Client::kBrowseQueryRecordTypes[]  = {ResourceRecord::kTypePtr};
700 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
701 #endif
702 
703 const uint8_t Client::kQuestionCount[] = {
704     /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
705 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
706     /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
707 #endif
708 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
709     /* kBrowseQuery        -> */ GetArrayLength(kBrowseQueryRecordTypes),  // PTR record
710     /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
711     /* kServiceQuerySrv    -> */ 1,                                        // SRV record only
712     /* kServiceQueryTxt    -> */ 1,                                        // TXT record only
713 #endif
714 };
715 
716 const uint16_t *const Client::kQuestionRecordTypes[] = {
717     /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
718 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
719     /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
720 #endif
721 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
722     /* kBrowseQuery  -> */ kBrowseQueryRecordTypes,
723     /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
724     /* kServiceQuerySrv    -> */ &kServiceQueryRecordTypes[0],
725     /* kServiceQueryTxt    -> */ &kServiceQueryRecordTypes[1],
726 
727 #endif
728 };
729 
Client(Instance & aInstance)730 Client::Client(Instance &aInstance)
731     : InstanceLocator(aInstance)
732     , mSocket(aInstance, *this)
733 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
734     , mTcpState(kTcpUninitialized)
735 #endif
736     , mTimer(aInstance)
737     , mDefaultConfig(QueryConfig::kInitFromDefaults)
738 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
739     , mUserDidSetDefaultAddress(false)
740 #endif
741 {
742     struct QueryTypeChecker
743     {
744         InitEnumValidatorCounter();
745 
746         ValidateNextEnum(kIp6AddressQuery);
747 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
748         ValidateNextEnum(kIp4AddressQuery);
749 #endif
750 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
751         ValidateNextEnum(kBrowseQuery);
752         ValidateNextEnum(kServiceQuerySrvTxt);
753         ValidateNextEnum(kServiceQuerySrv);
754         ValidateNextEnum(kServiceQueryTxt);
755 #endif
756     };
757 
758 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
759     ClearAllBytes(mSendLink);
760 #endif
761 }
762 
Start(void)763 Error Client::Start(void)
764 {
765     Error error;
766 
767     SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified));
768     SuccessOrExit(error = mSocket.Bind(0));
769 
770 exit:
771     return error;
772 }
773 
Stop(void)774 void Client::Stop(void)
775 {
776     Query *query;
777 
778     while ((query = mMainQueries.GetHead()) != nullptr)
779     {
780         FinalizeQuery(*query, kErrorAbort);
781     }
782 
783     IgnoreError(mSocket.Close());
784 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
785     if (mTcpState != kTcpUninitialized)
786     {
787         IgnoreError(mEndpoint.Deinitialize());
788     }
789 #endif
790 
791     mLimitedQueryServers.Clear();
792 }
793 
794 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)795 Error Client::InitTcpSocket(void)
796 {
797     Error                       error;
798     otTcpEndpointInitializeArgs endpointArgs;
799 
800     ClearAllBytes(endpointArgs);
801     endpointArgs.mSendDoneCallback         = HandleTcpSendDoneCallback;
802     endpointArgs.mEstablishedCallback      = HandleTcpEstablishedCallback;
803     endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
804     endpointArgs.mDisconnectedCallback     = HandleTcpDisconnectedCallback;
805     endpointArgs.mContext                  = this;
806     endpointArgs.mReceiveBuffer            = mReceiveBufferBytes;
807     endpointArgs.mReceiveBufferSize        = sizeof(mReceiveBufferBytes);
808 
809     mSendLink.mNext   = nullptr;
810     mSendLink.mData   = mSendBufferBytes;
811     mSendLink.mLength = 0;
812 
813     SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
814 exit:
815     return error;
816 }
817 #endif
818 
SetDefaultConfig(const QueryConfig & aQueryConfig)819 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
820 {
821     QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
822 
823     mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
824 
825 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
826     mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
827     UpdateDefaultConfigAddress();
828 #endif
829 }
830 
ResetDefaultConfig(void)831 void Client::ResetDefaultConfig(void)
832 {
833     mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
834 
835 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
836     mUserDidSetDefaultAddress = false;
837     UpdateDefaultConfigAddress();
838 #endif
839 }
840 
841 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)842 void Client::UpdateDefaultConfigAddress(void)
843 {
844     const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
845 
846     if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
847         !srpServerAddr.IsUnspecified())
848     {
849         mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
850     }
851 }
852 #endif
853 
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)854 Error Client::ResolveAddress(const char        *aHostName,
855                              AddressCallback    aCallback,
856                              void              *aContext,
857                              const QueryConfig *aConfig)
858 {
859     QueryInfo info;
860 
861     info.Clear();
862     info.mQueryType = kIp6AddressQuery;
863     info.mConfig.SetFrom(aConfig, mDefaultConfig);
864     info.mCallback.mAddressCallback = aCallback;
865     info.mCallbackContext           = aContext;
866 
867     return StartQuery(info, nullptr, aHostName);
868 }
869 
870 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)871 Error Client::ResolveIp4Address(const char        *aHostName,
872                                 AddressCallback    aCallback,
873                                 void              *aContext,
874                                 const QueryConfig *aConfig)
875 {
876     QueryInfo info;
877 
878     info.Clear();
879     info.mQueryType = kIp4AddressQuery;
880     info.mConfig.SetFrom(aConfig, mDefaultConfig);
881     info.mCallback.mAddressCallback = aCallback;
882     info.mCallbackContext           = aContext;
883 
884     return StartQuery(info, nullptr, aHostName);
885 }
886 #endif
887 
888 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
889 
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)890 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
891 {
892     QueryInfo info;
893 
894     info.Clear();
895     info.mQueryType = kBrowseQuery;
896     info.mConfig.SetFrom(aConfig, mDefaultConfig);
897     info.mCallback.mBrowseCallback = aCallback;
898     info.mCallbackContext          = aContext;
899 
900     return StartQuery(info, nullptr, aServiceName);
901 }
902 
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)903 Error Client::ResolveService(const char        *aInstanceLabel,
904                              const char        *aServiceName,
905                              ServiceCallback    aCallback,
906                              void              *aContext,
907                              const QueryConfig *aConfig)
908 {
909     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
910 }
911 
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)912 Error Client::ResolveServiceAndHostAddress(const char        *aInstanceLabel,
913                                            const char        *aServiceName,
914                                            ServiceCallback    aCallback,
915                                            void              *aContext,
916                                            const QueryConfig *aConfig)
917 {
918     return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
919 }
920 
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)921 Error Client::Resolve(const char        *aInstanceLabel,
922                       const char        *aServiceName,
923                       ServiceCallback    aCallback,
924                       void              *aContext,
925                       const QueryConfig *aConfig,
926                       bool               aShouldResolveHostAddr)
927 {
928     QueryInfo info;
929     Error     error;
930     QueryType secondQueryType = kNoQuery;
931 
932     VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
933 
934     info.Clear();
935 
936     info.mConfig.SetFrom(aConfig, mDefaultConfig);
937     info.mShouldResolveHostAddr = aShouldResolveHostAddr;
938 
939     CheckAndUpdateServiceMode(info.mConfig, aConfig);
940 
941     switch (info.mConfig.GetServiceMode())
942     {
943     case QueryConfig::kServiceModeSrvTxtSeparate:
944         secondQueryType = kServiceQueryTxt;
945 
946         OT_FALL_THROUGH;
947 
948     case QueryConfig::kServiceModeSrv:
949         info.mQueryType = kServiceQuerySrv;
950         break;
951 
952     case QueryConfig::kServiceModeTxt:
953         info.mQueryType = kServiceQueryTxt;
954         VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
955         break;
956 
957     case QueryConfig::kServiceModeSrvTxt:
958     case QueryConfig::kServiceModeSrvTxtOptimize:
959     default:
960         info.mQueryType = kServiceQuerySrvTxt;
961         break;
962     }
963 
964     info.mCallback.mServiceCallback = aCallback;
965     info.mCallbackContext           = aContext;
966 
967     error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
968 
969 exit:
970     return error;
971 }
972 
973 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
974 
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)975 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
976 {
977     // The `aLabel` can be `nullptr` and then `aName` provides the
978     // full name, otherwise the name is appended as `{aLabel}.
979     // {aName}`.
980 
981     Error  error;
982     Query *query;
983 
984     VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
985 
986 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
987     if (aInfo.mQueryType == kIp4AddressQuery)
988     {
989         NetworkData::ExternalRouteConfig nat64Prefix;
990 
991         VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
992         VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
993                      error = kErrorInvalidState);
994     }
995 #endif
996 
997     SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
998 
999     mMainQueries.Enqueue(*query);
1000 
1001     error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1002     VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1003 
1004     if (aSecondType != kNoQuery)
1005     {
1006         Query *secondQuery;
1007 
1008         aInfo.mQueryType         = aSecondType;
1009         aInfo.mMessageId         = 0;
1010         aInfo.mTransmissionCount = 0;
1011         aInfo.mMainQuery         = query;
1012 
1013         // We intentionally do not use `error` here so in the unlikely
1014         // case where we cannot allocate the second query we can proceed
1015         // with the first one.
1016         SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1017 
1018         IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1019 
1020         // Update first query to link to second one by updating
1021         // its `mNextQuery`.
1022         aInfo.ReadFrom(*query);
1023         aInfo.mNextQuery = secondQuery;
1024         UpdateQuery(*query, aInfo);
1025     }
1026 
1027 exit:
1028     return error;
1029 }
1030 
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1031 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1032 {
1033     Error error = kErrorNone;
1034 
1035     aQuery = nullptr;
1036 
1037     VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1038 
1039     aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1040     VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1041 
1042     SuccessOrExit(error = aQuery->Append(aInfo));
1043 
1044     if (aLabel != nullptr)
1045     {
1046         SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1047     }
1048 
1049     SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1050 
1051 exit:
1052     FreeAndNullMessageOnError(aQuery, error);
1053     return error;
1054 }
1055 
FindMainQuery(Query & aQuery)1056 Client::Query &Client::FindMainQuery(Query &aQuery)
1057 {
1058     QueryInfo info;
1059 
1060     info.ReadFrom(aQuery);
1061 
1062     return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1063 }
1064 
FreeQuery(Query & aQuery)1065 void Client::FreeQuery(Query &aQuery)
1066 {
1067     Query    &mainQuery = FindMainQuery(aQuery);
1068     QueryInfo info;
1069 
1070     mMainQueries.Dequeue(mainQuery);
1071 
1072     for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1073     {
1074         info.ReadFrom(*query);
1075         FreeMessage(info.mSavedResponse);
1076         query->Free();
1077     }
1078 }
1079 
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1080 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1081 {
1082     // This method prepares and sends a query message represented by
1083     // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1084     // the new `mRetransmissionTime`) and updates it in `aQuery` as
1085     // well. `aUpdateTimer` indicates whether the timer should be
1086     // updated when query is sent or not (used in the case where timer
1087     // is handled by caller).
1088 
1089     Error            error   = kErrorNone;
1090     Message         *message = nullptr;
1091     Header           header;
1092     Ip6::MessageInfo messageInfo;
1093     uint16_t         length = 0;
1094 
1095     aInfo.mTransmissionCount++;
1096     aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1097 
1098     if (aInfo.mMessageId == 0)
1099     {
1100         do
1101         {
1102             SuccessOrExit(error = header.SetRandomMessageId());
1103         } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1104 
1105         aInfo.mMessageId = header.GetMessageId();
1106     }
1107     else
1108     {
1109         header.SetMessageId(aInfo.mMessageId);
1110     }
1111 
1112     header.SetType(Header::kTypeQuery);
1113     header.SetQueryType(Header::kQueryTypeStandard);
1114 
1115     if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1116     {
1117         header.SetRecursionDesiredFlag();
1118     }
1119 
1120     header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1121 
1122     message = mSocket.NewMessage();
1123     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1124 
1125     SuccessOrExit(error = message->Append(header));
1126 
1127     // Prepare the question section.
1128 
1129     for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1130     {
1131         SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1132         SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1133     }
1134 
1135     length = message->GetLength() - message->GetOffset();
1136 
1137     if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1138 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1139     {
1140         // Check if query will fit into tcp buffer if not return error.
1141         VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1142                          OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1143                      error = kErrorNoBufs);
1144 
1145         // In case of initialized connection check if connected peer and new query have the same address.
1146         if (mTcpState != kTcpUninitialized)
1147         {
1148             VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1149                          error = kErrorFailed);
1150         }
1151 
1152         switch (mTcpState)
1153         {
1154         case kTcpUninitialized:
1155             SuccessOrExit(error = InitTcpSocket());
1156             SuccessOrExit(
1157                 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1158             mTcpState = kTcpConnecting;
1159             PrepareTcpMessage(*message);
1160             break;
1161         case kTcpConnectedIdle:
1162             PrepareTcpMessage(*message);
1163             SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1164             mTcpState = kTcpConnectedSending;
1165             break;
1166         case kTcpConnecting:
1167             PrepareTcpMessage(*message);
1168             break;
1169         case kTcpConnectedSending:
1170             BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1171             SuccessOrAssert(error = message->Read(message->GetOffset(),
1172                                                   (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1173             IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1174             break;
1175         }
1176         message->Free();
1177         message = nullptr;
1178     }
1179 #else
1180     {
1181         error = kErrorInvalidArgs;
1182         LogWarn("DNS query over TCP not supported.");
1183         ExitNow();
1184     }
1185 #endif
1186     else
1187     {
1188         VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1189         messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1190         messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1191         SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1192     }
1193 
1194 exit:
1195 
1196     FreeMessageOnError(message, error);
1197     if (aUpdateTimer)
1198     {
1199         mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1200     }
1201 
1202     UpdateQuery(aQuery, aInfo);
1203 
1204     return error;
1205 }
1206 
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1207 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1208 {
1209     // The name is encoded and included after the `Info` in `aQuery`
1210     // starting at `kNameOffsetInQuery`.
1211 
1212     return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1213 }
1214 
FinalizeQuery(Query & aQuery,Error aError)1215 void Client::FinalizeQuery(Query &aQuery, Error aError)
1216 {
1217     Response response;
1218     Query   &mainQuery = FindMainQuery(aQuery);
1219 
1220     response.mInstance = &Get<Instance>();
1221     response.mQuery    = &mainQuery;
1222 
1223     FinalizeQuery(response, aError);
1224 }
1225 
FinalizeQuery(Response & aResponse,Error aError)1226 void Client::FinalizeQuery(Response &aResponse, Error aError)
1227 {
1228     QueryType type;
1229     Callback  callback;
1230     void     *context;
1231 
1232     GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1233 
1234     switch (type)
1235     {
1236     case kIp6AddressQuery:
1237 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1238     case kIp4AddressQuery:
1239 #endif
1240         if (callback.mAddressCallback != nullptr)
1241         {
1242             callback.mAddressCallback(aError, &aResponse, context);
1243         }
1244         break;
1245 
1246 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1247     case kBrowseQuery:
1248         if (callback.mBrowseCallback != nullptr)
1249         {
1250             callback.mBrowseCallback(aError, &aResponse, context);
1251         }
1252         break;
1253 
1254     case kServiceQuerySrvTxt:
1255     case kServiceQuerySrv:
1256     case kServiceQueryTxt:
1257         if (callback.mServiceCallback != nullptr)
1258         {
1259             callback.mServiceCallback(aError, &aResponse, context);
1260         }
1261         break;
1262 #endif
1263     case kNoQuery:
1264         break;
1265     }
1266 
1267     FreeQuery(*aResponse.mQuery);
1268 }
1269 
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1270 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1271 {
1272     QueryInfo info;
1273 
1274     info.ReadFrom(aQuery);
1275 
1276     aType     = info.mQueryType;
1277     aCallback = info.mCallback;
1278     aContext  = info.mCallbackContext;
1279 }
1280 
FindQueryById(uint16_t aMessageId)1281 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1282 {
1283     Query    *matchedQuery = nullptr;
1284     QueryInfo info;
1285 
1286     for (Query &mainQuery : mMainQueries)
1287     {
1288         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1289         {
1290             info.ReadFrom(*query);
1291 
1292             if (info.mMessageId == aMessageId)
1293             {
1294                 matchedQuery = query;
1295                 ExitNow();
1296             }
1297         }
1298     }
1299 
1300 exit:
1301     return matchedQuery;
1302 }
1303 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMsgInfo)1304 void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMsgInfo)
1305 {
1306     OT_UNUSED_VARIABLE(aMsgInfo);
1307     ProcessResponse(aMessage);
1308 }
1309 
ProcessResponse(const Message & aResponseMessage)1310 void Client::ProcessResponse(const Message &aResponseMessage)
1311 {
1312     Error  responseError;
1313     Query *query;
1314 
1315     SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1316 
1317 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1318     if (ReplaceWithIp4Query(*query, aResponseMessage) == kErrorNone)
1319     {
1320         ExitNow();
1321     }
1322 #endif
1323 
1324     if (responseError != kErrorNone)
1325     {
1326         // Received an error from server, check if we can replace
1327         // the query.
1328 
1329 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1330         if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1331         {
1332             ExitNow();
1333         }
1334 #endif
1335 
1336         FinalizeQuery(*query, responseError);
1337         ExitNow();
1338     }
1339 
1340     // Received successful response from server.
1341 
1342 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1343     ResolveHostAddressIfNeeded(*query, aResponseMessage);
1344 #endif
1345 
1346     if (!CanFinalizeQuery(*query))
1347     {
1348         SaveQueryResponse(*query, aResponseMessage);
1349         ExitNow();
1350     }
1351 
1352     PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1353 
1354 exit:
1355     return;
1356 }
1357 
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1358 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1359 {
1360     Error     error  = kErrorNone;
1361     uint16_t  offset = aResponseMessage.GetOffset();
1362     Header    header;
1363     QueryInfo info;
1364     Name      queryName;
1365 
1366     SuccessOrExit(error = aResponseMessage.Read(offset, header));
1367     offset += sizeof(Header);
1368 
1369     VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1370                      !header.IsTruncationFlagSet(),
1371                  error = kErrorDrop);
1372 
1373     aQuery = FindQueryById(header.GetMessageId());
1374     VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1375 
1376     info.ReadFrom(*aQuery);
1377 
1378     queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1379 
1380     // Check the Question Section
1381 
1382     if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1383     {
1384         for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1385         {
1386             SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1387             offset += sizeof(Question);
1388         }
1389     }
1390     else
1391     {
1392         VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1393                      error = kErrorParse);
1394     }
1395 
1396     // Check the answer, authority and additional record sections
1397 
1398     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1399     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1400     SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1401 
1402     // Read the response code
1403 
1404     aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1405 
1406     if ((aResponseError == kErrorNone) && (info.mQueryType == kServiceQuerySrvTxt))
1407     {
1408         RecordServerAsCapableOfMultiQuestions(info.mConfig.GetServerSockAddr().GetAddress());
1409     }
1410 
1411 exit:
1412     return error;
1413 }
1414 
CanFinalizeQuery(Query & aQuery)1415 bool Client::CanFinalizeQuery(Query &aQuery)
1416 {
1417     // Determines whether we can finalize a main query by checking if
1418     // we have received and saved responses for all other related
1419     // queries associated with `aQuery`. Note that this method is
1420     // called when we receive a response for `aQuery`, so no need to
1421     // check for a saved response for `aQuery` itself.
1422 
1423     bool      canFinalize = true;
1424     QueryInfo info;
1425 
1426     for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1427     {
1428         info.ReadFrom(*query);
1429 
1430         if (query == &aQuery)
1431         {
1432             continue;
1433         }
1434 
1435         if (info.mSavedResponse == nullptr)
1436         {
1437             canFinalize = false;
1438             ExitNow();
1439         }
1440     }
1441 
1442 exit:
1443     return canFinalize;
1444 }
1445 
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1446 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1447 {
1448     QueryInfo info;
1449 
1450     info.ReadFrom(aQuery);
1451     VerifyOrExit(info.mSavedResponse == nullptr);
1452 
1453     // If `Clone()` fails we let retry or timeout handle the error.
1454     info.mSavedResponse = aResponseMessage.Clone();
1455 
1456     UpdateQuery(aQuery, info);
1457 
1458 exit:
1459     return;
1460 }
1461 
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1462 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1463 {
1464     // Populate `aResponse` for `aQuery`. If there is a saved response
1465     // message for `aQuery` we use it, otherwise, we use
1466     // `aResponseMessage`.
1467 
1468     QueryInfo info;
1469 
1470     info.ReadFrom(aQuery);
1471 
1472     aResponse.mInstance = &Get<Instance>();
1473     aResponse.mQuery    = &aQuery;
1474     aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1475 
1476     return info.mNextQuery;
1477 }
1478 
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1479 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1480 {
1481     // This method prepares a list of chained `Response` instances
1482     // corresponding to all related (chained) queries. It uses
1483     // recursion to go through the queries and construct the
1484     // `Response` chain.
1485 
1486     Response response;
1487     Query   *nextQuery;
1488 
1489     nextQuery      = PopulateResponse(response, aQuery, aResponseMessage);
1490     response.mNext = aPrevResponse;
1491 
1492     if (nextQuery != nullptr)
1493     {
1494         PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1495     }
1496     else
1497     {
1498         FinalizeQuery(response, kErrorNone);
1499     }
1500 }
1501 
HandleTimer(void)1502 void Client::HandleTimer(void)
1503 {
1504     NextFireTime nextTime;
1505     QueryInfo    info;
1506 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1507     bool hasTcpQuery = false;
1508 #endif
1509 
1510     for (Query &mainQuery : mMainQueries)
1511     {
1512         for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1513         {
1514             info.ReadFrom(*query);
1515 
1516             if (info.mSavedResponse != nullptr)
1517             {
1518                 continue;
1519             }
1520 
1521             if (nextTime.GetNow() >= info.mRetransmissionTime)
1522             {
1523                 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1524                 {
1525                     FinalizeQuery(*query, kErrorResponseTimeout);
1526                     break;
1527                 }
1528 
1529 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1530                 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1531                 {
1532                     LogInfo("Switching to separate SRV/TXT on response timeout");
1533                     info.ReadFrom(*query);
1534                 }
1535                 else
1536 #endif
1537                 {
1538                     IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1539                 }
1540             }
1541 
1542             nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1543 
1544 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1545             if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1546             {
1547                 hasTcpQuery = true;
1548             }
1549 #endif
1550         }
1551     }
1552 
1553     mTimer.FireAtIfEarlier(nextTime);
1554 
1555 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1556     if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1557     {
1558         IgnoreError(mEndpoint.SendEndOfStream());
1559     }
1560 #endif
1561 }
1562 
1563 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1564 
ReplaceWithIp4Query(Query & aQuery,const Message & aResponseMessage)1565 Error Client::ReplaceWithIp4Query(Query &aQuery, const Message &aResponseMessage)
1566 {
1567     Error     error = kErrorFailed;
1568     QueryInfo info;
1569     Header    header;
1570 
1571     info.ReadFrom(aQuery);
1572 
1573     VerifyOrExit(info.mQueryType == kIp6AddressQuery);
1574     VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1575 
1576     // Check the response to the IPv6 query from the server. If the
1577     // response code is success but the answer section is empty
1578     // (indicating the name exists but has no IPv6 address), or the
1579     // response code indicates an error other than `NameError`, we
1580     // replace the query with an IPv4 address resolution query for
1581     // the same name. If the server responded with `NameError`
1582     // (RCode=3), it indicates that the name doesn't exist, so there
1583     // is no need to try an IPv4 query.
1584 
1585     SuccessOrExit(aResponseMessage.Read(aResponseMessage.GetOffset(), header));
1586 
1587     switch (header.GetResponseCode())
1588     {
1589     case Header::kResponseSuccess:
1590         VerifyOrExit(header.GetAnswerCount() == 0);
1591         OT_FALL_THROUGH;
1592 
1593     default:
1594         break;
1595 
1596     case Header::kResponseNameError:
1597         ExitNow();
1598     }
1599 
1600     // We send a new query for IPv4 address resolution
1601     // for the same host name. We reuse the existing `aQuery`
1602     // instance and keep all the info but clear `mTransmissionCount`
1603     // and `mMessageId` (so that a new random message ID is
1604     // selected). The new `info` will be saved in the query in
1605     // `SendQuery()`. Note that the current query is still in the
1606     // `mMainQueries` list when `SendQuery()` selects a new random
1607     // message ID, so the existing message ID for this query will
1608     // not be reused.
1609 
1610     info.mQueryType         = kIp4AddressQuery;
1611     info.mMessageId         = 0;
1612     info.mTransmissionCount = 0;
1613 
1614     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1615     error = kErrorNone;
1616 
1617 exit:
1618     return error;
1619 }
1620 
1621 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1622 
1623 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1624 
CheckAndUpdateServiceMode(QueryConfig & aConfig,const QueryConfig * aRequestConfig) const1625 void Client::CheckAndUpdateServiceMode(QueryConfig &aConfig, const QueryConfig *aRequestConfig) const
1626 {
1627     // If the user explicitly requested "optimize" mode, we honor that
1628     // request. Otherwise, if "optimize" is chosen from the default
1629     // config, we check if the DNS server is known to have trouble
1630     // with multiple-question queries. If so, we switch to "separate"
1631     // mode.
1632 
1633     if ((aRequestConfig != nullptr) && (aRequestConfig->GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize))
1634     {
1635         ExitNow();
1636     }
1637 
1638     VerifyOrExit(aConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1639 
1640     if (mLimitedQueryServers.Contains(aConfig.GetServerSockAddr().GetAddress()))
1641     {
1642         aConfig.SetServiceMode(QueryConfig::kServiceModeSrvTxtSeparate);
1643     }
1644 
1645 exit:
1646     return;
1647 }
1648 
RecordServerAsLimitedToSingleQuestion(const Ip6::Address & aServerAddress)1649 void Client::RecordServerAsLimitedToSingleQuestion(const Ip6::Address &aServerAddress)
1650 {
1651     VerifyOrExit(!aServerAddress.IsUnspecified());
1652 
1653     VerifyOrExit(!mLimitedQueryServers.Contains(aServerAddress));
1654 
1655     if (mLimitedQueryServers.IsFull())
1656     {
1657         uint8_t randomIndex = Random::NonCrypto::GetUint8InRange(0, mLimitedQueryServers.GetMaxSize());
1658 
1659         mLimitedQueryServers.Remove(mLimitedQueryServers[randomIndex]);
1660     }
1661 
1662     IgnoreError(mLimitedQueryServers.PushBack(aServerAddress));
1663 
1664 exit:
1665     return;
1666 }
1667 
RecordServerAsCapableOfMultiQuestions(const Ip6::Address & aServerAddress)1668 void Client::RecordServerAsCapableOfMultiQuestions(const Ip6::Address &aServerAddress)
1669 {
1670     Ip6::Address *entry = mLimitedQueryServers.Find(aServerAddress);
1671 
1672     VerifyOrExit(entry != nullptr);
1673     mLimitedQueryServers.Remove(*entry);
1674 
1675 exit:
1676     return;
1677 }
1678 
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1679 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1680 {
1681     Error     error = kErrorFailed;
1682     QueryInfo info;
1683     Query    *secondQuery;
1684 
1685     info.ReadFrom(aQuery);
1686 
1687     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1688     VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1689 
1690     RecordServerAsLimitedToSingleQuestion(info.mConfig.GetServerSockAddr().GetAddress());
1691 
1692     secondQuery = aQuery.Clone();
1693     VerifyOrExit(secondQuery != nullptr);
1694 
1695     info.mQueryType         = kServiceQueryTxt;
1696     info.mMessageId         = 0;
1697     info.mTransmissionCount = 0;
1698     info.mMainQuery         = &aQuery;
1699     IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1700 
1701     info.mQueryType         = kServiceQuerySrv;
1702     info.mMessageId         = 0;
1703     info.mTransmissionCount = 0;
1704     info.mNextQuery         = secondQuery;
1705     IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1706     error = kErrorNone;
1707 
1708 exit:
1709     return error;
1710 }
1711 
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1712 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1713 {
1714     QueryInfo   info;
1715     Response    response;
1716     ServiceInfo serviceInfo;
1717     char        hostName[Name::kMaxNameSize];
1718 
1719     info.ReadFrom(aQuery);
1720 
1721     VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1722     VerifyOrExit(info.mShouldResolveHostAddr);
1723 
1724     PopulateResponse(response, aQuery, aResponseMessage);
1725 
1726     ClearAllBytes(serviceInfo);
1727     serviceInfo.mHostNameBuffer     = hostName;
1728     serviceInfo.mHostNameBufferSize = sizeof(hostName);
1729     SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1730 
1731     // Check whether AAAA record for host address is provided in the SRV query response
1732 
1733     if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1734     {
1735         Query *newQuery;
1736 
1737         info.mQueryType         = kIp6AddressQuery;
1738         info.mMessageId         = 0;
1739         info.mTransmissionCount = 0;
1740         info.mMainQuery         = &FindMainQuery(aQuery);
1741 
1742         SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1743         IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1744 
1745         // Update `aQuery` to be linked with new query (inserting
1746         // the `newQuery` into the linked-list after `aQuery`).
1747 
1748         info.ReadFrom(aQuery);
1749         info.mNextQuery = newQuery;
1750         UpdateQuery(aQuery, info);
1751     }
1752 
1753 exit:
1754     return;
1755 }
1756 
1757 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1758 
1759 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1760 void Client::PrepareTcpMessage(Message &aMessage)
1761 {
1762     uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1763 
1764     // Prepending the DNS query with length of the packet according to RFC1035.
1765     BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1766     SuccessOrAssert(
1767         aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1768     mSendLink.mLength += length + sizeof(uint16_t);
1769 }
1770 
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1771 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1772 {
1773     OT_UNUSED_VARIABLE(aEndpoint);
1774     OT_UNUSED_VARIABLE(aData);
1775     OT_ASSERT(mTcpState == kTcpConnectedSending);
1776 
1777     mSendLink.mLength = 0;
1778     mTcpState         = kTcpConnectedIdle;
1779 }
1780 
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1781 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1782 {
1783     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1784 }
1785 
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1786 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1787 {
1788     OT_UNUSED_VARIABLE(aEndpoint);
1789     IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1790     mTcpState = kTcpConnectedSending;
1791 }
1792 
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1793 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1794 {
1795     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1796 }
1797 
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1798 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1799                                  size_t                &aOffset,
1800                                  Message               &aMessage,
1801                                  uint16_t               aLength)
1802 {
1803     // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1804     // and copy the content into `aMessage`. As we read we can move
1805     // to the next `aLinkedBuffer` and update `aOffset`.
1806     // Returns:
1807     // - `kErrorNone` if `aLength` bytes are successfully read and
1808     //    `aOffset` and `aLinkedBuffer` are updated.
1809     // - `kErrorNotFound` is not enough bytes available to read
1810     //    from `aLinkedBuffer`.
1811     // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1812 
1813     Error error = kErrorNone;
1814 
1815     while (aLength > 0)
1816     {
1817         uint16_t bytesToRead = aLength;
1818 
1819         VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1820 
1821         if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1822         {
1823             bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1824         }
1825 
1826         SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1827 
1828         aLength -= bytesToRead;
1829         aOffset += bytesToRead;
1830 
1831         if (aOffset == aLinkedBuffer->mLength)
1832         {
1833             aLinkedBuffer = aLinkedBuffer->mNext;
1834             aOffset       = 0;
1835         }
1836     }
1837 
1838 exit:
1839     return error;
1840 }
1841 
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1842 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1843                                        size_t         aBytesAvailable,
1844                                        bool           aEndOfStream,
1845                                        size_t         aBytesRemaining)
1846 {
1847     OT_UNUSED_VARIABLE(aEndpoint);
1848     OT_UNUSED_VARIABLE(aBytesRemaining);
1849 
1850     Message              *message   = nullptr;
1851     size_t                totalRead = 0;
1852     size_t                offset    = 0;
1853     const otLinkedBuffer *data;
1854 
1855     if (aEndOfStream)
1856     {
1857         // Cleanup is done in disconnected callback.
1858         IgnoreError(mEndpoint.SendEndOfStream());
1859     }
1860 
1861     SuccessOrExit(mEndpoint.ReceiveByReference(data));
1862     VerifyOrExit(data != nullptr);
1863 
1864     message = mSocket.NewMessage();
1865     VerifyOrExit(message != nullptr);
1866 
1867     while (aBytesAvailable > totalRead)
1868     {
1869         uint16_t length;
1870 
1871         // Read the `length` field.
1872         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1873 
1874         IgnoreError(message->Read(/* aOffset */ 0, length));
1875         length = BigEndian::HostSwap16(length);
1876 
1877         // Try to read `length` bytes.
1878         IgnoreError(message->SetLength(0));
1879         SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1880 
1881         totalRead += length + sizeof(uint16_t);
1882 
1883         // Now process the read message as query response.
1884         ProcessResponse(*message);
1885 
1886         IgnoreError(message->SetLength(0));
1887 
1888         // Loop again to see if we can read another response.
1889     }
1890 
1891 exit:
1892     // Inform `mEndPoint` about the total read and processed bytes
1893     IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1894     FreeMessage(message);
1895 }
1896 
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1897 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1898                                                size_t         aBytesAvailable,
1899                                                bool           aEndOfStream,
1900                                                size_t         aBytesRemaining)
1901 {
1902     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1903         ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1904 }
1905 
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1906 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1907 {
1908     OT_UNUSED_VARIABLE(aEndpoint);
1909     OT_UNUSED_VARIABLE(aReason);
1910     QueryInfo info;
1911 
1912     IgnoreError(mEndpoint.Deinitialize());
1913     mTcpState = kTcpUninitialized;
1914 
1915     // Abort queries in case of connection failures
1916     for (Query &mainQuery : mMainQueries)
1917     {
1918         info.ReadFrom(mainQuery);
1919 
1920         if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1921         {
1922             FinalizeQuery(mainQuery, kErrorAbort);
1923         }
1924     }
1925 }
1926 
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1927 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1928 {
1929     static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1930 }
1931 
1932 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1933 
1934 } // namespace Dns
1935 } // namespace ot
1936 
1937 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1938