1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43
44 /**
45 * @file
46 * This file implements the DNS client.
47 */
48
49 namespace ot {
50 namespace Dns {
51
52 RegisterLogModule("DnsClient");
53
54 //---------------------------------------------------------------------------------------------------------------------
55 // Client::QueryConfig
56
57 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
58
QueryConfig(InitMode aMode)59 Client::QueryConfig::QueryConfig(InitMode aMode)
60 {
61 OT_UNUSED_VARIABLE(aMode);
62
63 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
64 GetServerSockAddr().SetPort(kDefaultServerPort);
65 SetResponseTimeout(kDefaultResponseTimeout);
66 SetMaxTxAttempts(kDefaultMaxTxAttempts);
67 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
68 SetServiceMode(kDefaultServiceMode);
69 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
70 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
71 #endif
72 SetTransportProto(kDnsTransportUdp);
73 }
74
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)75 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
76 {
77 // This method sets the config from `aConfig` replacing any
78 // unspecified fields (value zero) with the fields from
79 // `aDefaultConfig`. If `aConfig` is `nullptr` then
80 // `aDefaultConfig` is used.
81
82 if (aConfig == nullptr)
83 {
84 *this = aDefaultConfig;
85 ExitNow();
86 }
87
88 *this = *aConfig;
89
90 if (GetServerSockAddr().GetAddress().IsUnspecified())
91 {
92 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
93 }
94
95 if (GetServerSockAddr().GetPort() == 0)
96 {
97 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
98 }
99
100 if (GetResponseTimeout() == 0)
101 {
102 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
103 }
104
105 if (GetMaxTxAttempts() == 0)
106 {
107 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
108 }
109
110 if (GetRecursionFlag() == kFlagUnspecified)
111 {
112 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
113 }
114
115 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
116 if (GetNat64Mode() == kNat64Unspecified)
117 {
118 SetNat64Mode(aDefaultConfig.GetNat64Mode());
119 }
120 #endif
121
122 if (GetServiceMode() == kServiceModeUnspecified)
123 {
124 SetServiceMode(aDefaultConfig.GetServiceMode());
125 }
126
127 if (GetTransportProto() == kDnsTransportUnspecified)
128 {
129 SetTransportProto(aDefaultConfig.GetTransportProto());
130 }
131
132 exit:
133 return;
134 }
135
136 //---------------------------------------------------------------------------------------------------------------------
137 // Client::Response
138
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const139 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
140 {
141 switch (aSection)
142 {
143 case kAnswerSection:
144 aOffset = mAnswerOffset;
145 aNumRecord = mAnswerRecordCount;
146 break;
147 case kAdditionalDataSection:
148 default:
149 aOffset = mAdditionalOffset;
150 aNumRecord = mAdditionalRecordCount;
151 break;
152 }
153 }
154
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const155 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
156 {
157 uint16_t offset = kNameOffsetInQuery;
158
159 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
160 }
161
CheckForHostNameAlias(Section aSection,Name & aHostName) const162 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
163 {
164 // If the response includes a CNAME record mapping the query host
165 // name to a canonical name, we update `aHostName` to the new alias
166 // name. Otherwise `aHostName` remains as before. This method handles
167 // when there are multiple CNAME records mapping the host name multiple
168 // times. We limit number of changes to `kMaxCnameAliasNameChanges`
169 // to detect and handle if the response contains CNAME record loops.
170
171 Error error;
172 uint16_t offset;
173 uint16_t numRecords;
174 CnameRecord cnameRecord;
175
176 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
177
178 for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
179 {
180 SelectSection(aSection, offset, numRecords);
181 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
182
183 if (error == kErrorNotFound)
184 {
185 error = kErrorNone;
186 ExitNow();
187 }
188
189 SuccessOrExit(error);
190
191 // A CNAME record was found. `offset` now points to after the
192 // last read byte within the `mMessage` into the `cnameRecord`
193 // (which is the start of the new canonical name).
194
195 aHostName.SetFromMessage(*mMessage, offset);
196 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
197
198 // Loop back to check if there may be a CNAME record for the
199 // new `aHostName`.
200 }
201
202 error = kErrorParse;
203
204 exit:
205 return error;
206 }
207
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const208 Error Client::Response::FindHostAddress(Section aSection,
209 const Name &aHostName,
210 uint16_t aIndex,
211 Ip6::Address &aAddress,
212 uint32_t &aTtl) const
213 {
214 Error error;
215 uint16_t offset;
216 uint16_t numRecords;
217 Name name = aHostName;
218 AaaaRecord aaaaRecord;
219
220 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
221
222 SelectSection(aSection, offset, numRecords);
223 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
224 aAddress = aaaaRecord.GetAddress();
225 aTtl = aaaaRecord.GetTtl();
226
227 exit:
228 return error;
229 }
230
231 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
232
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const233 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
234 {
235 Error error;
236 uint16_t offset;
237 uint16_t numRecords;
238 Name name = aHostName;
239
240 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
241
242 SelectSection(aSection, offset, numRecords);
243 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
244
245 exit:
246 return error;
247 }
248
249 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
250
251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
252
InitServiceInfo(ServiceInfo & aServiceInfo) const253 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
254 {
255 // This method initializes `aServiceInfo` setting all
256 // TTLs to zero and host name to empty string.
257
258 aServiceInfo.mTtl = 0;
259 aServiceInfo.mHostAddressTtl = 0;
260 aServiceInfo.mTxtDataTtl = 0;
261 aServiceInfo.mTxtDataTruncated = false;
262
263 AsCoreType(&aServiceInfo.mHostAddress).Clear();
264
265 if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
266 {
267 aServiceInfo.mHostNameBuffer[0] = '\0';
268 }
269 }
270
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const271 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
272 {
273 // This method searches for SRV record in the given `aSection`
274 // matching the record name against `aName`, and updates the
275 // `aServiceInfo` accordingly. It also searches for AAAA record
276 // for host name associated with the service (from SRV record).
277 // The search for AAAA record is always performed in Additional
278 // Data section (independent of the value given in `aSection`).
279
280 Error error = kErrorNone;
281 uint16_t offset;
282 uint16_t numRecords;
283 Name hostName;
284 SrvRecord srvRecord;
285
286 // A non-zero `mTtl` indicates that SRV record is already found
287 // and parsed from a previous response.
288 VerifyOrExit(aServiceInfo.mTtl == 0);
289
290 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
291
292 // Search for a matching SRV record
293 SelectSection(aSection, offset, numRecords);
294 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
295
296 aServiceInfo.mTtl = srvRecord.GetTtl();
297 aServiceInfo.mPort = srvRecord.GetPort();
298 aServiceInfo.mPriority = srvRecord.GetPriority();
299 aServiceInfo.mWeight = srvRecord.GetWeight();
300
301 hostName.SetFromMessage(*mMessage, offset);
302
303 if (aServiceInfo.mHostNameBuffer != nullptr)
304 {
305 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
306 aServiceInfo.mHostNameBufferSize));
307 }
308 else
309 {
310 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
311 }
312
313 // Search in additional section for AAAA record for the host name.
314
315 VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
316
317 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
318 aServiceInfo.mHostAddressTtl);
319
320 if (error == kErrorNotFound)
321 {
322 error = kErrorNone;
323 }
324
325 exit:
326 return error;
327 }
328
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const329 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
330 {
331 // This method searches a TXT record in the given `aSection`
332 // matching the record name against `aName` and updates the TXT
333 // related properties in `aServicesInfo`.
334 //
335 // If no match is found `mTxtDataTtl` (which is initialized to zero)
336 // remains unchanged to indicate this. In this case this method still
337 // returns `kErrorNone`.
338
339 Error error = kErrorNone;
340 uint16_t offset;
341 uint16_t numRecords;
342 TxtRecord txtRecord;
343
344 // A non-zero `mTxtDataTtl` indicates that TXT record is already
345 // found and parsed from a previous response.
346 VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
347
348 // A null `mTxtData` indicates that caller does not want to retrieve
349 // TXT data.
350 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
351
352 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
353
354 SelectSection(aSection, offset, numRecords);
355
356 aServiceInfo.mTxtDataTruncated = false;
357
358 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
359
360 error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
361
362 if (error == kErrorNoBufs)
363 {
364 error = kErrorNone;
365
366 // Mark `mTxtDataTruncated` to indicate that we could not read
367 // the full TXT record into the given `mTxtData` buffer.
368 aServiceInfo.mTxtDataTruncated = true;
369 }
370
371 SuccessOrExit(error);
372 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
373
374 exit:
375 if (error == kErrorNotFound)
376 {
377 error = kErrorNone;
378 }
379
380 return error;
381 }
382
383 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
384
PopulateFrom(const Message & aMessage)385 void Client::Response::PopulateFrom(const Message &aMessage)
386 {
387 // Populate `Response` with info from `aMessage`.
388
389 uint16_t offset = aMessage.GetOffset();
390 Header header;
391
392 mMessage = &aMessage;
393
394 IgnoreError(aMessage.Read(offset, header));
395 offset += sizeof(Header);
396
397 for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
398 {
399 IgnoreError(Name::ParseName(aMessage, offset));
400 offset += sizeof(Question);
401 }
402
403 mAnswerOffset = offset;
404 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
405 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
406 mAdditionalOffset = offset;
407 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
408
409 mAnswerRecordCount = header.GetAnswerCount();
410 mAdditionalRecordCount = header.GetAdditionalRecordCount();
411 }
412
413 //---------------------------------------------------------------------------------------------------------------------
414 // Client::AddressResponse
415
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const416 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
417 {
418 Error error = kErrorNone;
419 Name name(*mQuery, kNameOffsetInQuery);
420
421 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
422
423 // If the response is for an IPv4 address query or if it is an
424 // IPv6 address query response with no IPv6 address but with
425 // an IPv4 in its additional section, we read the IPv4 address
426 // and translate it to an IPv6 address.
427
428 QueryInfo info;
429
430 info.ReadFrom(*mQuery);
431
432 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
433 {
434 Section section;
435 ARecord aRecord;
436 NetworkData::ExternalRouteConfig nat64Prefix;
437
438 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
439 error = kErrorInvalidState);
440
441 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
442 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
443
444 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
445 aTtl = aRecord.GetTtl();
446
447 ExitNow();
448 }
449
450 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
451
452 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
453
454 exit:
455 return error;
456 }
457
458 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
459
460 //---------------------------------------------------------------------------------------------------------------------
461 // Client::BrowseResponse
462
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const463 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
464 {
465 Error error;
466 uint16_t offset;
467 uint16_t numRecords;
468 Name serviceName(*mQuery, kNameOffsetInQuery);
469 PtrRecord ptrRecord;
470
471 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
472
473 SelectSection(kAnswerSection, offset, numRecords);
474 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
475 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
476
477 exit:
478 return error;
479 }
480
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const481 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
482 {
483 Error error;
484 Name instanceName;
485
486 // Find a matching PTR record for the service instance label. Then
487 // search and read SRV, TXT and AAAA records in Additional Data
488 // section matching the same name to populate `aServiceInfo`.
489
490 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
491
492 InitServiceInfo(aServiceInfo);
493 SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
494 SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
495
496 if (aServiceInfo.mTxtDataTtl == 0)
497 {
498 aServiceInfo.mTxtDataSize = 0;
499 }
500
501 exit:
502 return error;
503 }
504
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const505 Error Client::BrowseResponse::GetHostAddress(const char *aHostName,
506 uint16_t aIndex,
507 Ip6::Address &aAddress,
508 uint32_t &aTtl) const
509 {
510 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
511 }
512
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const513 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
514 {
515 // This method searches within the Answer Section for a PTR record
516 // matching a given instance label @aInstanceLabel. If found, the
517 // `aName` is updated to return the name in the message.
518
519 Error error;
520 uint16_t offset;
521 Name serviceName(*mQuery, kNameOffsetInQuery);
522 uint16_t numRecords;
523 uint16_t labelOffset;
524 PtrRecord ptrRecord;
525
526 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
527
528 SelectSection(kAnswerSection, offset, numRecords);
529
530 for (; numRecords > 0; numRecords--)
531 {
532 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
533
534 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
535
536 if (error == kErrorNotFound)
537 {
538 // `ReadRecord()` updates `offset` to skip over a
539 // non-matching record.
540 continue;
541 }
542
543 SuccessOrExit(error);
544
545 // It is a PTR record. Check the first label to match the
546 // instance label.
547
548 labelOffset = offset;
549 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
550
551 if (error == kErrorNone)
552 {
553 aInstanceName.SetFromMessage(*mMessage, offset);
554 ExitNow();
555 }
556
557 VerifyOrExit(error == kErrorNotFound);
558
559 // Update offset to skip over the PTR record.
560 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
561 }
562
563 error = kErrorNotFound;
564
565 exit:
566 return error;
567 }
568
569 //---------------------------------------------------------------------------------------------------------------------
570 // Client::ServiceResponse
571
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const572 Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer,
573 uint8_t aLabelBufferSize,
574 char *aNameBuffer,
575 uint16_t aNameBufferSize) const
576 {
577 Error error;
578 uint16_t offset = kNameOffsetInQuery;
579
580 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
581
582 VerifyOrExit(aNameBuffer != nullptr);
583 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
584
585 exit:
586 return error;
587 }
588
GetServiceInfo(ServiceInfo & aServiceInfo) const589 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
590 {
591 // Search and read SRV, TXT records matching name from query.
592
593 Error error = kErrorNotFound;
594
595 InitServiceInfo(aServiceInfo);
596
597 for (const Response *response = this; response != nullptr; response = response->mNext)
598 {
599 Name name(*response->mQuery, kNameOffsetInQuery);
600 QueryInfo info;
601 Section srvSection;
602 Section txtSection;
603
604 info.ReadFrom(*response->mQuery);
605
606 switch (info.mQueryType)
607 {
608 case kIp6AddressQuery:
609 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
610 case kIp4AddressQuery:
611 #endif
612 IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
613 AsCoreType(&aServiceInfo.mHostAddress),
614 aServiceInfo.mHostAddressTtl));
615
616 continue; // to `for()` loop
617
618 case kServiceQuerySrvTxt:
619 case kServiceQuerySrv:
620 case kServiceQueryTxt:
621 break;
622
623 default:
624 continue;
625 }
626
627 // Determine from which section we should try to read the SRV and
628 // TXT records based on the query type.
629 //
630 // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
631 // only one record (SRV or TXT) in the answer section, but we
632 // still try to read the other records from additional data
633 // section in case server provided them.
634
635 srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
636 txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
637
638 error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
639
640 if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
641 {
642 error = kErrorNone;
643 }
644
645 SuccessOrExit(error);
646
647 SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
648 }
649
650 if (aServiceInfo.mTxtDataTtl == 0)
651 {
652 aServiceInfo.mTxtDataSize = 0;
653 }
654
655 exit:
656 return error;
657 }
658
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const659 Error Client::ServiceResponse::GetHostAddress(const char *aHostName,
660 uint16_t aIndex,
661 Ip6::Address &aAddress,
662 uint32_t &aTtl) const
663 {
664 Error error = kErrorNotFound;
665
666 for (const Response *response = this; response != nullptr; response = response->mNext)
667 {
668 Section section = kAdditionalDataSection;
669 QueryInfo info;
670
671 info.ReadFrom(*response->mQuery);
672
673 switch (info.mQueryType)
674 {
675 case kIp6AddressQuery:
676 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
677 case kIp4AddressQuery:
678 #endif
679 section = kAnswerSection;
680 break;
681
682 default:
683 break;
684 }
685
686 error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
687
688 if (error == kErrorNone)
689 {
690 break;
691 }
692 }
693
694 return error;
695 }
696
697 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
698
699 //---------------------------------------------------------------------------------------------------------------------
700 // Client
701
702 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
703 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
704 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
705 #endif
706 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
707 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
708 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
709 #endif
710
711 const uint8_t Client::kQuestionCount[] = {
712 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
713 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
714 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
715 #endif
716 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
717 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record
718 /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
719 /* kServiceQuerySrv -> */ 1, // SRV record only
720 /* kServiceQueryTxt -> */ 1, // TXT record only
721 #endif
722 };
723
724 const uint16_t *const Client::kQuestionRecordTypes[] = {
725 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
726 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
727 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
728 #endif
729 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
730 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
731 /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
732 /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0],
733 /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1],
734
735 #endif
736 };
737
Client(Instance & aInstance)738 Client::Client(Instance &aInstance)
739 : InstanceLocator(aInstance)
740 , mSocket(aInstance, *this)
741 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
742 , mTcpState(kTcpUninitialized)
743 #endif
744 , mTimer(aInstance)
745 , mDefaultConfig(QueryConfig::kInitFromDefaults)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
747 , mUserDidSetDefaultAddress(false)
748 #endif
749 {
750 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
752 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
753 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
754 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
755 static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
756 static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
757 static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
758 #endif
759 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
760 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
761 static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
762 static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
763 static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
764 #endif
765 }
766
Start(void)767 Error Client::Start(void)
768 {
769 Error error;
770
771 SuccessOrExit(error = mSocket.Open());
772 SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
773
774 exit:
775 return error;
776 }
777
Stop(void)778 void Client::Stop(void)
779 {
780 Query *query;
781
782 while ((query = mMainQueries.GetHead()) != nullptr)
783 {
784 FinalizeQuery(*query, kErrorAbort);
785 }
786
787 IgnoreError(mSocket.Close());
788 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
789 if (mTcpState != kTcpUninitialized)
790 {
791 IgnoreError(mEndpoint.Deinitialize());
792 }
793 #endif
794
795 mLimitedQueryServers.Clear();
796 }
797
798 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)799 Error Client::InitTcpSocket(void)
800 {
801 Error error;
802 otTcpEndpointInitializeArgs endpointArgs;
803
804 ClearAllBytes(endpointArgs);
805 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
806 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
807 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
808 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
809 endpointArgs.mContext = this;
810 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
811 endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes);
812
813 mSendLink.mNext = nullptr;
814 mSendLink.mData = mSendBufferBytes;
815 mSendLink.mLength = 0;
816
817 SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
818 exit:
819 return error;
820 }
821 #endif
822
SetDefaultConfig(const QueryConfig & aQueryConfig)823 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
824 {
825 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
826
827 mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
828
829 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
830 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
831 UpdateDefaultConfigAddress();
832 #endif
833 }
834
ResetDefaultConfig(void)835 void Client::ResetDefaultConfig(void)
836 {
837 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
838
839 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
840 mUserDidSetDefaultAddress = false;
841 UpdateDefaultConfigAddress();
842 #endif
843 }
844
845 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)846 void Client::UpdateDefaultConfigAddress(void)
847 {
848 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
849
850 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
851 !srpServerAddr.IsUnspecified())
852 {
853 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
854 }
855 }
856 #endif
857
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)858 Error Client::ResolveAddress(const char *aHostName,
859 AddressCallback aCallback,
860 void *aContext,
861 const QueryConfig *aConfig)
862 {
863 QueryInfo info;
864
865 info.Clear();
866 info.mQueryType = kIp6AddressQuery;
867 info.mConfig.SetFrom(aConfig, mDefaultConfig);
868 info.mCallback.mAddressCallback = aCallback;
869 info.mCallbackContext = aContext;
870
871 return StartQuery(info, nullptr, aHostName);
872 }
873
874 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)875 Error Client::ResolveIp4Address(const char *aHostName,
876 AddressCallback aCallback,
877 void *aContext,
878 const QueryConfig *aConfig)
879 {
880 QueryInfo info;
881
882 info.Clear();
883 info.mQueryType = kIp4AddressQuery;
884 info.mConfig.SetFrom(aConfig, mDefaultConfig);
885 info.mCallback.mAddressCallback = aCallback;
886 info.mCallbackContext = aContext;
887
888 return StartQuery(info, nullptr, aHostName);
889 }
890 #endif
891
892 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
893
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)894 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
895 {
896 QueryInfo info;
897
898 info.Clear();
899 info.mQueryType = kBrowseQuery;
900 info.mConfig.SetFrom(aConfig, mDefaultConfig);
901 info.mCallback.mBrowseCallback = aCallback;
902 info.mCallbackContext = aContext;
903
904 return StartQuery(info, nullptr, aServiceName);
905 }
906
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)907 Error Client::ResolveService(const char *aInstanceLabel,
908 const char *aServiceName,
909 ServiceCallback aCallback,
910 void *aContext,
911 const QueryConfig *aConfig)
912 {
913 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
914 }
915
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)916 Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel,
917 const char *aServiceName,
918 ServiceCallback aCallback,
919 void *aContext,
920 const QueryConfig *aConfig)
921 {
922 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
923 }
924
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)925 Error Client::Resolve(const char *aInstanceLabel,
926 const char *aServiceName,
927 ServiceCallback aCallback,
928 void *aContext,
929 const QueryConfig *aConfig,
930 bool aShouldResolveHostAddr)
931 {
932 QueryInfo info;
933 Error error;
934 QueryType secondQueryType = kNoQuery;
935
936 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
937
938 info.Clear();
939
940 info.mConfig.SetFrom(aConfig, mDefaultConfig);
941 info.mShouldResolveHostAddr = aShouldResolveHostAddr;
942
943 CheckAndUpdateServiceMode(info.mConfig, aConfig);
944
945 switch (info.mConfig.GetServiceMode())
946 {
947 case QueryConfig::kServiceModeSrvTxtSeparate:
948 secondQueryType = kServiceQueryTxt;
949
950 OT_FALL_THROUGH;
951
952 case QueryConfig::kServiceModeSrv:
953 info.mQueryType = kServiceQuerySrv;
954 break;
955
956 case QueryConfig::kServiceModeTxt:
957 info.mQueryType = kServiceQueryTxt;
958 VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
959 break;
960
961 case QueryConfig::kServiceModeSrvTxt:
962 case QueryConfig::kServiceModeSrvTxtOptimize:
963 default:
964 info.mQueryType = kServiceQuerySrvTxt;
965 break;
966 }
967
968 info.mCallback.mServiceCallback = aCallback;
969 info.mCallbackContext = aContext;
970
971 error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
972
973 exit:
974 return error;
975 }
976
977 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
978
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)979 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
980 {
981 // The `aLabel` can be `nullptr` and then `aName` provides the
982 // full name, otherwise the name is appended as `{aLabel}.
983 // {aName}`.
984
985 Error error;
986 Query *query;
987
988 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
989
990 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
991 if (aInfo.mQueryType == kIp4AddressQuery)
992 {
993 NetworkData::ExternalRouteConfig nat64Prefix;
994
995 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
996 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
997 error = kErrorInvalidState);
998 }
999 #endif
1000
1001 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
1002
1003 mMainQueries.Enqueue(*query);
1004
1005 error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1006 VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1007
1008 if (aSecondType != kNoQuery)
1009 {
1010 Query *secondQuery;
1011
1012 aInfo.mQueryType = aSecondType;
1013 aInfo.mMessageId = 0;
1014 aInfo.mTransmissionCount = 0;
1015 aInfo.mMainQuery = query;
1016
1017 // We intentionally do not use `error` here so in the unlikely
1018 // case where we cannot allocate the second query we can proceed
1019 // with the first one.
1020 SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1021
1022 IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1023
1024 // Update first query to link to second one by updating
1025 // its `mNextQuery`.
1026 aInfo.ReadFrom(*query);
1027 aInfo.mNextQuery = secondQuery;
1028 UpdateQuery(*query, aInfo);
1029 }
1030
1031 exit:
1032 return error;
1033 }
1034
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1035 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1036 {
1037 Error error = kErrorNone;
1038
1039 aQuery = nullptr;
1040
1041 VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1042
1043 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1044 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1045
1046 SuccessOrExit(error = aQuery->Append(aInfo));
1047
1048 if (aLabel != nullptr)
1049 {
1050 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1051 }
1052
1053 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1054
1055 exit:
1056 FreeAndNullMessageOnError(aQuery, error);
1057 return error;
1058 }
1059
FindMainQuery(Query & aQuery)1060 Client::Query &Client::FindMainQuery(Query &aQuery)
1061 {
1062 QueryInfo info;
1063
1064 info.ReadFrom(aQuery);
1065
1066 return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1067 }
1068
FreeQuery(Query & aQuery)1069 void Client::FreeQuery(Query &aQuery)
1070 {
1071 Query &mainQuery = FindMainQuery(aQuery);
1072 QueryInfo info;
1073
1074 mMainQueries.Dequeue(mainQuery);
1075
1076 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1077 {
1078 info.ReadFrom(*query);
1079 FreeMessage(info.mSavedResponse);
1080 query->Free();
1081 }
1082 }
1083
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1084 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1085 {
1086 // This method prepares and sends a query message represented by
1087 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1088 // the new `mRetransmissionTime`) and updates it in `aQuery` as
1089 // well. `aUpdateTimer` indicates whether the timer should be
1090 // updated when query is sent or not (used in the case where timer
1091 // is handled by caller).
1092
1093 Error error = kErrorNone;
1094 Message *message = nullptr;
1095 Header header;
1096 Ip6::MessageInfo messageInfo;
1097 uint16_t length = 0;
1098
1099 aInfo.mTransmissionCount++;
1100 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1101
1102 if (aInfo.mMessageId == 0)
1103 {
1104 do
1105 {
1106 SuccessOrExit(error = header.SetRandomMessageId());
1107 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1108
1109 aInfo.mMessageId = header.GetMessageId();
1110 }
1111 else
1112 {
1113 header.SetMessageId(aInfo.mMessageId);
1114 }
1115
1116 header.SetType(Header::kTypeQuery);
1117 header.SetQueryType(Header::kQueryTypeStandard);
1118
1119 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1120 {
1121 header.SetRecursionDesiredFlag();
1122 }
1123
1124 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1125
1126 message = mSocket.NewMessage();
1127 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1128
1129 SuccessOrExit(error = message->Append(header));
1130
1131 // Prepare the question section.
1132
1133 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1134 {
1135 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1136 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1137 }
1138
1139 length = message->GetLength() - message->GetOffset();
1140
1141 if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1142 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1143 {
1144 // Check if query will fit into tcp buffer if not return error.
1145 VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1146 OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1147 error = kErrorNoBufs);
1148
1149 // In case of initialized connection check if connected peer and new query have the same address.
1150 if (mTcpState != kTcpUninitialized)
1151 {
1152 VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1153 error = kErrorFailed);
1154 }
1155
1156 switch (mTcpState)
1157 {
1158 case kTcpUninitialized:
1159 SuccessOrExit(error = InitTcpSocket());
1160 SuccessOrExit(
1161 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1162 mTcpState = kTcpConnecting;
1163 PrepareTcpMessage(*message);
1164 break;
1165 case kTcpConnectedIdle:
1166 PrepareTcpMessage(*message);
1167 SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1168 mTcpState = kTcpConnectedSending;
1169 break;
1170 case kTcpConnecting:
1171 PrepareTcpMessage(*message);
1172 break;
1173 case kTcpConnectedSending:
1174 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1175 SuccessOrAssert(error = message->Read(message->GetOffset(),
1176 (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1177 IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1178 break;
1179 }
1180 message->Free();
1181 message = nullptr;
1182 }
1183 #else
1184 {
1185 error = kErrorInvalidArgs;
1186 LogWarn("DNS query over TCP not supported.");
1187 ExitNow();
1188 }
1189 #endif
1190 else
1191 {
1192 VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1193 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1194 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1195 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1196 }
1197
1198 exit:
1199
1200 FreeMessageOnError(message, error);
1201 if (aUpdateTimer)
1202 {
1203 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1204 }
1205
1206 UpdateQuery(aQuery, aInfo);
1207
1208 return error;
1209 }
1210
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1211 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1212 {
1213 // The name is encoded and included after the `Info` in `aQuery`
1214 // starting at `kNameOffsetInQuery`.
1215
1216 return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1217 }
1218
FinalizeQuery(Query & aQuery,Error aError)1219 void Client::FinalizeQuery(Query &aQuery, Error aError)
1220 {
1221 Response response;
1222 Query &mainQuery = FindMainQuery(aQuery);
1223
1224 response.mInstance = &Get<Instance>();
1225 response.mQuery = &mainQuery;
1226
1227 FinalizeQuery(response, aError);
1228 }
1229
FinalizeQuery(Response & aResponse,Error aError)1230 void Client::FinalizeQuery(Response &aResponse, Error aError)
1231 {
1232 QueryType type;
1233 Callback callback;
1234 void *context;
1235
1236 GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1237
1238 switch (type)
1239 {
1240 case kIp6AddressQuery:
1241 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1242 case kIp4AddressQuery:
1243 #endif
1244 if (callback.mAddressCallback != nullptr)
1245 {
1246 callback.mAddressCallback(aError, &aResponse, context);
1247 }
1248 break;
1249
1250 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1251 case kBrowseQuery:
1252 if (callback.mBrowseCallback != nullptr)
1253 {
1254 callback.mBrowseCallback(aError, &aResponse, context);
1255 }
1256 break;
1257
1258 case kServiceQuerySrvTxt:
1259 case kServiceQuerySrv:
1260 case kServiceQueryTxt:
1261 if (callback.mServiceCallback != nullptr)
1262 {
1263 callback.mServiceCallback(aError, &aResponse, context);
1264 }
1265 break;
1266 #endif
1267 case kNoQuery:
1268 break;
1269 }
1270
1271 FreeQuery(*aResponse.mQuery);
1272 }
1273
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1274 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1275 {
1276 QueryInfo info;
1277
1278 info.ReadFrom(aQuery);
1279
1280 aType = info.mQueryType;
1281 aCallback = info.mCallback;
1282 aContext = info.mCallbackContext;
1283 }
1284
FindQueryById(uint16_t aMessageId)1285 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1286 {
1287 Query *matchedQuery = nullptr;
1288 QueryInfo info;
1289
1290 for (Query &mainQuery : mMainQueries)
1291 {
1292 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1293 {
1294 info.ReadFrom(*query);
1295
1296 if (info.mMessageId == aMessageId)
1297 {
1298 matchedQuery = query;
1299 ExitNow();
1300 }
1301 }
1302 }
1303
1304 exit:
1305 return matchedQuery;
1306 }
1307
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMsgInfo)1308 void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMsgInfo)
1309 {
1310 OT_UNUSED_VARIABLE(aMsgInfo);
1311 ProcessResponse(aMessage);
1312 }
1313
ProcessResponse(const Message & aResponseMessage)1314 void Client::ProcessResponse(const Message &aResponseMessage)
1315 {
1316 Error responseError;
1317 Query *query;
1318
1319 SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1320
1321 if (responseError != kErrorNone)
1322 {
1323 // Received an error from server, check if we can replace
1324 // the query.
1325
1326 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1327 if (ReplaceWithIp4Query(*query) == kErrorNone)
1328 {
1329 ExitNow();
1330 }
1331 #endif
1332 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1333 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1334 {
1335 ExitNow();
1336 }
1337 #endif
1338
1339 FinalizeQuery(*query, responseError);
1340 ExitNow();
1341 }
1342
1343 // Received successful response from server.
1344
1345 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1346 ResolveHostAddressIfNeeded(*query, aResponseMessage);
1347 #endif
1348
1349 if (!CanFinalizeQuery(*query))
1350 {
1351 SaveQueryResponse(*query, aResponseMessage);
1352 ExitNow();
1353 }
1354
1355 PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1356
1357 exit:
1358 return;
1359 }
1360
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1361 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1362 {
1363 Error error = kErrorNone;
1364 uint16_t offset = aResponseMessage.GetOffset();
1365 Header header;
1366 QueryInfo info;
1367 Name queryName;
1368
1369 SuccessOrExit(error = aResponseMessage.Read(offset, header));
1370 offset += sizeof(Header);
1371
1372 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1373 !header.IsTruncationFlagSet(),
1374 error = kErrorDrop);
1375
1376 aQuery = FindQueryById(header.GetMessageId());
1377 VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1378
1379 info.ReadFrom(*aQuery);
1380
1381 queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1382
1383 // Check the Question Section
1384
1385 if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1386 {
1387 for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1388 {
1389 SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1390 offset += sizeof(Question);
1391 }
1392 }
1393 else
1394 {
1395 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1396 error = kErrorParse);
1397 }
1398
1399 // Check the answer, authority and additional record sections
1400
1401 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1402 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1403 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1404
1405 // Read the response code
1406
1407 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1408
1409 if ((aResponseError == kErrorNone) && (info.mQueryType == kServiceQuerySrvTxt))
1410 {
1411 RecordServerAsCapableOfMultiQuestions(info.mConfig.GetServerSockAddr().GetAddress());
1412 }
1413
1414 exit:
1415 return error;
1416 }
1417
CanFinalizeQuery(Query & aQuery)1418 bool Client::CanFinalizeQuery(Query &aQuery)
1419 {
1420 // Determines whether we can finalize a main query by checking if
1421 // we have received and saved responses for all other related
1422 // queries associated with `aQuery`. Note that this method is
1423 // called when we receive a response for `aQuery`, so no need to
1424 // check for a saved response for `aQuery` itself.
1425
1426 bool canFinalize = true;
1427 QueryInfo info;
1428
1429 for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1430 {
1431 info.ReadFrom(*query);
1432
1433 if (query == &aQuery)
1434 {
1435 continue;
1436 }
1437
1438 if (info.mSavedResponse == nullptr)
1439 {
1440 canFinalize = false;
1441 ExitNow();
1442 }
1443 }
1444
1445 exit:
1446 return canFinalize;
1447 }
1448
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1449 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1450 {
1451 QueryInfo info;
1452
1453 info.ReadFrom(aQuery);
1454 VerifyOrExit(info.mSavedResponse == nullptr);
1455
1456 // If `Clone()` fails we let retry or timeout handle the error.
1457 info.mSavedResponse = aResponseMessage.Clone();
1458
1459 UpdateQuery(aQuery, info);
1460
1461 exit:
1462 return;
1463 }
1464
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1465 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1466 {
1467 // Populate `aResponse` for `aQuery`. If there is a saved response
1468 // message for `aQuery` we use it, otherwise, we use
1469 // `aResponseMessage`.
1470
1471 QueryInfo info;
1472
1473 info.ReadFrom(aQuery);
1474
1475 aResponse.mInstance = &Get<Instance>();
1476 aResponse.mQuery = &aQuery;
1477 aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1478
1479 return info.mNextQuery;
1480 }
1481
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1482 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1483 {
1484 // This method prepares a list of chained `Response` instances
1485 // corresponding to all related (chained) queries. It uses
1486 // recursion to go through the queries and construct the
1487 // `Response` chain.
1488
1489 Response response;
1490 Query *nextQuery;
1491
1492 nextQuery = PopulateResponse(response, aQuery, aResponseMessage);
1493 response.mNext = aPrevResponse;
1494
1495 if (nextQuery != nullptr)
1496 {
1497 PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1498 }
1499 else
1500 {
1501 FinalizeQuery(response, kErrorNone);
1502 }
1503 }
1504
HandleTimer(void)1505 void Client::HandleTimer(void)
1506 {
1507 NextFireTime nextTime;
1508 QueryInfo info;
1509 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1510 bool hasTcpQuery = false;
1511 #endif
1512
1513 for (Query &mainQuery : mMainQueries)
1514 {
1515 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1516 {
1517 info.ReadFrom(*query);
1518
1519 if (info.mSavedResponse != nullptr)
1520 {
1521 continue;
1522 }
1523
1524 if (nextTime.GetNow() >= info.mRetransmissionTime)
1525 {
1526 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1527 {
1528 FinalizeQuery(*query, kErrorResponseTimeout);
1529 break;
1530 }
1531
1532 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1533 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1534 {
1535 LogInfo("Switching to separate SRV/TXT on response timeout");
1536 info.ReadFrom(*query);
1537 }
1538 else
1539 #endif
1540 {
1541 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1542 }
1543 }
1544
1545 nextTime.UpdateIfEarlier(info.mRetransmissionTime);
1546
1547 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1548 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1549 {
1550 hasTcpQuery = true;
1551 }
1552 #endif
1553 }
1554 }
1555
1556 mTimer.FireAtIfEarlier(nextTime);
1557
1558 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1559 if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1560 {
1561 IgnoreError(mEndpoint.SendEndOfStream());
1562 }
1563 #endif
1564 }
1565
1566 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1567
ReplaceWithIp4Query(Query & aQuery)1568 Error Client::ReplaceWithIp4Query(Query &aQuery)
1569 {
1570 Error error = kErrorFailed;
1571 QueryInfo info;
1572
1573 info.ReadFrom(aQuery);
1574
1575 VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1576 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1577
1578 // We send a new query for IPv4 address resolution
1579 // for the same host name. We reuse the existing `aQuery`
1580 // instance and keep all the info but clear `mTransmissionCount`
1581 // and `mMessageId` (so that a new random message ID is
1582 // selected). The new `info` will be saved in the query in
1583 // `SendQuery()`. Note that the current query is still in the
1584 // `mMainQueries` list when `SendQuery()` selects a new random
1585 // message ID, so the existing message ID for this query will
1586 // not be reused.
1587
1588 info.mQueryType = kIp4AddressQuery;
1589 info.mMessageId = 0;
1590 info.mTransmissionCount = 0;
1591
1592 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1593 error = kErrorNone;
1594
1595 exit:
1596 return error;
1597 }
1598
1599 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1600
1601 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1602
CheckAndUpdateServiceMode(QueryConfig & aConfig,const QueryConfig * aRequestConfig) const1603 void Client::CheckAndUpdateServiceMode(QueryConfig &aConfig, const QueryConfig *aRequestConfig) const
1604 {
1605 // If the user explicitly requested "optimize" mode, we honor that
1606 // request. Otherwise, if "optimize" is chosen from the default
1607 // config, we check if the DNS server is known to have trouble
1608 // with multiple-question queries. If so, we switch to "separate"
1609 // mode.
1610
1611 if ((aRequestConfig != nullptr) && (aRequestConfig->GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize))
1612 {
1613 ExitNow();
1614 }
1615
1616 VerifyOrExit(aConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1617
1618 if (mLimitedQueryServers.Contains(aConfig.GetServerSockAddr().GetAddress()))
1619 {
1620 aConfig.SetServiceMode(QueryConfig::kServiceModeSrvTxtSeparate);
1621 }
1622
1623 exit:
1624 return;
1625 }
1626
RecordServerAsLimitedToSingleQuestion(const Ip6::Address & aServerAddress)1627 void Client::RecordServerAsLimitedToSingleQuestion(const Ip6::Address &aServerAddress)
1628 {
1629 VerifyOrExit(!aServerAddress.IsUnspecified());
1630
1631 VerifyOrExit(!mLimitedQueryServers.Contains(aServerAddress));
1632
1633 if (mLimitedQueryServers.IsFull())
1634 {
1635 uint8_t randomIndex = Random::NonCrypto::GetUint8InRange(0, mLimitedQueryServers.GetMaxSize());
1636
1637 mLimitedQueryServers.Remove(mLimitedQueryServers[randomIndex]);
1638 }
1639
1640 IgnoreError(mLimitedQueryServers.PushBack(aServerAddress));
1641
1642 exit:
1643 return;
1644 }
1645
RecordServerAsCapableOfMultiQuestions(const Ip6::Address & aServerAddress)1646 void Client::RecordServerAsCapableOfMultiQuestions(const Ip6::Address &aServerAddress)
1647 {
1648 Ip6::Address *entry = mLimitedQueryServers.Find(aServerAddress);
1649
1650 VerifyOrExit(entry != nullptr);
1651 mLimitedQueryServers.Remove(*entry);
1652
1653 exit:
1654 return;
1655 }
1656
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1657 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1658 {
1659 Error error = kErrorFailed;
1660 QueryInfo info;
1661 Query *secondQuery;
1662
1663 info.ReadFrom(aQuery);
1664
1665 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1666 VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1667
1668 RecordServerAsLimitedToSingleQuestion(info.mConfig.GetServerSockAddr().GetAddress());
1669
1670 secondQuery = aQuery.Clone();
1671 VerifyOrExit(secondQuery != nullptr);
1672
1673 info.mQueryType = kServiceQueryTxt;
1674 info.mMessageId = 0;
1675 info.mTransmissionCount = 0;
1676 info.mMainQuery = &aQuery;
1677 IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1678
1679 info.mQueryType = kServiceQuerySrv;
1680 info.mMessageId = 0;
1681 info.mTransmissionCount = 0;
1682 info.mNextQuery = secondQuery;
1683 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1684 error = kErrorNone;
1685
1686 exit:
1687 return error;
1688 }
1689
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1690 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1691 {
1692 QueryInfo info;
1693 Response response;
1694 ServiceInfo serviceInfo;
1695 char hostName[Name::kMaxNameSize];
1696
1697 info.ReadFrom(aQuery);
1698
1699 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1700 VerifyOrExit(info.mShouldResolveHostAddr);
1701
1702 PopulateResponse(response, aQuery, aResponseMessage);
1703
1704 ClearAllBytes(serviceInfo);
1705 serviceInfo.mHostNameBuffer = hostName;
1706 serviceInfo.mHostNameBufferSize = sizeof(hostName);
1707 SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1708
1709 // Check whether AAAA record for host address is provided in the SRV query response
1710
1711 if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1712 {
1713 Query *newQuery;
1714
1715 info.mQueryType = kIp6AddressQuery;
1716 info.mMessageId = 0;
1717 info.mTransmissionCount = 0;
1718 info.mMainQuery = &FindMainQuery(aQuery);
1719
1720 SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1721 IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1722
1723 // Update `aQuery` to be linked with new query (inserting
1724 // the `newQuery` into the linked-list after `aQuery`).
1725
1726 info.ReadFrom(aQuery);
1727 info.mNextQuery = newQuery;
1728 UpdateQuery(aQuery, info);
1729 }
1730
1731 exit:
1732 return;
1733 }
1734
1735 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1736
1737 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1738 void Client::PrepareTcpMessage(Message &aMessage)
1739 {
1740 uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1741
1742 // Prepending the DNS query with length of the packet according to RFC1035.
1743 BigEndian::WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1744 SuccessOrAssert(
1745 aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1746 mSendLink.mLength += length + sizeof(uint16_t);
1747 }
1748
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1749 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1750 {
1751 OT_UNUSED_VARIABLE(aEndpoint);
1752 OT_UNUSED_VARIABLE(aData);
1753 OT_ASSERT(mTcpState == kTcpConnectedSending);
1754
1755 mSendLink.mLength = 0;
1756 mTcpState = kTcpConnectedIdle;
1757 }
1758
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1759 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1760 {
1761 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1762 }
1763
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1764 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1765 {
1766 OT_UNUSED_VARIABLE(aEndpoint);
1767 IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1768 mTcpState = kTcpConnectedSending;
1769 }
1770
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1771 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1772 {
1773 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1774 }
1775
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1776 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1777 size_t &aOffset,
1778 Message &aMessage,
1779 uint16_t aLength)
1780 {
1781 // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1782 // and copy the content into `aMessage`. As we read we can move
1783 // to the next `aLinkedBuffer` and update `aOffset`.
1784 // Returns:
1785 // - `kErrorNone` if `aLength` bytes are successfully read and
1786 // `aOffset` and `aLinkedBuffer` are updated.
1787 // - `kErrorNotFound` is not enough bytes available to read
1788 // from `aLinkedBuffer`.
1789 // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1790
1791 Error error = kErrorNone;
1792
1793 while (aLength > 0)
1794 {
1795 uint16_t bytesToRead = aLength;
1796
1797 VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1798
1799 if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1800 {
1801 bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1802 }
1803
1804 SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1805
1806 aLength -= bytesToRead;
1807 aOffset += bytesToRead;
1808
1809 if (aOffset == aLinkedBuffer->mLength)
1810 {
1811 aLinkedBuffer = aLinkedBuffer->mNext;
1812 aOffset = 0;
1813 }
1814 }
1815
1816 exit:
1817 return error;
1818 }
1819
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1820 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1821 size_t aBytesAvailable,
1822 bool aEndOfStream,
1823 size_t aBytesRemaining)
1824 {
1825 OT_UNUSED_VARIABLE(aEndpoint);
1826 OT_UNUSED_VARIABLE(aBytesRemaining);
1827
1828 Message *message = nullptr;
1829 size_t totalRead = 0;
1830 size_t offset = 0;
1831 const otLinkedBuffer *data;
1832
1833 if (aEndOfStream)
1834 {
1835 // Cleanup is done in disconnected callback.
1836 IgnoreError(mEndpoint.SendEndOfStream());
1837 }
1838
1839 SuccessOrExit(mEndpoint.ReceiveByReference(data));
1840 VerifyOrExit(data != nullptr);
1841
1842 message = mSocket.NewMessage();
1843 VerifyOrExit(message != nullptr);
1844
1845 while (aBytesAvailable > totalRead)
1846 {
1847 uint16_t length;
1848
1849 // Read the `length` field.
1850 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1851
1852 IgnoreError(message->Read(/* aOffset */ 0, length));
1853 length = BigEndian::HostSwap16(length);
1854
1855 // Try to read `length` bytes.
1856 IgnoreError(message->SetLength(0));
1857 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1858
1859 totalRead += length + sizeof(uint16_t);
1860
1861 // Now process the read message as query response.
1862 ProcessResponse(*message);
1863
1864 IgnoreError(message->SetLength(0));
1865
1866 // Loop again to see if we can read another response.
1867 }
1868
1869 exit:
1870 // Inform `mEndPoint` about the total read and processed bytes
1871 IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1872 FreeMessage(message);
1873 }
1874
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1875 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1876 size_t aBytesAvailable,
1877 bool aEndOfStream,
1878 size_t aBytesRemaining)
1879 {
1880 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1881 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1882 }
1883
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1884 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1885 {
1886 OT_UNUSED_VARIABLE(aEndpoint);
1887 OT_UNUSED_VARIABLE(aReason);
1888 QueryInfo info;
1889
1890 IgnoreError(mEndpoint.Deinitialize());
1891 mTcpState = kTcpUninitialized;
1892
1893 // Abort queries in case of connection failures
1894 for (Query &mainQuery : mMainQueries)
1895 {
1896 info.ReadFrom(mainQuery);
1897
1898 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1899 {
1900 FinalizeQuery(mainQuery, kErrorAbort);
1901 }
1902 }
1903 }
1904
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1905 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1906 {
1907 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1908 }
1909
1910 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1911
1912 } // namespace Dns
1913 } // namespace ot
1914
1915 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1916