/* * Copyright (c) 2017-2021, The OpenThread Authors. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * 3. Neither the name of the copyright holder nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include "dns_client.hpp" #if OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE #include "common/array.hpp" #include "common/as_core_type.hpp" #include "common/code_utils.hpp" #include "common/debug.hpp" #include "common/instance.hpp" #include "common/locator_getters.hpp" #include "common/log.hpp" #include "net/udp6.hpp" #include "thread/network_data_types.hpp" #include "thread/thread_netif.hpp" /** * @file * This file implements the DNS client. */ namespace ot { namespace Dns { #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE using ot::Encoding::BigEndian::ReadUint16; using ot::Encoding::BigEndian::WriteUint16; #endif RegisterLogModule("DnsClient"); //--------------------------------------------------------------------------------------------------------------------- // Client::QueryConfig const char Client::QueryConfig::kDefaultServerAddressString[] = OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_IP6_ADDRESS; Client::QueryConfig::QueryConfig(InitMode aMode) { OT_UNUSED_VARIABLE(aMode); IgnoreError(GetServerSockAddr().GetAddress().FromString(kDefaultServerAddressString)); GetServerSockAddr().SetPort(kDefaultServerPort); SetResponseTimeout(kDefaultResponseTimeout); SetMaxTxAttempts(kDefaultMaxTxAttempts); SetRecursionFlag(kDefaultRecursionDesired ? kFlagRecursionDesired : kFlagNoRecursion); SetServiceMode(kDefaultServiceMode); #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE SetNat64Mode(kDefaultNat64Allowed ? kNat64Allow : kNat64Disallow); #endif SetTransportProto(kDnsTransportUdp); } void Client::QueryConfig::SetFrom(const QueryConfig *aConfig, const QueryConfig &aDefaultConfig) { // This method sets the config from `aConfig` replacing any // unspecified fields (value zero) with the fields from // `aDefaultConfig`. If `aConfig` is `nullptr` then // `aDefaultConfig` is used. if (aConfig == nullptr) { *this = aDefaultConfig; ExitNow(); } *this = *aConfig; if (GetServerSockAddr().GetAddress().IsUnspecified()) { GetServerSockAddr().GetAddress() = aDefaultConfig.GetServerSockAddr().GetAddress(); } if (GetServerSockAddr().GetPort() == 0) { GetServerSockAddr().SetPort(aDefaultConfig.GetServerSockAddr().GetPort()); } if (GetResponseTimeout() == 0) { SetResponseTimeout(aDefaultConfig.GetResponseTimeout()); } if (GetMaxTxAttempts() == 0) { SetMaxTxAttempts(aDefaultConfig.GetMaxTxAttempts()); } if (GetRecursionFlag() == kFlagUnspecified) { SetRecursionFlag(aDefaultConfig.GetRecursionFlag()); } #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE if (GetNat64Mode() == kNat64Unspecified) { SetNat64Mode(aDefaultConfig.GetNat64Mode()); } #endif if (GetServiceMode() == kServiceModeUnspecified) { SetServiceMode(aDefaultConfig.GetServiceMode()); } if (GetTransportProto() == kDnsTransportUnspecified) { SetTransportProto(aDefaultConfig.GetTransportProto()); } exit: return; } //--------------------------------------------------------------------------------------------------------------------- // Client::Response void Client::Response::SelectSection(Section aSection, uint16_t &aOffset, uint16_t &aNumRecord) const { switch (aSection) { case kAnswerSection: aOffset = mAnswerOffset; aNumRecord = mAnswerRecordCount; break; case kAdditionalDataSection: default: aOffset = mAdditionalOffset; aNumRecord = mAdditionalRecordCount; break; } } Error Client::Response::GetName(char *aNameBuffer, uint16_t aNameBufferSize) const { uint16_t offset = kNameOffsetInQuery; return Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize); } Error Client::Response::CheckForHostNameAlias(Section aSection, Name &aHostName) const { // If the response includes a CNAME record mapping the query host // name to a canonical name, we update `aHostName` to the new alias // name. Otherwise `aHostName` remains as before. Error error; uint16_t offset; uint16_t numRecords; CnameRecord cnameRecord; VerifyOrExit(mMessage != nullptr, error = kErrorNotFound); SelectSection(aSection, offset, numRecords); error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aHostName, cnameRecord); switch (error) { case kErrorNone: // A CNAME record was found. `offset` now points to after the // last read byte within the `mMessage` into the `cnameRecord` // (which is the start of the new canonical name). aHostName.SetFromMessage(*mMessage, offset); error = Name::ParseName(*mMessage, offset); break; case kErrorNotFound: error = kErrorNone; break; default: break; } exit: return error; } Error Client::Response::FindHostAddress(Section aSection, const Name &aHostName, uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const { Error error; uint16_t offset; uint16_t numRecords; Name name = aHostName; AaaaRecord aaaaRecord; SuccessOrExit(error = CheckForHostNameAlias(aSection, name)); SelectSection(aSection, offset, numRecords); SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aaaaRecord)); aAddress = aaaaRecord.GetAddress(); aTtl = aaaaRecord.GetTtl(); exit: return error; } #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE Error Client::Response::FindARecord(Section aSection, const Name &aHostName, uint16_t aIndex, ARecord &aARecord) const { Error error; uint16_t offset; uint16_t numRecords; Name name = aHostName; SuccessOrExit(error = CheckForHostNameAlias(aSection, name)); SelectSection(aSection, offset, numRecords); error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, name, aARecord); exit: return error; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE void Client::Response::InitServiceInfo(ServiceInfo &aServiceInfo) const { // This method initializes `aServiceInfo` setting all // TTLs to zero and host name to empty string. aServiceInfo.mTtl = 0; aServiceInfo.mHostAddressTtl = 0; aServiceInfo.mTxtDataTtl = 0; aServiceInfo.mTxtDataTruncated = false; AsCoreType(&aServiceInfo.mHostAddress).Clear(); if ((aServiceInfo.mHostNameBuffer != nullptr) && (aServiceInfo.mHostNameBufferSize > 0)) { aServiceInfo.mHostNameBuffer[0] = '\0'; } } Error Client::Response::ReadServiceInfo(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const { // This method searches for SRV record in the given `aSection` // matching the record name against `aName`, and updates the // `aServiceInfo` accordingly. It also searches for AAAA record // for host name associated with the service (from SRV record). // The search for AAAA record is always performed in Additional // Data section (independent of the value given in `aSection`). Error error = kErrorNone; uint16_t offset; uint16_t numRecords; Name hostName; SrvRecord srvRecord; // A non-zero `mTtl` indicates that SRV record is already found // and parsed from a previous response. VerifyOrExit(aServiceInfo.mTtl == 0); VerifyOrExit(mMessage != nullptr, error = kErrorNotFound); // Search for a matching SRV record SelectSection(aSection, offset, numRecords); SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, srvRecord)); aServiceInfo.mTtl = srvRecord.GetTtl(); aServiceInfo.mPort = srvRecord.GetPort(); aServiceInfo.mPriority = srvRecord.GetPriority(); aServiceInfo.mWeight = srvRecord.GetWeight(); hostName.SetFromMessage(*mMessage, offset); if (aServiceInfo.mHostNameBuffer != nullptr) { SuccessOrExit(error = srvRecord.ReadTargetHostName(*mMessage, offset, aServiceInfo.mHostNameBuffer, aServiceInfo.mHostNameBufferSize)); } else { SuccessOrExit(error = Name::ParseName(*mMessage, offset)); } // Search in additional section for AAAA record for the host name. VerifyOrExit(AsCoreType(&aServiceInfo.mHostAddress).IsUnspecified()); error = FindHostAddress(kAdditionalDataSection, hostName, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress), aServiceInfo.mHostAddressTtl); if (error == kErrorNotFound) { error = kErrorNone; } exit: return error; } Error Client::Response::ReadTxtRecord(Section aSection, const Name &aName, ServiceInfo &aServiceInfo) const { // This method searches a TXT record in the given `aSection` // matching the record name against `aName` and updates the TXT // related properties in `aServicesInfo`. // // If no match is found `mTxtDataTtl` (which is initialized to zero) // remains unchanged to indicate this. In this case this method still // returns `kErrorNone`. Error error = kErrorNone; uint16_t offset; uint16_t numRecords; TxtRecord txtRecord; // A non-zero `mTxtDataTtl` indicates that TXT record is already // found and parsed from a previous response. VerifyOrExit(aServiceInfo.mTxtDataTtl == 0); // A null `mTxtData` indicates that caller does not want to retrieve // TXT data. VerifyOrExit(aServiceInfo.mTxtData != nullptr); VerifyOrExit(mMessage != nullptr, error = kErrorNotFound); SelectSection(aSection, offset, numRecords); aServiceInfo.mTxtDataTruncated = false; SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, /* aIndex */ 0, aName, txtRecord)); error = txtRecord.ReadTxtData(*mMessage, offset, aServiceInfo.mTxtData, aServiceInfo.mTxtDataSize); if (error == kErrorNoBufs) { error = kErrorNone; // Mark `mTxtDataTruncated` to indicate that we could not read // the full TXT record into the given `mTxtData` buffer. aServiceInfo.mTxtDataTruncated = true; } SuccessOrExit(error); aServiceInfo.mTxtDataTtl = txtRecord.GetTtl(); exit: if (error == kErrorNotFound) { error = kErrorNone; } return error; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE void Client::Response::PopulateFrom(const Message &aMessage) { // Populate `Response` with info from `aMessage`. uint16_t offset = aMessage.GetOffset(); Header header; mMessage = &aMessage; IgnoreError(aMessage.Read(offset, header)); offset += sizeof(Header); for (uint16_t num = 0; num < header.GetQuestionCount(); num++) { IgnoreError(Name::ParseName(aMessage, offset)); offset += sizeof(Question); } mAnswerOffset = offset; IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAnswerCount())); IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAuthorityRecordCount())); mAdditionalOffset = offset; IgnoreError(ResourceRecord::ParseRecords(aMessage, offset, header.GetAdditionalRecordCount())); mAnswerRecordCount = header.GetAnswerCount(); mAdditionalRecordCount = header.GetAdditionalRecordCount(); } //--------------------------------------------------------------------------------------------------------------------- // Client::AddressResponse Error Client::AddressResponse::GetAddress(uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const { Error error = kErrorNone; Name name(*mQuery, kNameOffsetInQuery); #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE // If the response is for an IPv4 address query or if it is an // IPv6 address query response with no IPv6 address but with // an IPv4 in its additional section, we read the IPv4 address // and translate it to an IPv6 address. QueryInfo info; info.ReadFrom(*mQuery); if ((info.mQueryType == kIp4AddressQuery) || mIp6QueryResponseRequiresNat64) { Section section; ARecord aRecord; NetworkData::ExternalRouteConfig nat64Prefix; VerifyOrExit(mInstance->Get().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone, error = kErrorInvalidState); section = (info.mQueryType == kIp4AddressQuery) ? kAnswerSection : kAdditionalDataSection; SuccessOrExit(error = FindARecord(section, name, aIndex, aRecord)); aAddress.SynthesizeFromIp4Address(nat64Prefix.GetPrefix(), aRecord.GetAddress()); aTtl = aRecord.GetTtl(); ExitNow(); } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE ExitNow(error = FindHostAddress(kAnswerSection, name, aIndex, aAddress, aTtl)); exit: return error; } #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE //--------------------------------------------------------------------------------------------------------------------- // Client::BrowseResponse Error Client::BrowseResponse::GetServiceInstance(uint16_t aIndex, char *aLabelBuffer, uint8_t aLabelBufferSize) const { Error error; uint16_t offset; uint16_t numRecords; Name serviceName(*mQuery, kNameOffsetInQuery); PtrRecord ptrRecord; VerifyOrExit(mMessage != nullptr, error = kErrorNotFound); SelectSection(kAnswerSection, offset, numRecords); SuccessOrExit(error = ResourceRecord::FindRecord(*mMessage, offset, numRecords, aIndex, serviceName, ptrRecord)); error = ptrRecord.ReadPtrName(*mMessage, offset, aLabelBuffer, aLabelBufferSize, nullptr, 0); exit: return error; } Error Client::BrowseResponse::GetServiceInfo(const char *aInstanceLabel, ServiceInfo &aServiceInfo) const { Error error; Name instanceName; // Find a matching PTR record for the service instance label. Then // search and read SRV, TXT and AAAA records in Additional Data // section matching the same name to populate `aServiceInfo`. SuccessOrExit(error = FindPtrRecord(aInstanceLabel, instanceName)); InitServiceInfo(aServiceInfo); SuccessOrExit(error = ReadServiceInfo(kAdditionalDataSection, instanceName, aServiceInfo)); SuccessOrExit(error = ReadTxtRecord(kAdditionalDataSection, instanceName, aServiceInfo)); if (aServiceInfo.mTxtDataTtl == 0) { aServiceInfo.mTxtDataSize = 0; } exit: return error; } Error Client::BrowseResponse::GetHostAddress(const char *aHostName, uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const { return FindHostAddress(kAdditionalDataSection, Name(aHostName), aIndex, aAddress, aTtl); } Error Client::BrowseResponse::FindPtrRecord(const char *aInstanceLabel, Name &aInstanceName) const { // This method searches within the Answer Section for a PTR record // matching a given instance label @aInstanceLabel. If found, the // `aName` is updated to return the name in the message. Error error; uint16_t offset; Name serviceName(*mQuery, kNameOffsetInQuery); uint16_t numRecords; uint16_t labelOffset; PtrRecord ptrRecord; VerifyOrExit(mMessage != nullptr, error = kErrorNotFound); SelectSection(kAnswerSection, offset, numRecords); for (; numRecords > 0; numRecords--) { SuccessOrExit(error = Name::CompareName(*mMessage, offset, serviceName)); error = ResourceRecord::ReadRecord(*mMessage, offset, ptrRecord); if (error == kErrorNotFound) { // `ReadRecord()` updates `offset` to skip over a // non-matching record. continue; } SuccessOrExit(error); // It is a PTR record. Check the first label to match the // instance label. labelOffset = offset; error = Name::CompareLabel(*mMessage, labelOffset, aInstanceLabel); if (error == kErrorNone) { aInstanceName.SetFromMessage(*mMessage, offset); ExitNow(); } VerifyOrExit(error == kErrorNotFound); // Update offset to skip over the PTR record. offset += static_cast(ptrRecord.GetSize()) - sizeof(ptrRecord); } error = kErrorNotFound; exit: return error; } //--------------------------------------------------------------------------------------------------------------------- // Client::ServiceResponse Error Client::ServiceResponse::GetServiceName(char *aLabelBuffer, uint8_t aLabelBufferSize, char *aNameBuffer, uint16_t aNameBufferSize) const { Error error; uint16_t offset = kNameOffsetInQuery; SuccessOrExit(error = Name::ReadLabel(*mQuery, offset, aLabelBuffer, aLabelBufferSize)); VerifyOrExit(aNameBuffer != nullptr); SuccessOrExit(error = Name::ReadName(*mQuery, offset, aNameBuffer, aNameBufferSize)); exit: return error; } Error Client::ServiceResponse::GetServiceInfo(ServiceInfo &aServiceInfo) const { // Search and read SRV, TXT records matching name from query. Error error = kErrorNotFound; InitServiceInfo(aServiceInfo); for (const Response *response = this; response != nullptr; response = response->mNext) { Name name(*response->mQuery, kNameOffsetInQuery); QueryInfo info; Section srvSection; Section txtSection; info.ReadFrom(*response->mQuery); switch (info.mQueryType) { case kIp6AddressQuery: #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE case kIp4AddressQuery: #endif IgnoreError(response->FindHostAddress(kAnswerSection, name, /* aIndex */ 0, AsCoreType(&aServiceInfo.mHostAddress), aServiceInfo.mHostAddressTtl)); continue; // to `for()` loop case kServiceQuerySrvTxt: case kServiceQuerySrv: case kServiceQueryTxt: break; default: continue; } // Determine from which section we should try to read the SRV and // TXT records based on the query type. // // In `kServiceQuerySrv` or `kServiceQueryTxt` we expect to see // only one record (SRV or TXT) in the answer section, but we // still try to read the other records from additional data // section in case server provided them. srvSection = (info.mQueryType != kServiceQueryTxt) ? kAnswerSection : kAdditionalDataSection; txtSection = (info.mQueryType != kServiceQuerySrv) ? kAnswerSection : kAdditionalDataSection; error = response->ReadServiceInfo(srvSection, name, aServiceInfo); if ((srvSection == kAdditionalDataSection) && (error == kErrorNotFound)) { error = kErrorNone; } SuccessOrExit(error); SuccessOrExit(error = response->ReadTxtRecord(txtSection, name, aServiceInfo)); } if (aServiceInfo.mTxtDataTtl == 0) { aServiceInfo.mTxtDataSize = 0; } exit: return error; } Error Client::ServiceResponse::GetHostAddress(const char *aHostName, uint16_t aIndex, Ip6::Address &aAddress, uint32_t &aTtl) const { Error error = kErrorNotFound; for (const Response *response = this; response != nullptr; response = response->mNext) { Section section = kAdditionalDataSection; QueryInfo info; info.ReadFrom(*response->mQuery); switch (info.mQueryType) { case kIp6AddressQuery: #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE case kIp4AddressQuery: #endif section = kAnswerSection; break; default: break; } error = response->FindHostAddress(section, Name(aHostName), aIndex, aAddress, aTtl); if (error == kErrorNone) { break; } } return error; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE //--------------------------------------------------------------------------------------------------------------------- // Client const uint16_t Client::kIp6AddressQueryRecordTypes[] = {ResourceRecord::kTypeAaaa}; #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE const uint16_t Client::kIp4AddressQueryRecordTypes[] = {ResourceRecord::kTypeA}; #endif #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE const uint16_t Client::kBrowseQueryRecordTypes[] = {ResourceRecord::kTypePtr}; const uint16_t Client::kServiceQueryRecordTypes[] = {ResourceRecord::kTypeSrv, ResourceRecord::kTypeTxt}; #endif const uint8_t Client::kQuestionCount[] = { /* kIp6AddressQuery -> */ GetArrayLength(kIp6AddressQueryRecordTypes), // AAAA record #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE /* kIp4AddressQuery -> */ GetArrayLength(kIp4AddressQueryRecordTypes), // A record #endif #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE /* kBrowseQuery -> */ GetArrayLength(kBrowseQueryRecordTypes), // PTR record /* kServiceQuerySrvTxt -> */ GetArrayLength(kServiceQueryRecordTypes), // SRV and TXT records /* kServiceQuerySrv -> */ 1, // SRV record only /* kServiceQueryTxt -> */ 1, // TXT record only #endif }; const uint16_t *const Client::kQuestionRecordTypes[] = { /* kIp6AddressQuery -> */ kIp6AddressQueryRecordTypes, #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE /* kIp4AddressQuery -> */ kIp4AddressQueryRecordTypes, #endif #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE /* kBrowseQuery -> */ kBrowseQueryRecordTypes, /* kServiceQuerySrvTxt -> */ kServiceQueryRecordTypes, /* kServiceQuerySrv -> */ &kServiceQueryRecordTypes[0], /* kServiceQueryTxt -> */ &kServiceQueryRecordTypes[1], #endif }; Client::Client(Instance &aInstance) : InstanceLocator(aInstance) , mSocket(aInstance) #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE , mTcpState(kTcpUninitialized) #endif , mTimer(aInstance) , mDefaultConfig(QueryConfig::kInitFromDefaults) #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE , mUserDidSetDefaultAddress(false) #endif { static_assert(kIp6AddressQuery == 0, "kIp6AddressQuery value is not correct"); #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE static_assert(kIp4AddressQuery == 1, "kIp4AddressQuery value is not correct"); #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE static_assert(kBrowseQuery == 2, "kBrowseQuery value is not correct"); static_assert(kServiceQuerySrvTxt == 3, "kServiceQuerySrvTxt value is not correct"); static_assert(kServiceQuerySrv == 4, "kServiceQuerySrv value is not correct"); static_assert(kServiceQueryTxt == 5, "kServiceQueryTxt value is not correct"); #endif #elif OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE static_assert(kBrowseQuery == 1, "kBrowseQuery value is not correct"); static_assert(kServiceQuerySrvTxt == 2, "kServiceQuerySrvTxt value is not correct"); static_assert(kServiceQuerySrv == 3, "kServiceQuerySrv value is not correct"); static_assert(kServiceQueryTxt == 4, "kServiceQuerySrv value is not correct"); #endif } Error Client::Start(void) { Error error; SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this)); SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified)); exit: return error; } void Client::Stop(void) { Query *query; while ((query = mMainQueries.GetHead()) != nullptr) { FinalizeQuery(*query, kErrorAbort); } IgnoreError(mSocket.Close()); #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE if (mTcpState != kTcpUninitialized) { IgnoreError(mEndpoint.Deinitialize()); } #endif } #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE Error Client::InitTcpSocket(void) { Error error; otTcpEndpointInitializeArgs endpointArgs; memset(&endpointArgs, 0x00, sizeof(endpointArgs)); endpointArgs.mSendDoneCallback = HandleTcpSendDoneCallback; endpointArgs.mEstablishedCallback = HandleTcpEstablishedCallback; endpointArgs.mReceiveAvailableCallback = HandleTcpReceiveAvailableCallback; endpointArgs.mDisconnectedCallback = HandleTcpDisconnectedCallback; endpointArgs.mContext = this; endpointArgs.mReceiveBuffer = mReceiveBufferBytes; endpointArgs.mReceiveBufferSize = sizeof(mReceiveBufferBytes); mSendLink.mNext = nullptr; mSendLink.mData = mSendBufferBytes; mSendLink.mLength = 0; SuccessOrExit(error = mEndpoint.Initialize(Get(), endpointArgs)); exit: return error; } #endif void Client::SetDefaultConfig(const QueryConfig &aQueryConfig) { QueryConfig startingDefault(QueryConfig::kInitFromDefaults); mDefaultConfig.SetFrom(&aQueryConfig, startingDefault); #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE mUserDidSetDefaultAddress = !aQueryConfig.GetServerSockAddr().GetAddress().IsUnspecified(); UpdateDefaultConfigAddress(); #endif } void Client::ResetDefaultConfig(void) { mDefaultConfig = QueryConfig(QueryConfig::kInitFromDefaults); #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE mUserDidSetDefaultAddress = false; UpdateDefaultConfigAddress(); #endif } #if OPENTHREAD_CONFIG_DNS_CLIENT_DEFAULT_SERVER_ADDRESS_AUTO_SET_ENABLE void Client::UpdateDefaultConfigAddress(void) { const Ip6::Address &srpServerAddr = Get().GetServerAddress().GetAddress(); if (!mUserDidSetDefaultAddress && Get().IsServerSelectedByAutoStart() && !srpServerAddr.IsUnspecified()) { mDefaultConfig.GetServerSockAddr().SetAddress(srpServerAddr); } } #endif Error Client::ResolveAddress(const char *aHostName, AddressCallback aCallback, void *aContext, const QueryConfig *aConfig) { QueryInfo info; info.Clear(); info.mQueryType = kIp6AddressQuery; info.mConfig.SetFrom(aConfig, mDefaultConfig); info.mCallback.mAddressCallback = aCallback; info.mCallbackContext = aContext; return StartQuery(info, nullptr, aHostName); } #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE Error Client::ResolveIp4Address(const char *aHostName, AddressCallback aCallback, void *aContext, const QueryConfig *aConfig) { QueryInfo info; info.Clear(); info.mQueryType = kIp4AddressQuery; info.mConfig.SetFrom(aConfig, mDefaultConfig); info.mCallback.mAddressCallback = aCallback; info.mCallbackContext = aContext; return StartQuery(info, nullptr, aHostName); } #endif #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE Error Client::Browse(const char *aServiceName, BrowseCallback aCallback, void *aContext, const QueryConfig *aConfig) { QueryInfo info; info.Clear(); info.mQueryType = kBrowseQuery; info.mConfig.SetFrom(aConfig, mDefaultConfig); info.mCallback.mBrowseCallback = aCallback; info.mCallbackContext = aContext; return StartQuery(info, nullptr, aServiceName); } Error Client::ResolveService(const char *aInstanceLabel, const char *aServiceName, ServiceCallback aCallback, void *aContext, const QueryConfig *aConfig) { return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, false); } Error Client::ResolveServiceAndHostAddress(const char *aInstanceLabel, const char *aServiceName, ServiceCallback aCallback, void *aContext, const QueryConfig *aConfig) { return Resolve(aInstanceLabel, aServiceName, aCallback, aContext, aConfig, true); } Error Client::Resolve(const char *aInstanceLabel, const char *aServiceName, ServiceCallback aCallback, void *aContext, const QueryConfig *aConfig, bool aShouldResolveHostAddr) { QueryInfo info; Error error; QueryType secondQueryType = kNoQuery; VerifyOrExit(aInstanceLabel != nullptr, error = kErrorInvalidArgs); info.Clear(); info.mConfig.SetFrom(aConfig, mDefaultConfig); info.mShouldResolveHostAddr = aShouldResolveHostAddr; switch (info.mConfig.GetServiceMode()) { case QueryConfig::kServiceModeSrvTxtSeparate: secondQueryType = kServiceQueryTxt; OT_FALL_THROUGH; case QueryConfig::kServiceModeSrv: info.mQueryType = kServiceQuerySrv; break; case QueryConfig::kServiceModeTxt: info.mQueryType = kServiceQueryTxt; VerifyOrExit(!info.mShouldResolveHostAddr, error = kErrorInvalidArgs); break; case QueryConfig::kServiceModeSrvTxt: case QueryConfig::kServiceModeSrvTxtOptimize: default: info.mQueryType = kServiceQuerySrvTxt; break; } info.mCallback.mServiceCallback = aCallback; info.mCallbackContext = aContext; error = StartQuery(info, aInstanceLabel, aServiceName, secondQueryType); exit: return error; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE Error Client::StartQuery(QueryInfo &aInfo, const char *aLabel, const char *aName, QueryType aSecondType) { // The `aLabel` can be `nullptr` and then `aName` provides the // full name, otherwise the name is appended as `{aLabel}. // {aName}`. Error error; Query *query; VerifyOrExit(mSocket.IsBound(), error = kErrorInvalidState); #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE if (aInfo.mQueryType == kIp4AddressQuery) { NetworkData::ExternalRouteConfig nat64Prefix; VerifyOrExit(aInfo.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow, error = kErrorInvalidArgs); VerifyOrExit(Get().GetPreferredNat64Prefix(nat64Prefix) == kErrorNone, error = kErrorInvalidState); } #endif SuccessOrExit(error = AllocateQuery(aInfo, aLabel, aName, query)); mMainQueries.Enqueue(*query); error = SendQuery(*query, aInfo, /* aUpdateTimer */ true); VerifyOrExit(error == kErrorNone, FreeQuery(*query)); if (aSecondType != kNoQuery) { Query *secondQuery; aInfo.mQueryType = aSecondType; aInfo.mMessageId = 0; aInfo.mTransmissionCount = 0; aInfo.mMainQuery = query; // We intentionally do not use `error` here so in the unlikely // case where we cannot allocate the second query we can proceed // with the first one. SuccessOrExit(AllocateQuery(aInfo, aLabel, aName, secondQuery)); IgnoreError(SendQuery(*secondQuery, aInfo, /* aUpdateTimer */ true)); // Update first query to link to second one by updating // its `mNextQuery`. aInfo.ReadFrom(*query); aInfo.mNextQuery = secondQuery; UpdateQuery(*query, aInfo); } exit: return error; } Error Client::AllocateQuery(const QueryInfo &aInfo, const char *aLabel, const char *aName, Query *&aQuery) { Error error = kErrorNone; aQuery = nullptr; VerifyOrExit(aInfo.mConfig.GetResponseTimeout() <= TimerMilli::kMaxDelay, error = kErrorInvalidArgs); aQuery = Get().Allocate(Message::kTypeOther); VerifyOrExit(aQuery != nullptr, error = kErrorNoBufs); SuccessOrExit(error = aQuery->Append(aInfo)); if (aLabel != nullptr) { SuccessOrExit(error = Name::AppendLabel(aLabel, *aQuery)); } SuccessOrExit(error = Name::AppendName(aName, *aQuery)); exit: FreeAndNullMessageOnError(aQuery, error); return error; } Client::Query &Client::FindMainQuery(Query &aQuery) { QueryInfo info; info.ReadFrom(aQuery); return (info.mMainQuery == nullptr) ? aQuery : *info.mMainQuery; } void Client::FreeQuery(Query &aQuery) { Query &mainQuery = FindMainQuery(aQuery); QueryInfo info; mMainQueries.Dequeue(mainQuery); for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery) { info.ReadFrom(*query); FreeMessage(info.mSavedResponse); query->Free(); } } Error Client::SendQuery(Query &aQuery, QueryInfo &aInfo, bool aUpdateTimer) { // This method prepares and sends a query message represented by // `aQuery` and `aInfo`. This method updates `aInfo` (e.g., sets // the new `mRetransmissionTime`) and updates it in `aQuery` as // well. `aUpdateTimer` indicates whether the timer should be // updated when query is sent or not (used in the case where timer // is handled by caller). Error error = kErrorNone; Message *message = nullptr; Header header; Ip6::MessageInfo messageInfo; uint16_t length = 0; aInfo.mTransmissionCount++; aInfo.mRetransmissionTime = TimerMilli::GetNow() + aInfo.mConfig.GetResponseTimeout(); if (aInfo.mMessageId == 0) { do { SuccessOrExit(error = header.SetRandomMessageId()); } while ((header.GetMessageId() == 0) || (FindQueryById(header.GetMessageId()) != nullptr)); aInfo.mMessageId = header.GetMessageId(); } else { header.SetMessageId(aInfo.mMessageId); } header.SetType(Header::kTypeQuery); header.SetQueryType(Header::kQueryTypeStandard); if (aInfo.mConfig.GetRecursionFlag() == QueryConfig::kFlagRecursionDesired) { header.SetRecursionDesiredFlag(); } header.SetQuestionCount(kQuestionCount[aInfo.mQueryType]); message = mSocket.NewMessage(); VerifyOrExit(message != nullptr, error = kErrorNoBufs); SuccessOrExit(error = message->Append(header)); // Prepare the question section. for (uint8_t num = 0; num < kQuestionCount[aInfo.mQueryType]; num++) { SuccessOrExit(error = AppendNameFromQuery(aQuery, *message)); SuccessOrExit(error = message->Append(Question(kQuestionRecordTypes[aInfo.mQueryType][num]))); } length = message->GetLength() - message->GetOffset(); if (aInfo.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp) #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE { // Check if query will fit into tcp buffer if not return error. VerifyOrExit(length + sizeof(uint16_t) + mSendLink.mLength <= OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_QUERY_MAX_SIZE, error = kErrorNoBufs); // In case of initialized connection check if connected peer and new query have the same address. if (mTcpState != kTcpUninitialized) { VerifyOrExit(mEndpoint.GetPeerAddress() == AsCoreType(&aInfo.mConfig.mServerSockAddr), error = kErrorFailed); } switch (mTcpState) { case kTcpUninitialized: SuccessOrExit(error = InitTcpSocket()); SuccessOrExit( error = mEndpoint.Connect(AsCoreType(&aInfo.mConfig.mServerSockAddr), OT_TCP_CONNECT_NO_FAST_OPEN)); mTcpState = kTcpConnecting; PrepareTcpMessage(*message); break; case kTcpConnectedIdle: PrepareTcpMessage(*message); SuccessOrExit(error = mEndpoint.SendByReference(mSendLink, /* aFlags */ 0)); mTcpState = kTcpConnectedSending; break; case kTcpConnecting: PrepareTcpMessage(*message); break; case kTcpConnectedSending: WriteUint16(length, mSendBufferBytes + mSendLink.mLength); SuccessOrAssert(error = message->Read(message->GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length)); IgnoreError(mEndpoint.SendByExtension(length + sizeof(uint16_t), /* aFlags */ 0)); break; } message->Free(); message = nullptr; } #else { error = kErrorInvalidArgs; LogWarn("DNS query over TCP not supported."); ExitNow(); } #endif else { VerifyOrExit(length <= kUdpQueryMaxSize, error = kErrorInvalidArgs); messageInfo.SetPeerAddr(aInfo.mConfig.GetServerSockAddr().GetAddress()); messageInfo.SetPeerPort(aInfo.mConfig.GetServerSockAddr().GetPort()); SuccessOrExit(error = mSocket.SendTo(*message, messageInfo)); } exit: FreeMessageOnError(message, error); if (aUpdateTimer) { mTimer.FireAtIfEarlier(aInfo.mRetransmissionTime); } UpdateQuery(aQuery, aInfo); return error; } Error Client::AppendNameFromQuery(const Query &aQuery, Message &aMessage) { // The name is encoded and included after the `Info` in `aQuery` // starting at `kNameOffsetInQuery`. return aMessage.AppendBytesFromMessage(aQuery, kNameOffsetInQuery, aQuery.GetLength() - kNameOffsetInQuery); } void Client::FinalizeQuery(Query &aQuery, Error aError) { Response response; Query &mainQuery = FindMainQuery(aQuery); response.mInstance = &Get(); response.mQuery = &mainQuery; FinalizeQuery(response, aError); } void Client::FinalizeQuery(Response &aResponse, Error aError) { QueryType type; Callback callback; void *context; GetQueryTypeAndCallback(*aResponse.mQuery, type, callback, context); switch (type) { case kIp6AddressQuery: #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE case kIp4AddressQuery: #endif if (callback.mAddressCallback != nullptr) { callback.mAddressCallback(aError, &aResponse, context); } break; #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE case kBrowseQuery: if (callback.mBrowseCallback != nullptr) { callback.mBrowseCallback(aError, &aResponse, context); } break; case kServiceQuerySrvTxt: case kServiceQuerySrv: case kServiceQueryTxt: if (callback.mServiceCallback != nullptr) { callback.mServiceCallback(aError, &aResponse, context); } break; #endif case kNoQuery: break; } FreeQuery(*aResponse.mQuery); } void Client::GetQueryTypeAndCallback(const Query &aQuery, QueryType &aType, Callback &aCallback, void *&aContext) { QueryInfo info; info.ReadFrom(aQuery); aType = info.mQueryType; aCallback = info.mCallback; aContext = info.mCallbackContext; } Client::Query *Client::FindQueryById(uint16_t aMessageId) { Query *matchedQuery = nullptr; QueryInfo info; for (Query &mainQuery : mMainQueries) { for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery) { info.ReadFrom(*query); if (info.mMessageId == aMessageId) { matchedQuery = query; ExitNow(); } } } exit: return matchedQuery; } void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMsgInfo) { OT_UNUSED_VARIABLE(aMsgInfo); static_cast(aContext)->ProcessResponse(AsCoreType(aMessage)); } void Client::ProcessResponse(const Message &aResponseMessage) { Error responseError; Query *query; SuccessOrExit(ParseResponse(aResponseMessage, query, responseError)); if (responseError != kErrorNone) { // Received an error from server, check if we can replace // the query. #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE if (ReplaceWithIp4Query(*query) == kErrorNone) { ExitNow(); } #endif #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE if (ReplaceWithSeparateSrvTxtQueries(*query) == kErrorNone) { ExitNow(); } #endif FinalizeQuery(*query, responseError); ExitNow(); } // Received successful response from server. #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE ResolveHostAddressIfNeeded(*query, aResponseMessage); #endif if (!CanFinalizeQuery(*query)) { SaveQueryResponse(*query, aResponseMessage); ExitNow(); } PrepareResponseAndFinalize(FindMainQuery(*query), aResponseMessage, nullptr); exit: return; } Error Client::ParseResponse(const Message &aResponseMessage, Query *&aQuery, Error &aResponseError) { Error error = kErrorNone; uint16_t offset = aResponseMessage.GetOffset(); Header header; QueryInfo info; Name queryName; SuccessOrExit(error = aResponseMessage.Read(offset, header)); offset += sizeof(Header); VerifyOrExit((header.GetType() == Header::kTypeResponse) && (header.GetQueryType() == Header::kQueryTypeStandard) && !header.IsTruncationFlagSet(), error = kErrorDrop); aQuery = FindQueryById(header.GetMessageId()); VerifyOrExit(aQuery != nullptr, error = kErrorNotFound); info.ReadFrom(*aQuery); queryName.SetFromMessage(*aQuery, kNameOffsetInQuery); // Check the Question Section if (header.GetQuestionCount() == kQuestionCount[info.mQueryType]) { for (uint8_t num = 0; num < kQuestionCount[info.mQueryType]; num++) { SuccessOrExit(error = Name::CompareName(aResponseMessage, offset, queryName)); offset += sizeof(Question); } } else { VerifyOrExit((header.GetResponseCode() != Header::kResponseSuccess) && (header.GetQuestionCount() == 0), error = kErrorParse); } // Check the answer, authority and additional record sections SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAnswerCount())); SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAuthorityRecordCount())); SuccessOrExit(error = ResourceRecord::ParseRecords(aResponseMessage, offset, header.GetAdditionalRecordCount())); // Read the response code aResponseError = Header::ResponseCodeToError(header.GetResponseCode()); exit: return error; } bool Client::CanFinalizeQuery(Query &aQuery) { // Determines whether we can finalize a main query by checking if // we have received and saved responses for all other related // queries associated with `aQuery`. Note that this method is // called when we receive a response for `aQuery`, so no need to // check for a saved response for `aQuery` itself. bool canFinalize = true; QueryInfo info; for (Query *query = &FindMainQuery(aQuery); query != nullptr; query = info.mNextQuery) { info.ReadFrom(*query); if (query == &aQuery) { continue; } if (info.mSavedResponse == nullptr) { canFinalize = false; ExitNow(); } } exit: return canFinalize; } void Client::SaveQueryResponse(Query &aQuery, const Message &aResponseMessage) { QueryInfo info; info.ReadFrom(aQuery); VerifyOrExit(info.mSavedResponse == nullptr); // If `Clone()` fails we let retry or timeout handle the error. info.mSavedResponse = aResponseMessage.Clone(); UpdateQuery(aQuery, info); exit: return; } Client::Query *Client::PopulateResponse(Response &aResponse, Query &aQuery, const Message &aResponseMessage) { // Populate `aResponse` for `aQuery`. If there is a saved response // message for `aQuery` we use it, otherwise, we use // `aResponseMessage`. QueryInfo info; info.ReadFrom(aQuery); aResponse.mInstance = &Get(); aResponse.mQuery = &aQuery; aResponse.PopulateFrom((info.mSavedResponse == nullptr) ? aResponseMessage : *info.mSavedResponse); return info.mNextQuery; } void Client::PrepareResponseAndFinalize(Query &aQuery, const Message &aResponseMessage, Response *aPrevResponse) { // This method prepares a list of chained `Response` instances // corresponding to all related (chained) queries. It uses // recursion to go through the queries and construct the // `Response` chain. Response response; Query *nextQuery; nextQuery = PopulateResponse(response, aQuery, aResponseMessage); response.mNext = aPrevResponse; if (nextQuery != nullptr) { PrepareResponseAndFinalize(*nextQuery, aResponseMessage, &response); } else { FinalizeQuery(response, kErrorNone); } } void Client::HandleTimer(void) { TimeMilli now = TimerMilli::GetNow(); TimeMilli nextTime = now.GetDistantFuture(); QueryInfo info; #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE bool hasTcpQuery = false; #endif for (Query &mainQuery : mMainQueries) { for (Query *query = &mainQuery; query != nullptr; query = info.mNextQuery) { info.ReadFrom(*query); if (info.mSavedResponse != nullptr) { continue; } if (now >= info.mRetransmissionTime) { if (info.mTransmissionCount >= info.mConfig.GetMaxTxAttempts()) { FinalizeQuery(*query, kErrorResponseTimeout); break; } IgnoreError(SendQuery(*query, info, /* aUpdateTimer */ false)); } if (nextTime > info.mRetransmissionTime) { nextTime = info.mRetransmissionTime; } #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp) { hasTcpQuery = true; } #endif } } if (nextTime < now.GetDistantFuture()) { mTimer.FireAt(nextTime); } #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE if (!hasTcpQuery && mTcpState != kTcpUninitialized) { IgnoreError(mEndpoint.SendEndOfStream()); } #endif } #if OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE Error Client::ReplaceWithIp4Query(Query &aQuery) { Error error = kErrorFailed; QueryInfo info; info.ReadFrom(aQuery); VerifyOrExit(info.mQueryType == kIp4AddressQuery); VerifyOrExit(info.mConfig.GetNat64Mode() == QueryConfig::kNat64Allow); // We send a new query for IPv4 address resolution // for the same host name. We reuse the existing `aQuery` // instance and keep all the info but clear `mTransmissionCount` // and `mMessageId` (so that a new random message ID is // selected). The new `info` will be saved in the query in // `SendQuery()`. Note that the current query is still in the // `mMainQueries` list when `SendQuery()` selects a new random // message ID, so the existing message ID for this query will // not be reused. info.mQueryType = kIp4AddressQuery; info.mMessageId = 0; info.mTransmissionCount = 0; IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true)); error = kErrorNone; exit: return error; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_NAT64_ENABLE #if OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE Error Client::ReplaceWithSeparateSrvTxtQueries(Query &aQuery) { Error error = kErrorFailed; QueryInfo info; Query *secondQuery; info.ReadFrom(aQuery); VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt); VerifyOrExit(info.mConfig.GetServiceMode() == QueryConfig::kServiceModeSrvTxtOptimize); secondQuery = aQuery.Clone(); VerifyOrExit(secondQuery != nullptr); info.mQueryType = kServiceQueryTxt; info.mMessageId = 0; info.mTransmissionCount = 0; info.mMainQuery = &aQuery; IgnoreError(SendQuery(*secondQuery, info, /* aUpdateTimer */ true)); info.mQueryType = kServiceQuerySrv; info.mMessageId = 0; info.mTransmissionCount = 0; info.mNextQuery = secondQuery; IgnoreError(SendQuery(aQuery, info, /* aUpdateTimer */ true)); error = kErrorNone; exit: return error; } void Client::ResolveHostAddressIfNeeded(Query &aQuery, const Message &aResponseMessage) { QueryInfo info; Response response; ServiceInfo serviceInfo; char hostName[Name::kMaxNameSize]; info.ReadFrom(aQuery); VerifyOrExit(info.mQueryType == kServiceQuerySrvTxt || info.mQueryType == kServiceQuerySrv); VerifyOrExit(info.mShouldResolveHostAddr); PopulateResponse(response, aQuery, aResponseMessage); memset(&serviceInfo, 0, sizeof(serviceInfo)); serviceInfo.mHostNameBuffer = hostName; serviceInfo.mHostNameBufferSize = sizeof(hostName); SuccessOrExit(response.ReadServiceInfo(Response::kAnswerSection, Name(aQuery, kNameOffsetInQuery), serviceInfo)); // Check whether AAAA record for host address is provided in the SRV query response if (AsCoreType(&serviceInfo.mHostAddress).IsUnspecified()) { Query *newQuery; info.mQueryType = kIp6AddressQuery; info.mMessageId = 0; info.mTransmissionCount = 0; info.mMainQuery = &FindMainQuery(aQuery); SuccessOrExit(AllocateQuery(info, nullptr, hostName, newQuery)); IgnoreError(SendQuery(*newQuery, info, /* aUpdateTimer */ true)); // Update `aQuery` to be linked with new query (inserting // the `newQuery` into the linked-list after `aQuery`). info.ReadFrom(aQuery); info.mNextQuery = newQuery; UpdateQuery(aQuery, info); } exit: return; } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_SERVICE_DISCOVERY_ENABLE #if OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE void Client::PrepareTcpMessage(Message &aMessage) { uint16_t length = aMessage.GetLength() - aMessage.GetOffset(); // Prepending the DNS query with length of the packet according to RFC1035. WriteUint16(length, mSendBufferBytes + mSendLink.mLength); SuccessOrAssert( aMessage.Read(aMessage.GetOffset(), (mSendBufferBytes + sizeof(uint16_t) + mSendLink.mLength), length)); mSendLink.mLength += length + sizeof(uint16_t); } void Client::HandleTcpSendDone(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData) { OT_UNUSED_VARIABLE(aEndpoint); OT_UNUSED_VARIABLE(aData); OT_ASSERT(mTcpState == kTcpConnectedSending); mSendLink.mLength = 0; mTcpState = kTcpConnectedIdle; } void Client::HandleTcpSendDoneCallback(otTcpEndpoint *aEndpoint, otLinkedBuffer *aData) { static_cast(otTcpEndpointGetContext(aEndpoint))->HandleTcpSendDone(aEndpoint, aData); } void Client::HandleTcpEstablished(otTcpEndpoint *aEndpoint) { OT_UNUSED_VARIABLE(aEndpoint); IgnoreError(mEndpoint.SendByReference(mSendLink, /* aFlags */ 0)); mTcpState = kTcpConnectedSending; } void Client::HandleTcpEstablishedCallback(otTcpEndpoint *aEndpoint) { static_cast(otTcpEndpointGetContext(aEndpoint))->HandleTcpEstablished(aEndpoint); } Error Client::ReadFromLinkBuffer(const otLinkedBuffer *&aLinkedBuffer, size_t &aOffset, Message &aMessage, uint16_t aLength) { // Read `aLength` bytes from `aLinkedBuffer` starting at `aOffset` // and copy the content into `aMessage`. As we read we can move // to the next `aLinkedBuffer` and update `aOffset`. // Returns: // - `kErrorNone` if `aLength` bytes are successfully read and // `aOffset` and `aLinkedBuffer` are updated. // - `kErrorNotFound` is not enough bytes available to read // from `aLinkedBuffer`. // - `kErrorNotBufs` if cannot grow `aMessage` to append bytes. Error error = kErrorNone; while (aLength > 0) { uint16_t bytesToRead = aLength; VerifyOrExit(aLinkedBuffer != nullptr, error = kErrorNotFound); if (bytesToRead > aLinkedBuffer->mLength - aOffset) { bytesToRead = static_cast(aLinkedBuffer->mLength - aOffset); } SuccessOrExit(error = aMessage.AppendBytes(&aLinkedBuffer->mData[aOffset], bytesToRead)); aLength -= bytesToRead; aOffset += bytesToRead; if (aOffset == aLinkedBuffer->mLength) { aLinkedBuffer = aLinkedBuffer->mNext; aOffset = 0; } } exit: return error; } void Client::HandleTcpReceiveAvailable(otTcpEndpoint *aEndpoint, size_t aBytesAvailable, bool aEndOfStream, size_t aBytesRemaining) { OT_UNUSED_VARIABLE(aEndpoint); OT_UNUSED_VARIABLE(aBytesRemaining); Message *message = nullptr; size_t totalRead = 0; size_t offset = 0; const otLinkedBuffer *data; if (aEndOfStream) { // Cleanup is done in disconnected callback. IgnoreError(mEndpoint.SendEndOfStream()); } SuccessOrExit(mEndpoint.ReceiveByReference(data)); VerifyOrExit(data != nullptr); message = mSocket.NewMessage(); VerifyOrExit(message != nullptr); while (aBytesAvailable > totalRead) { uint16_t length; // Read the `length` field. SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, sizeof(uint16_t))); IgnoreError(message->Read(/* aOffset */ 0, length)); length = HostSwap16(length); // Try to read `length` bytes. IgnoreError(message->SetLength(0)); SuccessOrExit(ReadFromLinkBuffer(data, offset, *message, length)); totalRead += length + sizeof(uint16_t); // Now process the read message as query response. ProcessResponse(*message); IgnoreError(message->SetLength(0)); // Loop again to see if we can read another response. } exit: // Inform `mEndPoint` about the total read and processed bytes IgnoreError(mEndpoint.CommitReceive(totalRead, /* aFlags */ 0)); FreeMessage(message); } void Client::HandleTcpReceiveAvailableCallback(otTcpEndpoint *aEndpoint, size_t aBytesAvailable, bool aEndOfStream, size_t aBytesRemaining) { static_cast(otTcpEndpointGetContext(aEndpoint)) ->HandleTcpReceiveAvailable(aEndpoint, aBytesAvailable, aEndOfStream, aBytesRemaining); } void Client::HandleTcpDisconnected(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason) { OT_UNUSED_VARIABLE(aEndpoint); OT_UNUSED_VARIABLE(aReason); QueryInfo info; IgnoreError(mEndpoint.Deinitialize()); mTcpState = kTcpUninitialized; // Abort queries in case of connection failures for (Query &mainQuery : mMainQueries) { info.ReadFrom(mainQuery); if (info.mConfig.GetTransportProto() == QueryConfig::kDnsTransportTcp) { FinalizeQuery(mainQuery, kErrorAbort); } } } void Client::HandleTcpDisconnectedCallback(otTcpEndpoint *aEndpoint, otTcpDisconnectedReason aReason) { static_cast(otTcpEndpointGetContext(aEndpoint))->HandleTcpDisconnected(aEndpoint, aReason); } #endif // OPENTHREAD_CONFIG_DNS_CLIENT_OVER_TCP_ENABLE } // namespace Dns } // namespace ot #endif // OPENTHREAD_CONFIG_DNS_CLIENT_ENABLE