1 /*
2 * Copyright (c) 2017-2021, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "dns_client.hpp"
30
31 #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
32
33 #include "common/array.hpp"
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/instance.hpp"
38 #include "common/locator_getters.hpp"
39 #include "common/log.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/network_data_types.hpp"
42 #include "thread/thread_netif.hpp"
43
44 /**
45 * @file
46 * This file implements the DNS client.
47 */
48
49 namespace ot {
50 namespace Dns {
51
52 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
53 using ot::Encoding::BigEndian::ReadUint16;
54 using ot::Encoding::BigEndian::WriteUint16;
55 #endif
56
57 RegisterLogModule("DnsClient");
58
59 //---------------------------------------------------------------------------------------------------------------------
60 // Client::QueryConfig
61
62 const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS;
63
QueryConfig(InitMode aMode)64 Client::QueryConfig::QueryConfig(InitMode aMode)
65 {
66 OT_UNUSED_VARIABLE(aMode);
67
68 IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString));
69 GetServerSockAddr().SetPort(kDefaultServerPort);
70 SetResponseTimeout(kDefaultResponseTimeout);
71 SetMaxTxAttempts(kDefaultMaxTxAttempts);
72 SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion);
73 SetServiceMode(kDefaultServiceMode);
74 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
75 SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow);
76 #endif
77 SetTransportProto(kDnsTransportUdp);
78 }
79
SetFrom(const QueryConfig * aConfig,const QueryConfig & aDefaultConfig)80 void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig)
81 {
82 // This method sets the config from `aConfig` replacing any
83 // unspecified fields (value zero) with the fields from
84 // `aDefaultConfig`. If `aConfig` is `nullptr` then
85 // `aDefaultConfig` is used.
86
87 if (aConfig == nullptr)
88 {
89 *this = aDefaultConfig;
90 ExitNow();
91 }
92
93 *this = *aConfig;
94
95 if (GetServerSockAddr().GetAddress().IsUnspecified())
96 {
97 GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress();
98 }
99
100 if (GetServerSockAddr().GetPort() == 0)
101 {
102 GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort());
103 }
104
105 if (GetResponseTimeout() == 0)
106 {
107 SetResponseTimeout(aDefaultConfig.GetResponseTimeout());
108 }
109
110 if (GetMaxTxAttempts() == 0)
111 {
112 SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts());
113 }
114
115 if (GetRecursionFlag() == kFlagUnspecified)
116 {
117 SetRecursionFlag(aDefaultConfig.GetRecursionFlag());
118 }
119
120 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
121 if (GetNat64Mode() == kNat64Unspecified)
122 {
123 SetNat64Mode(aDefaultConfig.GetNat64Mode());
124 }
125 #endif
126
127 if (GetServiceMode() == kServiceModeUnspecified)
128 {
129 SetServiceMode(aDefaultConfig.GetServiceMode());
130 }
131
132 if (GetTransportProto() == kDnsTransportUnspecified)
133 {
134 SetTransportProto(aDefaultConfig.GetTransportProto());
135 }
136
137 exit:
138 return;
139 }
140
141 //---------------------------------------------------------------------------------------------------------------------
142 // Client::Response
143
SelectSection(Section aSection,uint16_t & aOffset,uint16_t & aNumRecord) const144 void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const
145 {
146 switch (aSection)
147 {
148 case kAnswerSection:
149 aOffset = mAnswerOffset;
150 aNumRecord = mAnswerRecordCount;
151 break;
152 case kAdditionalDataSection:
153 default:
154 aOffset = mAdditionalOffset;
155 aNumRecord = mAdditionalRecordCount;
156 break;
157 }
158 }
159
GetName(char * aNameBuffer,uint16_t aNameBufferSize) const160 Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const
161 {
162 uint16_t offset = kNameOffsetInQuery;
163
164 return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize);
165 }
166
CheckForHostNameAlias(Section aSection,Name & aHostName) const167 Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const
168 {
169 // If the response includes a CNAME record mapping the query host
170 // name to a canonical name, we update `aHostName` to the new alias
171 // name. Otherwise `aHostName` remains as before. This method handles
172 // when there are multiple CNAME records mapping the host name multiple
173 // times. We limit number of changes to `kMaxCnameAliasNameChanges`
174 // to detect and handle if the response contains CNAME record loops.
175
176 Error error;
177 uint16_t offset;
178 uint16_t numRecords;
179 CnameRecord cnameRecord;
180
181 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
182
183 for (uint16_t counter = 0; counter < kMaxCnameAliasNameChanges; counter++)
184 {
185 SelectSection(aSection, offset, numRecords);
186 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord);
187
188 if (error == kErrorNotFound)
189 {
190 error = kErrorNone;
191 ExitNow();
192 }
193
194 SuccessOrExit(error);
195
196 // A CNAME record was found. `offset` now points to after the
197 // last read byte within the `mMessage` into the `cnameRecord`
198 // (which is the start of the new canonical name).
199
200 aHostName.SetFromMessage(*mMessage, offset);
201 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
202
203 // Loop back to check if there may be a CNAME record for the
204 // new `aHostName`.
205 }
206
207 error = kErrorParse;
208
209 exit:
210 return error;
211 }
212
FindHostAddress(Section aSection,const Name & aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const213 Error Client::Response::FindHostAddress(Section aSection,
214 const Name &aHostName,
215 uint16_t aIndex,
216 Ip6::Address &aAddress,
217 uint32_t &aTtl) const
218 {
219 Error error;
220 uint16_t offset;
221 uint16_t numRecords;
222 Name name = aHostName;
223 AaaaRecord aaaaRecord;
224
225 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
226
227 SelectSection(aSection, offset, numRecords);
228 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord));
229 aAddress = aaaaRecord.GetAddress();
230 aTtl = aaaaRecord.GetTtl();
231
232 exit:
233 return error;
234 }
235
236 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
237
FindARecord(Section aSection,const Name & aHostName,uint16_t aIndex,ARecord & aARecord) const238 Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const
239 {
240 Error error;
241 uint16_t offset;
242 uint16_t numRecords;
243 Name name = aHostName;
244
245 SuccessOrExit(error = CheckForHostNameAlias(aSection, name));
246
247 SelectSection(aSection, offset, numRecords);
248 error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord);
249
250 exit:
251 return error;
252 }
253
254 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
255
256 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
257
InitServiceInfo(ServiceInfo & aServiceInfo) const258 void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const
259 {
260 // This method initializes `aServiceInfo` setting all
261 // TTLs to zero and host name to empty string.
262
263 aServiceInfo.mTtl = 0;
264 aServiceInfo.mHostAddressTtl = 0;
265 aServiceInfo.mTxtDataTtl = 0;
266 aServiceInfo.mTxtDataTruncated = false;
267
268 AsCoreType(&aServiceInfo.mHostAddress).Clear();
269
270 if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0))
271 {
272 aServiceInfo.mHostNameBuffer[0] = '\0';
273 }
274 }
275
ReadServiceInfo(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const276 Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
277 {
278 // This method searches for SRV record in the given `aSection`
279 // matching the record name against `aName`, and updates the
280 // `aServiceInfo` accordingly. It also searches for AAAA record
281 // for host name associated with the service (from SRV record).
282 // The search for AAAA record is always performed in Additional
283 // Data section (independent of the value given in `aSection`).
284
285 Error error = kErrorNone;
286 uint16_t offset;
287 uint16_t numRecords;
288 Name hostName;
289 SrvRecord srvRecord;
290
291 // A non-zero `mTtl` indicates that SRV record is already found
292 // and parsed from a previous response.
293 VerifyOrExit(aServiceInfo.mTtl == 0);
294
295 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
296
297 // Search for a matching SRV record
298 SelectSection(aSection, offset, numRecords);
299 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord));
300
301 aServiceInfo.mTtl = srvRecord.GetTtl();
302 aServiceInfo.mPort = srvRecord.GetPort();
303 aServiceInfo.mPriority = srvRecord.GetPriority();
304 aServiceInfo.mWeight = srvRecord.GetWeight();
305
306 hostName.SetFromMessage(*mMessage, offset);
307
308 if (aServiceInfo.mHostNameBuffer != nullptr)
309 {
310 SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer,
311 aServiceInfo.mHostNameBufferSize));
312 }
313 else
314 {
315 SuccessOrExit(error = Name::ParseName(*mMessage, offset));
316 }
317
318 // Search in additional section for AAAA record for the host name.
319
320 VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified());
321
322 error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress),
323 aServiceInfo.mHostAddressTtl);
324
325 if (error == kErrorNotFound)
326 {
327 error = kErrorNone;
328 }
329
330 exit:
331 return error;
332 }
333
ReadTxtRecord(Section aSection,const Name & aName,ServiceInfo & aServiceInfo) const334 Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const
335 {
336 // This method searches a TXT record in the given `aSection`
337 // matching the record name against `aName` and updates the TXT
338 // related properties in `aServicesInfo`.
339 //
340 // If no match is found `mTxtDataTtl` (which is initialized to zero)
341 // remains unchanged to indicate this. In this case this method still
342 // returns `kErrorNone`.
343
344 Error error = kErrorNone;
345 uint16_t offset;
346 uint16_t numRecords;
347 TxtRecord txtRecord;
348
349 // A non-zero `mTxtDataTtl` indicates that TXT record is already
350 // found and parsed from a previous response.
351 VerifyOrExit(aServiceInfo.mTxtDataTtl == 0);
352
353 // A null `mTxtData` indicates that caller does not want to retrieve
354 // TXT data.
355 VerifyOrExit(aServiceInfo.mTxtData != nullptr);
356
357 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
358
359 SelectSection(aSection, offset, numRecords);
360
361 aServiceInfo.mTxtDataTruncated = false;
362
363 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord));
364
365 error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize);
366
367 if (error == kErrorNoBufs)
368 {
369 error = kErrorNone;
370
371 // Mark `mTxtDataTruncated` to indicate that we could not read
372 // the full TXT record into the given `mTxtData` buffer.
373 aServiceInfo.mTxtDataTruncated = true;
374 }
375
376 SuccessOrExit(error);
377 aServiceInfo.mTxtDataTtl = txtRecord.GetTtl();
378
379 exit:
380 if (error == kErrorNotFound)
381 {
382 error = kErrorNone;
383 }
384
385 return error;
386 }
387
388 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
389
PopulateFrom(const Message & aMessage)390 void Client::Response::PopulateFrom(const Message &aMessage)
391 {
392 // Populate `Response` with info from `aMessage`.
393
394 uint16_t offset = aMessage.GetOffset();
395 Header header;
396
397 mMessage = &aMessage;
398
399 IgnoreError(aMessage.Read(offset, header));
400 offset += sizeof(Header);
401
402 for (uint16_t num = 0; num < header.GetQuestionCount(); num++)
403 {
404 IgnoreError(Name::ParseName(aMessage, offset));
405 offset += sizeof(Question);
406 }
407
408 mAnswerOffset = offset;
409 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount()));
410 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount()));
411 mAdditionalOffset = offset;
412 IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount()));
413
414 mAnswerRecordCount = header.GetAnswerCount();
415 mAdditionalRecordCount = header.GetAdditionalRecordCount();
416 }
417
418 //---------------------------------------------------------------------------------------------------------------------
419 // Client::AddressResponse
420
GetAddress(uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const421 Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const
422 {
423 Error error = kErrorNone;
424 Name name(*mQuery, kNameOffsetInQuery);
425
426 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
427
428 // If the response is for an IPv4 address query or if it is an
429 // IPv6 address query response with no IPv6 address but with
430 // an IPv4 in its additional section, we read the IPv4 address
431 // and translate it to an IPv6 address.
432
433 QueryInfo info;
434
435 info.ReadFrom(*mQuery);
436
437 if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64)
438 {
439 Section section;
440 ARecord aRecord;
441 NetworkData::ExternalRouteConfig nat64Prefix;
442
443 VerifyOrExit(mInstance->Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
444 error = kErrorInvalidState);
445
446 section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection;
447 SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord));
448
449 aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress());
450 aTtl = aRecord.GetTtl();
451
452 ExitNow();
453 }
454
455 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
456
457 ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl));
458
459 exit:
460 return error;
461 }
462
463 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
464
465 //---------------------------------------------------------------------------------------------------------------------
466 // Client::BrowseResponse
467
GetServiceInstance(uint16_t aIndex,char * aLabelBuffer,uint8_t aLabelBufferSize) const468 Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const
469 {
470 Error error;
471 uint16_t offset;
472 uint16_t numRecords;
473 Name serviceName(*mQuery, kNameOffsetInQuery);
474 PtrRecord ptrRecord;
475
476 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
477
478 SelectSection(kAnswerSection, offset, numRecords);
479 SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord));
480 error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0);
481
482 exit:
483 return error;
484 }
485
GetServiceInfo(const char * aInstanceLabel,ServiceInfo & aServiceInfo) const486 Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const
487 {
488 Error error;
489 Name instanceName;
490
491 // Find a matching PTR record for the service instance label. Then
492 // search and read SRV, TXT and AAAA records in Additional Data
493 // section matching the same name to populate `aServiceInfo`.
494
495 SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName));
496
497 InitServiceInfo(aServiceInfo);
498 SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo));
499 SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo));
500
501 if (aServiceInfo.mTxtDataTtl == 0)
502 {
503 aServiceInfo.mTxtDataSize = 0;
504 }
505
506 exit:
507 return error;
508 }
509
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const510 Error Client::BrowseResponse::GetHostAddress(const char *aHostName,
511 uint16_t aIndex,
512 Ip6::Address &aAddress,
513 uint32_t &aTtl) const
514 {
515 return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl);
516 }
517
FindPtrRecord(const char * aInstanceLabel,Name & aInstanceName) const518 Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const
519 {
520 // This method searches within the Answer Section for a PTR record
521 // matching a given instance label @aInstanceLabel. If found, the
522 // `aName` is updated to return the name in the message.
523
524 Error error;
525 uint16_t offset;
526 Name serviceName(*mQuery, kNameOffsetInQuery);
527 uint16_t numRecords;
528 uint16_t labelOffset;
529 PtrRecord ptrRecord;
530
531 VerifyOrExit(mMessage != nullptr, error = kErrorNotFound);
532
533 SelectSection(kAnswerSection, offset, numRecords);
534
535 for (; numRecords > 0; numRecords--)
536 {
537 SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName));
538
539 error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord);
540
541 if (error == kErrorNotFound)
542 {
543 // `ReadRecord()` updates `offset` to skip over a
544 // non-matching record.
545 continue;
546 }
547
548 SuccessOrExit(error);
549
550 // It is a PTR record. Check the first label to match the
551 // instance label.
552
553 labelOffset = offset;
554 error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel);
555
556 if (error == kErrorNone)
557 {
558 aInstanceName.SetFromMessage(*mMessage, offset);
559 ExitNow();
560 }
561
562 VerifyOrExit(error == kErrorNotFound);
563
564 // Update offset to skip over the PTR record.
565 offset += static_cast<uint16_t>(ptrRecord.GetSize()) - sizeof(ptrRecord);
566 }
567
568 error = kErrorNotFound;
569
570 exit:
571 return error;
572 }
573
574 //---------------------------------------------------------------------------------------------------------------------
575 // Client::ServiceResponse
576
GetServiceName(char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const577 Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer,
578 uint8_t aLabelBufferSize,
579 char *aNameBuffer,
580 uint16_t aNameBufferSize) const
581 {
582 Error error;
583 uint16_t offset = kNameOffsetInQuery;
584
585 SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize));
586
587 VerifyOrExit(aNameBuffer != nullptr);
588 SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize));
589
590 exit:
591 return error;
592 }
593
GetServiceInfo(ServiceInfo & aServiceInfo) const594 Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const
595 {
596 // Search and read SRV, TXT records matching name from query.
597
598 Error error = kErrorNotFound;
599
600 InitServiceInfo(aServiceInfo);
601
602 for (const Response *response = this; response != nullptr; response = response->mNext)
603 {
604 Name name(*response->mQuery, kNameOffsetInQuery);
605 QueryInfo info;
606 Section srvSection;
607 Section txtSection;
608
609 info.ReadFrom(*response->mQuery);
610
611 switch (info.mQueryType)
612 {
613 case kIp6AddressQuery:
614 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
615 case kIp4AddressQuery:
616 #endif
617 IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0,
618 AsCoreType(&aServiceInfo.mHostAddress),
619 aServiceInfo.mHostAddressTtl));
620
621 continue; // to `for()` loop
622
623 case kServiceQuerySrvTxt:
624 case kServiceQuerySrv:
625 case kServiceQueryTxt:
626 break;
627
628 default:
629 continue;
630 }
631
632 // Determine from which section we should try to read the SRV and
633 // TXT records based on the query type.
634 //
635 // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see
636 // only one record (SRV or TXT) in the answer section, but we
637 // still try to read the other records from additional data
638 // section in case server provided them.
639
640 srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection;
641 txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection;
642
643 error = response->ReadServiceInfo(srvSection, name, aServiceInfo);
644
645 if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound))
646 {
647 error = kErrorNone;
648 }
649
650 SuccessOrExit(error);
651
652 SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo));
653 }
654
655 if (aServiceInfo.mTxtDataTtl == 0)
656 {
657 aServiceInfo.mTxtDataSize = 0;
658 }
659
660 exit:
661 return error;
662 }
663
GetHostAddress(const char * aHostName,uint16_t aIndex,Ip6::Address & aAddress,uint32_t & aTtl) const664 Error Client::ServiceResponse::GetHostAddress(const char *aHostName,
665 uint16_t aIndex,
666 Ip6::Address &aAddress,
667 uint32_t &aTtl) const
668 {
669 Error error = kErrorNotFound;
670
671 for (const Response *response = this; response != nullptr; response = response->mNext)
672 {
673 Section section = kAdditionalDataSection;
674 QueryInfo info;
675
676 info.ReadFrom(*response->mQuery);
677
678 switch (info.mQueryType)
679 {
680 case kIp6AddressQuery:
681 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
682 case kIp4AddressQuery:
683 #endif
684 section = kAnswerSection;
685 break;
686
687 default:
688 break;
689 }
690
691 error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl);
692
693 if (error == kErrorNone)
694 {
695 break;
696 }
697 }
698
699 return error;
700 }
701
702 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
703
704 //---------------------------------------------------------------------------------------------------------------------
705 // Client
706
707 const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa};
708 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
709 const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA};
710 #endif
711 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
712 const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr};
713 const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt};
714 #endif
715
716 const uint8_t Client::kQuestionCount[] = {
717 /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record
718 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
719 /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record
720 #endif
721 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
722 /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record
723 /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records
724 /* kServiceQuerySrv -> */ 1, // SRV record only
725 /* kServiceQueryTxt -> */ 1, // TXT record only
726 #endif
727 };
728
729 const uint16_t *const Client::kQuestionRecordTypes[] = {
730 /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes,
731 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
732 /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes,
733 #endif
734 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
735 /* kBrowseQuery -> */ kBrowseQueryRecordTypes,
736 /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes,
737 /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0],
738 /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1],
739
740 #endif
741 };
742
Client(Instance & aInstance)743 Client::Client(Instance &aInstance)
744 : InstanceLocator(aInstance)
745 , mSocket(aInstance)
746 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
747 , mTcpState(kTcpUninitialized)
748 #endif
749 , mTimer(aInstance)
750 , mDefaultConfig(QueryConfig::kInitFromDefaults)
751 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
752 , mUserDidSetDefaultAddress(false)
753 #endif
754 {
755 static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct");
756 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
757 static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct");
758 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
759 static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct");
760 static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct");
761 static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct");
762 static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct");
763 #endif
764 #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
765 static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct");
766 static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct");
767 static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct");
768 static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct");
769 #endif
770 }
771
Start(void)772 Error Client::Start(void)
773 {
774 Error error;
775
776 SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
777 SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
778
779 exit:
780 return error;
781 }
782
Stop(void)783 void Client::Stop(void)
784 {
785 Query *query;
786
787 while ((query = mMainQueries.GetHead()) != nullptr)
788 {
789 FinalizeQuery(*query, kErrorAbort);
790 }
791
792 IgnoreError(mSocket.Close());
793 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
794 if (mTcpState != kTcpUninitialized)
795 {
796 IgnoreError(mEndpoint.Deinitialize());
797 }
798 #endif
799 }
800
801 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
InitTcpSocket(void)802 Error Client::InitTcpSocket(void)
803 {
804 Error error;
805 otTcpEndpointInitializeArgs endpointArgs;
806
807 memset(&endpointArgs, 0x00, sizeof(endpointArgs));
808 endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback;
809 endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback;
810 endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback;
811 endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback;
812 endpointArgs.mContext = this;
813 endpointArgs.mReceiveBuffer = mReceiveBufferBytes;
814 endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes);
815
816 mSendLink.mNext = nullptr;
817 mSendLink.mData = mSendBufferBytes;
818 mSendLink.mLength = 0;
819
820 SuccessOrExit(error = mEndpoint.Initialize(Get<Instance>(), endpointArgs));
821 exit:
822 return error;
823 }
824 #endif
825
SetDefaultConfig(const QueryConfig & aQueryConfig)826 void Client::SetDefaultConfig(const QueryConfig &aQueryConfig)
827 {
828 QueryConfig startingDefault(QueryConfig::kInitFromDefaults);
829
830 mDefaultConfig.SetFrom(&aQueryConfig, startingDefault);
831
832 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
833 mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified();
834 UpdateDefaultConfigAddress();
835 #endif
836 }
837
ResetDefaultConfig(void)838 void Client::ResetDefaultConfig(void)
839 {
840 mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults);
841
842 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
843 mUserDidSetDefaultAddress = false;
844 UpdateDefaultConfigAddress();
845 #endif
846 }
847
848 #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE
UpdateDefaultConfigAddress(void)849 void Client::UpdateDefaultConfigAddress(void)
850 {
851 const Ip6::Address &srpServerAddr = Get<Srp::Client>().GetServerAddress().GetAddress();
852
853 if (!mUserDidSetDefaultAddress && Get<Srp::Client>().IsServerSelectedByAutoStart() &&
854 !srpServerAddr.IsUnspecified())
855 {
856 mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr);
857 }
858 }
859 #endif
860
ResolveAddress(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)861 Error Client::ResolveAddress(const char *aHostName,
862 AddressCallback aCallback,
863 void *aContext,
864 const QueryConfig *aConfig)
865 {
866 QueryInfo info;
867
868 info.Clear();
869 info.mQueryType = kIp6AddressQuery;
870 info.mConfig.SetFrom(aConfig, mDefaultConfig);
871 info.mCallback.mAddressCallback = aCallback;
872 info.mCallbackContext = aContext;
873
874 return StartQuery(info, nullptr, aHostName);
875 }
876
877 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
ResolveIp4Address(const char * aHostName,AddressCallback aCallback,void * aContext,const QueryConfig * aConfig)878 Error Client::ResolveIp4Address(const char *aHostName,
879 AddressCallback aCallback,
880 void *aContext,
881 const QueryConfig *aConfig)
882 {
883 QueryInfo info;
884
885 info.Clear();
886 info.mQueryType = kIp4AddressQuery;
887 info.mConfig.SetFrom(aConfig, mDefaultConfig);
888 info.mCallback.mAddressCallback = aCallback;
889 info.mCallbackContext = aContext;
890
891 return StartQuery(info, nullptr, aHostName);
892 }
893 #endif
894
895 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
896
Browse(const char * aServiceName,BrowseCallback aCallback,void * aContext,const QueryConfig * aConfig)897 Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig)
898 {
899 QueryInfo info;
900
901 info.Clear();
902 info.mQueryType = kBrowseQuery;
903 info.mConfig.SetFrom(aConfig, mDefaultConfig);
904 info.mCallback.mBrowseCallback = aCallback;
905 info.mCallbackContext = aContext;
906
907 return StartQuery(info, nullptr, aServiceName);
908 }
909
ResolveService(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)910 Error Client::ResolveService(const char *aInstanceLabel,
911 const char *aServiceName,
912 ServiceCallback aCallback,
913 void *aContext,
914 const QueryConfig *aConfig)
915 {
916 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false);
917 }
918
ResolveServiceAndHostAddress(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig)919 Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel,
920 const char *aServiceName,
921 ServiceCallback aCallback,
922 void *aContext,
923 const QueryConfig *aConfig)
924 {
925 return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true);
926 }
927
Resolve(const char * aInstanceLabel,const char * aServiceName,ServiceCallback aCallback,void * aContext,const QueryConfig * aConfig,bool aShouldResolveHostAddr)928 Error Client::Resolve(const char *aInstanceLabel,
929 const char *aServiceName,
930 ServiceCallback aCallback,
931 void *aContext,
932 const QueryConfig *aConfig,
933 bool aShouldResolveHostAddr)
934 {
935 QueryInfo info;
936 Error error;
937 QueryType secondQueryType = kNoQuery;
938
939 VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs);
940
941 info.Clear();
942
943 info.mConfig.SetFrom(aConfig, mDefaultConfig);
944 info.mShouldResolveHostAddr = aShouldResolveHostAddr;
945
946 switch (info.mConfig.GetServiceMode())
947 {
948 case QueryConfig::kServiceModeSrvTxtSeparate:
949 secondQueryType = kServiceQueryTxt;
950
951 OT_FALL_THROUGH;
952
953 case QueryConfig::kServiceModeSrv:
954 info.mQueryType = kServiceQuerySrv;
955 break;
956
957 case QueryConfig::kServiceModeTxt:
958 info.mQueryType = kServiceQueryTxt;
959 VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs);
960 break;
961
962 case QueryConfig::kServiceModeSrvTxt:
963 case QueryConfig::kServiceModeSrvTxtOptimize:
964 default:
965 info.mQueryType = kServiceQuerySrvTxt;
966 break;
967 }
968
969 info.mCallback.mServiceCallback = aCallback;
970 info.mCallbackContext = aContext;
971
972 error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType);
973
974 exit:
975 return error;
976 }
977
978 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
979
StartQuery(QueryInfo & aInfo,const char * aLabel,const char * aName,QueryType aSecondType)980 Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType)
981 {
982 // The `aLabel` can be `nullptr` and then `aName` provides the
983 // full name, otherwise the name is appended as `{aLabel}.
984 // {aName}`.
985
986 Error error;
987 Query *query;
988
989 VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState);
990
991 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
992 if (aInfo.mQueryType == kIp4AddressQuery)
993 {
994 NetworkData::ExternalRouteConfig nat64Prefix;
995
996 VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs);
997 VerifyOrExit(Get<NetworkData::Leader>().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone,
998 error = kErrorInvalidState);
999 }
1000 #endif
1001
1002 SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query));
1003
1004 mMainQueries.Enqueue(*query);
1005
1006 error = SendQuery(*query, aInfo, /* aUpdateTimer */ true);
1007 VerifyOrExit(error == kErrorNone, FreeQuery(*query));
1008
1009 if (aSecondType != kNoQuery)
1010 {
1011 Query *secondQuery;
1012
1013 aInfo.mQueryType = aSecondType;
1014 aInfo.mMessageId = 0;
1015 aInfo.mTransmissionCount = 0;
1016 aInfo.mMainQuery = query;
1017
1018 // We intentionally do not use `error` here so in the unlikely
1019 // case where we cannot allocate the second query we can proceed
1020 // with the first one.
1021 SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery));
1022
1023 IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true));
1024
1025 // Update first query to link to second one by updating
1026 // its `mNextQuery`.
1027 aInfo.ReadFrom(*query);
1028 aInfo.mNextQuery = secondQuery;
1029 UpdateQuery(*query, aInfo);
1030 }
1031
1032 exit:
1033 return error;
1034 }
1035
AllocateQuery(const QueryInfo & aInfo,const char * aLabel,const char * aName,Query * & aQuery)1036 Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery)
1037 {
1038 Error error = kErrorNone;
1039
1040 aQuery = nullptr;
1041
1042 VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs);
1043
1044 aQuery = Get<MessagePool>().Allocate(Message::kTypeOther);
1045 VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs);
1046
1047 SuccessOrExit(error = aQuery->Append(aInfo));
1048
1049 if (aLabel != nullptr)
1050 {
1051 SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery));
1052 }
1053
1054 SuccessOrExit(error = Name::AppendName(aName, *aQuery));
1055
1056 exit:
1057 FreeAndNullMessageOnError(aQuery, error);
1058 return error;
1059 }
1060
FindMainQuery(Query & aQuery)1061 Client::Query &Client::FindMainQuery(Query &aQuery)
1062 {
1063 QueryInfo info;
1064
1065 info.ReadFrom(aQuery);
1066
1067 return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery;
1068 }
1069
FreeQuery(Query & aQuery)1070 void Client::FreeQuery(Query &aQuery)
1071 {
1072 Query &mainQuery = FindMainQuery(aQuery);
1073 QueryInfo info;
1074
1075 mMainQueries.Dequeue(mainQuery);
1076
1077 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1078 {
1079 info.ReadFrom(*query);
1080 FreeMessage(info.mSavedResponse);
1081 query->Free();
1082 }
1083 }
1084
SendQuery(Query & aQuery,QueryInfo & aInfo,bool aUpdateTimer)1085 Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer)
1086 {
1087 // This method prepares and sends a query message represented by
1088 // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets
1089 // the new `mRetransmissionTime`) and updates it in `aQuery` as
1090 // well. `aUpdateTimer` indicates whether the timer should be
1091 // updated when query is sent or not (used in the case where timer
1092 // is handled by caller).
1093
1094 Error error = kErrorNone;
1095 Message *message = nullptr;
1096 Header header;
1097 Ip6::MessageInfo messageInfo;
1098 uint16_t length = 0;
1099
1100 aInfo.mTransmissionCount++;
1101 aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout();
1102
1103 if (aInfo.mMessageId == 0)
1104 {
1105 do
1106 {
1107 SuccessOrExit(error = header.SetRandomMessageId());
1108 } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr));
1109
1110 aInfo.mMessageId = header.GetMessageId();
1111 }
1112 else
1113 {
1114 header.SetMessageId(aInfo.mMessageId);
1115 }
1116
1117 header.SetType(Header::kTypeQuery);
1118 header.SetQueryType(Header::kQueryTypeStandard);
1119
1120 if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired)
1121 {
1122 header.SetRecursionDesiredFlag();
1123 }
1124
1125 header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]);
1126
1127 message = mSocket.NewMessage();
1128 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
1129
1130 SuccessOrExit(error = message->Append(header));
1131
1132 // Prepare the question section.
1133
1134 for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++)
1135 {
1136 SuccessOrExit(error = AppendNameFromQuery(aQuery, *message));
1137 SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num])));
1138 }
1139
1140 length = message->GetLength() - message->GetOffset();
1141
1142 if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1143 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1144 {
1145 // Check if query will fit into tcp buffer if not return error.
1146 VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <=
1147 OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE,
1148 error = kErrorNoBufs);
1149
1150 // In case of initialized connection check if connected peer and new query have the same address.
1151 if (mTcpState != kTcpUninitialized)
1152 {
1153 VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr),
1154 error = kErrorFailed);
1155 }
1156
1157 switch (mTcpState)
1158 {
1159 case kTcpUninitialized:
1160 SuccessOrExit(error = InitTcpSocket());
1161 SuccessOrExit(
1162 error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN));
1163 mTcpState = kTcpConnecting;
1164 PrepareTcpMessage(*message);
1165 break;
1166 case kTcpConnectedIdle:
1167 PrepareTcpMessage(*message);
1168 SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1169 mTcpState = kTcpConnectedSending;
1170 break;
1171 case kTcpConnecting:
1172 PrepareTcpMessage(*message);
1173 break;
1174 case kTcpConnectedSending:
1175 WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1176 SuccessOrAssert(error = message->Read(message->GetOffset(),
1177 (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1178 IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0));
1179 break;
1180 }
1181 message->Free();
1182 message = nullptr;
1183 }
1184 #else
1185 {
1186 error = kErrorInvalidArgs;
1187 LogWarn("DNS query over TCP not supported.");
1188 ExitNow();
1189 }
1190 #endif
1191 else
1192 {
1193 VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs);
1194 messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress());
1195 messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort());
1196 SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
1197 }
1198
1199 exit:
1200
1201 FreeMessageOnError(message, error);
1202 if (aUpdateTimer)
1203 {
1204 mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime);
1205 }
1206
1207 UpdateQuery(aQuery, aInfo);
1208
1209 return error;
1210 }
1211
AppendNameFromQuery(const Query & aQuery,Message & aMessage)1212 Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage)
1213 {
1214 // The name is encoded and included after the `Info` in `aQuery`
1215 // starting at `kNameOffsetInQuery`.
1216
1217 return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery);
1218 }
1219
FinalizeQuery(Query & aQuery,Error aError)1220 void Client::FinalizeQuery(Query &aQuery, Error aError)
1221 {
1222 Response response;
1223 Query &mainQuery = FindMainQuery(aQuery);
1224
1225 response.mInstance = &Get<Instance>();
1226 response.mQuery = &mainQuery;
1227
1228 FinalizeQuery(response, aError);
1229 }
1230
FinalizeQuery(Response & aResponse,Error aError)1231 void Client::FinalizeQuery(Response &aResponse, Error aError)
1232 {
1233 QueryType type;
1234 Callback callback;
1235 void *context;
1236
1237 GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context);
1238
1239 switch (type)
1240 {
1241 case kIp6AddressQuery:
1242 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1243 case kIp4AddressQuery:
1244 #endif
1245 if (callback.mAddressCallback != nullptr)
1246 {
1247 callback.mAddressCallback(aError, &aResponse, context);
1248 }
1249 break;
1250
1251 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1252 case kBrowseQuery:
1253 if (callback.mBrowseCallback != nullptr)
1254 {
1255 callback.mBrowseCallback(aError, &aResponse, context);
1256 }
1257 break;
1258
1259 case kServiceQuerySrvTxt:
1260 case kServiceQuerySrv:
1261 case kServiceQueryTxt:
1262 if (callback.mServiceCallback != nullptr)
1263 {
1264 callback.mServiceCallback(aError, &aResponse, context);
1265 }
1266 break;
1267 #endif
1268 case kNoQuery:
1269 break;
1270 }
1271
1272 FreeQuery(*aResponse.mQuery);
1273 }
1274
GetQueryTypeAndCallback(const Query & aQuery,QueryType & aType,Callback & aCallback,void * & aContext)1275 void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext)
1276 {
1277 QueryInfo info;
1278
1279 info.ReadFrom(aQuery);
1280
1281 aType = info.mQueryType;
1282 aCallback = info.mCallback;
1283 aContext = info.mCallbackContext;
1284 }
1285
FindQueryById(uint16_t aMessageId)1286 Client::Query *Client::FindQueryById(uint16_t aMessageId)
1287 {
1288 Query *matchedQuery = nullptr;
1289 QueryInfo info;
1290
1291 for (Query &mainQuery : mMainQueries)
1292 {
1293 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1294 {
1295 info.ReadFrom(*query);
1296
1297 if (info.mMessageId == aMessageId)
1298 {
1299 matchedQuery = query;
1300 ExitNow();
1301 }
1302 }
1303 }
1304
1305 exit:
1306 return matchedQuery;
1307 }
1308
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMsgInfo)1309 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo)
1310 {
1311 OT_UNUSED_VARIABLE(aMsgInfo);
1312
1313 static_cast<Client *>(aContext)->ProcessResponse(AsCoreType(aMessage));
1314 }
1315
ProcessResponse(const Message & aResponseMessage)1316 void Client::ProcessResponse(const Message &aResponseMessage)
1317 {
1318 Error responseError;
1319 Query *query;
1320
1321 SuccessOrExit(ParseResponse(aResponseMessage, query, responseError));
1322
1323 if (responseError != kErrorNone)
1324 {
1325 // Received an error from server, check if we can replace
1326 // the query.
1327
1328 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1329 if (ReplaceWithIp4Query(*query) == kErrorNone)
1330 {
1331 ExitNow();
1332 }
1333 #endif
1334 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1335 if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone)
1336 {
1337 ExitNow();
1338 }
1339 #endif
1340
1341 FinalizeQuery(*query, responseError);
1342 ExitNow();
1343 }
1344
1345 // Received successful response from server.
1346
1347 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1348 ResolveHostAddressIfNeeded(*query, aResponseMessage);
1349 #endif
1350
1351 if (!CanFinalizeQuery(*query))
1352 {
1353 SaveQueryResponse(*query, aResponseMessage);
1354 ExitNow();
1355 }
1356
1357 PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr);
1358
1359 exit:
1360 return;
1361 }
1362
ParseResponse(const Message & aResponseMessage,Query * & aQuery,Error & aResponseError)1363 Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError)
1364 {
1365 Error error = kErrorNone;
1366 uint16_t offset = aResponseMessage.GetOffset();
1367 Header header;
1368 QueryInfo info;
1369 Name queryName;
1370
1371 SuccessOrExit(error = aResponseMessage.Read(offset, header));
1372 offset += sizeof(Header);
1373
1374 VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) &&
1375 !header.IsTruncationFlagSet(),
1376 error = kErrorDrop);
1377
1378 aQuery = FindQueryById(header.GetMessageId());
1379 VerifyOrExit(aQuery != nullptr, error = kErrorNotFound);
1380
1381 info.ReadFrom(*aQuery);
1382
1383 queryName.SetFromMessage(*aQuery, kNameOffsetInQuery);
1384
1385 // Check the Question Section
1386
1387 if (header.GetQuestionCount() == kQuestionCount[info.mQueryType])
1388 {
1389 for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++)
1390 {
1391 SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName));
1392 offset += sizeof(Question);
1393 }
1394 }
1395 else
1396 {
1397 VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0),
1398 error = kErrorParse);
1399 }
1400
1401 // Check the answer, authority and additional record sections
1402
1403 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount()));
1404 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount()));
1405 SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount()));
1406
1407 // Read the response code
1408
1409 aResponseError = Header::ResponseCodeToError(header.GetResponseCode());
1410
1411 exit:
1412 return error;
1413 }
1414
CanFinalizeQuery(Query & aQuery)1415 bool Client::CanFinalizeQuery(Query &aQuery)
1416 {
1417 // Determines whether we can finalize a main query by checking if
1418 // we have received and saved responses for all other related
1419 // queries associated with `aQuery`. Note that this method is
1420 // called when we receive a response for `aQuery`, so no need to
1421 // check for a saved response for `aQuery` itself.
1422
1423 bool canFinalize = true;
1424 QueryInfo info;
1425
1426 for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery)
1427 {
1428 info.ReadFrom(*query);
1429
1430 if (query == &aQuery)
1431 {
1432 continue;
1433 }
1434
1435 if (info.mSavedResponse == nullptr)
1436 {
1437 canFinalize = false;
1438 ExitNow();
1439 }
1440 }
1441
1442 exit:
1443 return canFinalize;
1444 }
1445
SaveQueryResponse(Query & aQuery,const Message & aResponseMessage)1446 void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage)
1447 {
1448 QueryInfo info;
1449
1450 info.ReadFrom(aQuery);
1451 VerifyOrExit(info.mSavedResponse == nullptr);
1452
1453 // If `Clone()` fails we let retry or timeout handle the error.
1454 info.mSavedResponse = aResponseMessage.Clone();
1455
1456 UpdateQuery(aQuery, info);
1457
1458 exit:
1459 return;
1460 }
1461
PopulateResponse(Response & aResponse,Query & aQuery,const Message & aResponseMessage)1462 Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage)
1463 {
1464 // Populate `aResponse` for `aQuery`. If there is a saved response
1465 // message for `aQuery` we use it, otherwise, we use
1466 // `aResponseMessage`.
1467
1468 QueryInfo info;
1469
1470 info.ReadFrom(aQuery);
1471
1472 aResponse.mInstance = &Get<Instance>();
1473 aResponse.mQuery = &aQuery;
1474 aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse);
1475
1476 return info.mNextQuery;
1477 }
1478
PrepareResponseAndFinalize(Query & aQuery,const Message & aResponseMessage,Response * aPrevResponse)1479 void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse)
1480 {
1481 // This method prepares a list of chained `Response` instances
1482 // corresponding to all related (chained) queries. It uses
1483 // recursion to go through the queries and construct the
1484 // `Response` chain.
1485
1486 Response response;
1487 Query *nextQuery;
1488
1489 nextQuery = PopulateResponse(response, aQuery, aResponseMessage);
1490 response.mNext = aPrevResponse;
1491
1492 if (nextQuery != nullptr)
1493 {
1494 PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response);
1495 }
1496 else
1497 {
1498 FinalizeQuery(response, kErrorNone);
1499 }
1500 }
1501
HandleTimer(void)1502 void Client::HandleTimer(void)
1503 {
1504 TimeMilli now = TimerMilli::GetNow();
1505 TimeMilli nextTime = now.GetDistantFuture();
1506 QueryInfo info;
1507 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1508 bool hasTcpQuery = false;
1509 #endif
1510
1511 for (Query &mainQuery : mMainQueries)
1512 {
1513 for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery)
1514 {
1515 info.ReadFrom(*query);
1516
1517 if (info.mSavedResponse != nullptr)
1518 {
1519 continue;
1520 }
1521
1522 if (now >= info.mRetransmissionTime)
1523 {
1524 if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts())
1525 {
1526 FinalizeQuery(*query, kErrorResponseTimeout);
1527 break;
1528 }
1529
1530 IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false));
1531 }
1532
1533 if (nextTime > info.mRetransmissionTime)
1534 {
1535 nextTime = info.mRetransmissionTime;
1536 }
1537
1538 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1539 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1540 {
1541 hasTcpQuery = true;
1542 }
1543 #endif
1544 }
1545 }
1546
1547 if (nextTime < now.GetDistantFuture())
1548 {
1549 mTimer.FireAt(nextTime);
1550 }
1551
1552 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1553 if (!hasTcpQuery && mTcpState != kTcpUninitialized)
1554 {
1555 IgnoreError(mEndpoint.SendEndOfStream());
1556 }
1557 #endif
1558 }
1559
1560 #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1561
ReplaceWithIp4Query(Query & aQuery)1562 Error Client::ReplaceWithIp4Query(Query &aQuery)
1563 {
1564 Error error = kErrorFailed;
1565 QueryInfo info;
1566
1567 info.ReadFrom(aQuery);
1568
1569 VerifyOrExit(info.mQueryType == kIp4AddressQuery);
1570 VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow);
1571
1572 // We send a new query for IPv4 address resolution
1573 // for the same host name. We reuse the existing `aQuery`
1574 // instance and keep all the info but clear `mTransmissionCount`
1575 // and `mMessageId` (so that a new random message ID is
1576 // selected). The new `info` will be saved in the query in
1577 // `SendQuery()`. Note that the current query is still in the
1578 // `mMainQueries` list when `SendQuery()` selects a new random
1579 // message ID, so the existing message ID for this query will
1580 // not be reused.
1581
1582 info.mQueryType = kIp4AddressQuery;
1583 info.mMessageId = 0;
1584 info.mTransmissionCount = 0;
1585
1586 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1587 error = kErrorNone;
1588
1589 exit:
1590 return error;
1591 }
1592
1593 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE
1594
1595 #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1596
ReplaceWithSeparateSrvTxtQueries(Query & aQuery)1597 Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery)
1598 {
1599 Error error = kErrorFailed;
1600 QueryInfo info;
1601 Query *secondQuery;
1602
1603 info.ReadFrom(aQuery);
1604
1605 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt);
1606 VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize);
1607
1608 secondQuery = aQuery.Clone();
1609 VerifyOrExit(secondQuery != nullptr);
1610
1611 info.mQueryType = kServiceQueryTxt;
1612 info.mMessageId = 0;
1613 info.mTransmissionCount = 0;
1614 info.mMainQuery = &aQuery;
1615 IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true));
1616
1617 info.mQueryType = kServiceQuerySrv;
1618 info.mMessageId = 0;
1619 info.mTransmissionCount = 0;
1620 info.mNextQuery = secondQuery;
1621 IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true));
1622 error = kErrorNone;
1623
1624 exit:
1625 return error;
1626 }
1627
ResolveHostAddressIfNeeded(Query & aQuery,const Message & aResponseMessage)1628 void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage)
1629 {
1630 QueryInfo info;
1631 Response response;
1632 ServiceInfo serviceInfo;
1633 char hostName[Name::kMaxNameSize];
1634
1635 info.ReadFrom(aQuery);
1636
1637 VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv);
1638 VerifyOrExit(info.mShouldResolveHostAddr);
1639
1640 PopulateResponse(response, aQuery, aResponseMessage);
1641
1642 memset(&serviceInfo, 0, sizeof(serviceInfo));
1643 serviceInfo.mHostNameBuffer = hostName;
1644 serviceInfo.mHostNameBufferSize = sizeof(hostName);
1645 SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo));
1646
1647 // Check whether AAAA record for host address is provided in the SRV query response
1648
1649 if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified())
1650 {
1651 Query *newQuery;
1652
1653 info.mQueryType = kIp6AddressQuery;
1654 info.mMessageId = 0;
1655 info.mTransmissionCount = 0;
1656 info.mMainQuery = &FindMainQuery(aQuery);
1657
1658 SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery));
1659 IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true));
1660
1661 // Update `aQuery` to be linked with new query (inserting
1662 // the `newQuery` into the linked-list after `aQuery`).
1663
1664 info.ReadFrom(aQuery);
1665 info.mNextQuery = newQuery;
1666 UpdateQuery(aQuery, info);
1667 }
1668
1669 exit:
1670 return;
1671 }
1672
1673 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE
1674
1675 #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
PrepareTcpMessage(Message & aMessage)1676 void Client::PrepareTcpMessage(Message &aMessage)
1677 {
1678 uint16_t length = aMessage.GetLength() - aMessage.GetOffset();
1679
1680 // Prepending the DNS query with length of the packet according to RFC1035.
1681 WriteUint16(length, mSendBufferBytes + mSendLink.mLength);
1682 SuccessOrAssert(
1683 aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length));
1684 mSendLink.mLength += length + sizeof(uint16_t);
1685 }
1686
HandleTcpSendDone(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1687 void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1688 {
1689 OT_UNUSED_VARIABLE(aEndpoint);
1690 OT_UNUSED_VARIABLE(aData);
1691 OT_ASSERT(mTcpState == kTcpConnectedSending);
1692
1693 mSendLink.mLength = 0;
1694 mTcpState = kTcpConnectedIdle;
1695 }
1696
HandleTcpSendDoneCallback(otTcpEndpoint * aEndpoint,otLinkedBuffer * aData)1697 void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData)
1698 {
1699 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData);
1700 }
1701
HandleTcpEstablished(otTcpEndpoint * aEndpoint)1702 void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint)
1703 {
1704 OT_UNUSED_VARIABLE(aEndpoint);
1705 IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0));
1706 mTcpState = kTcpConnectedSending;
1707 }
1708
HandleTcpEstablishedCallback(otTcpEndpoint * aEndpoint)1709 void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint)
1710 {
1711 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint);
1712 }
1713
ReadFromLinkBuffer(const otLinkedBuffer * & aLinkedBuffer,size_t & aOffset,Message & aMessage,uint16_t aLength)1714 Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer,
1715 size_t &aOffset,
1716 Message &aMessage,
1717 uint16_t aLength)
1718 {
1719 // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset`
1720 // and copy the content into `aMessage`. As we read we can move
1721 // to the next `aLinkedBuffer` and update `aOffset`.
1722 // Returns:
1723 // - `kErrorNone` if `aLength` bytes are successfully read and
1724 // `aOffset` and `aLinkedBuffer` are updated.
1725 // - `kErrorNotFound` is not enough bytes available to read
1726 // from `aLinkedBuffer`.
1727 // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes.
1728
1729 Error error = kErrorNone;
1730
1731 while (aLength > 0)
1732 {
1733 uint16_t bytesToRead = aLength;
1734
1735 VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound);
1736
1737 if (bytesToRead > aLinkedBuffer->mLength - aOffset)
1738 {
1739 bytesToRead = static_cast<uint16_t>(aLinkedBuffer->mLength - aOffset);
1740 }
1741
1742 SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead));
1743
1744 aLength -= bytesToRead;
1745 aOffset += bytesToRead;
1746
1747 if (aOffset == aLinkedBuffer->mLength)
1748 {
1749 aLinkedBuffer = aLinkedBuffer->mNext;
1750 aOffset = 0;
1751 }
1752 }
1753
1754 exit:
1755 return error;
1756 }
1757
HandleTcpReceiveAvailable(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1758 void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint,
1759 size_t aBytesAvailable,
1760 bool aEndOfStream,
1761 size_t aBytesRemaining)
1762 {
1763 OT_UNUSED_VARIABLE(aEndpoint);
1764 OT_UNUSED_VARIABLE(aBytesRemaining);
1765
1766 Message *message = nullptr;
1767 size_t totalRead = 0;
1768 size_t offset = 0;
1769 const otLinkedBuffer *data;
1770
1771 if (aEndOfStream)
1772 {
1773 // Cleanup is done in disconnected callback.
1774 IgnoreError(mEndpoint.SendEndOfStream());
1775 }
1776
1777 SuccessOrExit(mEndpoint.ReceiveByReference(data));
1778 VerifyOrExit(data != nullptr);
1779
1780 message = mSocket.NewMessage();
1781 VerifyOrExit(message != nullptr);
1782
1783 while (aBytesAvailable > totalRead)
1784 {
1785 uint16_t length;
1786
1787 // Read the `length` field.
1788 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t)));
1789
1790 IgnoreError(message->Read(/* aOffset */ 0, length));
1791 length = HostSwap16(length);
1792
1793 // Try to read `length` bytes.
1794 IgnoreError(message->SetLength(0));
1795 SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length));
1796
1797 totalRead += length + sizeof(uint16_t);
1798
1799 // Now process the read message as query response.
1800 ProcessResponse(*message);
1801
1802 IgnoreError(message->SetLength(0));
1803
1804 // Loop again to see if we can read another response.
1805 }
1806
1807 exit:
1808 // Inform `mEndPoint` about the total read and processed bytes
1809 IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0));
1810 FreeMessage(message);
1811 }
1812
HandleTcpReceiveAvailableCallback(otTcpEndpoint * aEndpoint,size_t aBytesAvailable,bool aEndOfStream,size_t aBytesRemaining)1813 void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint,
1814 size_t aBytesAvailable,
1815 bool aEndOfStream,
1816 size_t aBytesRemaining)
1817 {
1818 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))
1819 ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining);
1820 }
1821
HandleTcpDisconnected(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1822 void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1823 {
1824 OT_UNUSED_VARIABLE(aEndpoint);
1825 OT_UNUSED_VARIABLE(aReason);
1826 QueryInfo info;
1827
1828 IgnoreError(mEndpoint.Deinitialize());
1829 mTcpState = kTcpUninitialized;
1830
1831 // Abort queries in case of connection failures
1832 for (Query &mainQuery : mMainQueries)
1833 {
1834 info.ReadFrom(mainQuery);
1835
1836 if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp)
1837 {
1838 FinalizeQuery(mainQuery, kErrorAbort);
1839 }
1840 }
1841 }
1842
HandleTcpDisconnectedCallback(otTcpEndpoint * aEndpoint,otTcpDisconnectedReason aReason)1843 void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason)
1844 {
1845 static_cast<Client *>(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason);
1846 }
1847
1848 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE
1849
1850 } // namespace Dns
1851 } // namespace ot
1852
1853 #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE
1854