1 /*
2  *  Copyright (c) 2021, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #ifndef DNS_SERVER_HPP_
30 #define DNS_SERVER_HPP_
31 
32 #include "openthread-core-config.h"
33 
34 #if OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
35 
36 #include <openthread/dnssd_server.h>
37 
38 #include "common/message.hpp"
39 #include "common/non_copyable.hpp"
40 #include "common/timer.hpp"
41 #include "net/dns_types.hpp"
42 #include "net/ip6.hpp"
43 #include "net/netif.hpp"
44 #include "net/srp_server.hpp"
45 
46 /**
47  * @file
48  *   This file includes definitions for the DNS-SD server.
49  */
50 
51 namespace ot {
52 namespace Dns {
53 namespace ServiceDiscovery {
54 
55 /**
56  * This class implements DNS-SD server.
57  *
58  */
59 class Server : public InstanceLocator, private NonCopyable
60 {
61 public:
62     /**
63      * This enumeration specifies a DNS-SD query type.
64      *
65      */
66     enum DnsQueryType : uint8_t
67     {
68         kDnsQueryNone        = OT_DNSSD_QUERY_TYPE_NONE,         ///< Service type unspecified.
69         kDnsQueryBrowse      = OT_DNSSD_QUERY_TYPE_BROWSE,       ///< Service type browse service.
70         kDnsQueryResolve     = OT_DNSSD_QUERY_TYPE_RESOLVE,      ///< Service type resolve service instance.
71         kDnsQueryResolveHost = OT_DNSSD_QUERY_TYPE_RESOLVE_HOST, ///< Service type resolve hostname.
72     };
73 
74     static constexpr uint16_t kPort = OPENTHREAD_CONFIG_DNSSD_SERVER_PORT; ///< The DNS-SD server port.
75 
76     /**
77      * This constructor initializes the object.
78      *
79      * @param[in]  aInstance     A reference to the OpenThread instance.
80      *
81      */
82     explicit Server(Instance &aInstance);
83 
84     /**
85      * This method starts the DNS-SD server.
86      *
87      * @retval kErrorNone     Successfully started the DNS-SD server.
88      * @retval kErrorFailed   If failed to open or bind the UDP socket.
89      *
90      */
91     Error Start(void);
92 
93     /**
94      * This method stops the DNS-SD server.
95      *
96      */
97     void Stop(void);
98 
99     /**
100      * This method sets DNS-SD query callbacks.
101      *
102      * @param[in] aSubscribe    A pointer to the callback function to subscribe a service or service instance.
103      * @param[in] aUnsubscribe  A pointer to the callback function to unsubscribe a service or service instance.
104      * @param[in] aContext      A pointer to the application-specific context.
105      *
106      */
107     void SetQueryCallbacks(otDnssdQuerySubscribeCallback   aSubscribe,
108                            otDnssdQueryUnsubscribeCallback aUnsubscribe,
109                            void *                          aContext);
110 
111     /**
112      * This method notifies a discovered service instance.
113      *
114      * @param[in] aServiceFullName  The null-terminated full service name.
115      * @param[in] aInstanceInfo     A reference to the discovered service instance information.
116      *
117      */
118     void HandleDiscoveredServiceInstance(const char *aServiceFullName, const otDnssdServiceInstanceInfo &aInstanceInfo);
119 
120     /**
121      * This method notifies a discovered host.
122      *
123      * @param[in] aHostFullName     The null-terminated full host name.
124      * @param[in] aHostInfo         A reference to the discovered host information.
125      *
126      */
127     void HandleDiscoveredHost(const char *aHostFullName, const otDnssdHostInfo &aHostInfo);
128 
129     /**
130      * This method acquires the next query in the server.
131      *
132      * @param[in] aQuery            The query pointer. Pass nullptr to get the first query.
133      *
134      * @returns  A pointer to the query or nullptr if no more queries.
135      *
136      */
137     const otDnssdQuery *GetNextQuery(const otDnssdQuery *aQuery) const;
138 
139     /**
140      * This method acquires the DNS-SD query type and name for a specific query.
141      *
142      * @param[in]   aQuery            The query pointer.
143      * @param[out]  aNameOutput       The name output buffer.
144      *
145      * @returns The DNS-SD query type.
146      *
147      */
148     static DnsQueryType GetQueryTypeAndName(const otDnssdQuery *aQuery, char (&aName)[Name::kMaxNameSize]);
149 
150 private:
151     class NameCompressInfo : public Clearable<NameCompressInfo>
152     {
153     public:
154         explicit NameCompressInfo(void) = default;
155 
NameCompressInfo(const char * aDomainName)156         explicit NameCompressInfo(const char *aDomainName)
157             : mDomainName(aDomainName)
158             , mDomainNameOffset(kUnknownOffset)
159             , mServiceNameOffset(kUnknownOffset)
160             , mInstanceNameOffset(kUnknownOffset)
161             , mHostNameOffset(kUnknownOffset)
162         {
163         }
164 
165         static constexpr uint16_t kUnknownOffset = 0; // Unknown offset value (used when offset is not yet set).
166 
GetDomainNameOffset(void) const167         uint16_t GetDomainNameOffset(void) const { return mDomainNameOffset; }
168 
SetDomainNameOffset(uint16_t aOffset)169         void SetDomainNameOffset(uint16_t aOffset) { mDomainNameOffset = aOffset; }
170 
GetDomainName(void) const171         const char *GetDomainName(void) const { return mDomainName; }
172 
GetServiceNameOffset(const Message & aMessage,const char * aServiceName) const173         uint16_t GetServiceNameOffset(const Message &aMessage, const char *aServiceName) const
174         {
175             return MatchCompressedName(aMessage, mServiceNameOffset, aServiceName)
176                        ? mServiceNameOffset
177                        : static_cast<uint16_t>(kUnknownOffset);
178         };
179 
SetServiceNameOffset(uint16_t aOffset)180         void SetServiceNameOffset(uint16_t aOffset)
181         {
182             if (mServiceNameOffset == kUnknownOffset)
183             {
184                 mServiceNameOffset = aOffset;
185             }
186         }
187 
GetInstanceNameOffset(const Message & aMessage,const char * aName) const188         uint16_t GetInstanceNameOffset(const Message &aMessage, const char *aName) const
189         {
190             return MatchCompressedName(aMessage, mInstanceNameOffset, aName) ? mInstanceNameOffset
191                                                                              : static_cast<uint16_t>(kUnknownOffset);
192         }
193 
SetInstanceNameOffset(uint16_t aOffset)194         void SetInstanceNameOffset(uint16_t aOffset)
195         {
196             if (mInstanceNameOffset == kUnknownOffset)
197             {
198                 mInstanceNameOffset = aOffset;
199             }
200         }
201 
GetHostNameOffset(const Message & aMessage,const char * aName) const202         uint16_t GetHostNameOffset(const Message &aMessage, const char *aName) const
203         {
204             return MatchCompressedName(aMessage, mHostNameOffset, aName) ? mHostNameOffset
205                                                                          : static_cast<uint16_t>(kUnknownOffset);
206         }
207 
SetHostNameOffset(uint16_t aOffset)208         void SetHostNameOffset(uint16_t aOffset)
209         {
210             if (mHostNameOffset == kUnknownOffset)
211             {
212                 mHostNameOffset = aOffset;
213             }
214         }
215 
216     private:
MatchCompressedName(const Message & aMessage,uint16_t aOffset,const char * aName)217         static bool MatchCompressedName(const Message &aMessage, uint16_t aOffset, const char *aName)
218         {
219             return aOffset != kUnknownOffset && Name::CompareName(aMessage, aOffset, aName) == kErrorNone;
220         }
221 
222         const char *mDomainName;         // The serialized domain name.
223         uint16_t    mDomainNameOffset;   // Offset of domain name serialization into the response message.
224         uint16_t    mServiceNameOffset;  // Offset of service name serialization into the response message.
225         uint16_t    mInstanceNameOffset; // Offset of instance name serialization into the response message.
226         uint16_t    mHostNameOffset;     // Offset of host name serialization into the response message.
227     };
228 
229     static constexpr bool     kBindUnspecifiedNetif = OPENTHREAD_CONFIG_DNSSD_SERVER_BIND_UNSPECIFIED_NETIF;
230     static constexpr uint8_t  kProtocolLabelLength  = 4;
231     static constexpr uint8_t  kSubTypeLabelLength   = 4;
232     static constexpr uint16_t kMaxConcurrentQueries = 32;
233 
234     // This structure represents the splitting information of a full name.
235     struct NameComponentsOffsetInfo
236     {
237         static constexpr uint8_t kNotPresent = 0xff; // Indicates the component is not present.
238 
NameComponentsOffsetInfoot::Dns::ServiceDiscovery::Server::NameComponentsOffsetInfo239         explicit NameComponentsOffsetInfo(void)
240             : mDomainOffset(kNotPresent)
241             , mProtocolOffset(kNotPresent)
242             , mServiceOffset(kNotPresent)
243             , mSubTypeOffset(kNotPresent)
244             , mInstanceOffset(kNotPresent)
245         {
246         }
247 
IsServiceInstanceNameot::Dns::ServiceDiscovery::Server::NameComponentsOffsetInfo248         bool IsServiceInstanceName(void) const { return mInstanceOffset != kNotPresent; }
249 
IsServiceNameot::Dns::ServiceDiscovery::Server::NameComponentsOffsetInfo250         bool IsServiceName(void) const { return mServiceOffset != kNotPresent && mInstanceOffset == kNotPresent; }
251 
IsHostNameot::Dns::ServiceDiscovery::Server::NameComponentsOffsetInfo252         bool IsHostName(void) const { return mProtocolOffset == kNotPresent && mDomainOffset != 0; }
253 
254         uint8_t mDomainOffset;   // The offset to the beginning of <Domain>.
255         uint8_t mProtocolOffset; // The offset to the beginning of <Protocol> (i.e. _tcp or _udp) or `kNotPresent` if
256                                  // the name is not a service or instance.
257         uint8_t mServiceOffset;  // The offset to the beginning of <Service> or `kNotPresent` if the name is not a
258                                  // service or instance.
259         uint8_t mSubTypeOffset;  // The offset to the beginning of sub-type label or `kNotPresent` is not a sub-type.
260         uint8_t mInstanceOffset; // The offset to the beginning of <Instance> or `kNotPresent` if the name is not a
261                                  // instance.
262     };
263 
264     /**
265      * This class contains the compress information for a dns packet.
266      *
267      */
268     class QueryTransaction
269     {
270     public:
QueryTransaction(void)271         explicit QueryTransaction(void)
272             : mResponseMessage(nullptr)
273         {
274         }
275 
276         void                    Init(const Header &          aResponseHeader,
277                                      Message &               aResponseMessage,
278                                      const NameCompressInfo &aCompressInfo,
279                                      const Ip6::MessageInfo &aMessageInfo);
IsValid(void) const280         bool                    IsValid(void) const { return mResponseMessage != nullptr; }
GetMessageInfo(void) const281         const Ip6::MessageInfo &GetMessageInfo(void) const { return mMessageInfo; }
GetResponseHeader(void) const282         const Header &          GetResponseHeader(void) const { return mResponseHeader; }
GetResponseHeader(void)283         Header &                GetResponseHeader(void) { return mResponseHeader; }
GetResponseMessage(void) const284         const Message &         GetResponseMessage(void) const { return *mResponseMessage; }
GetResponseMessage(void)285         Message &               GetResponseMessage(void) { return *mResponseMessage; }
GetStartTime(void) const286         TimeMilli               GetStartTime(void) const { return mStartTime; }
GetNameCompressInfo(void)287         NameCompressInfo &      GetNameCompressInfo(void) { return mCompressInfo; };
288         void                    Finalize(Header::Response aResponseMessage, Ip6::Udp::Socket &aSocket);
289 
290         Header           mResponseHeader;
291         Message *        mResponseMessage;
292         NameCompressInfo mCompressInfo;
293         Ip6::MessageInfo mMessageInfo;
294         TimeMilli        mStartTime;
295     };
296 
297     static constexpr uint32_t kQueryTimeout = OPENTHREAD_CONFIG_DNSSD_QUERY_TIMEOUT;
298 
IsRunning(void) const299     bool        IsRunning(void) const { return mSocket.IsBound(); }
300     static void HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo);
301     void        HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo);
302     void ProcessQuery(const Header &aRequestHeader, Message &aRequestMessage, const Ip6::MessageInfo &aMessageInfo);
303     static Header::Response AddQuestions(const Header &    aRequestHeader,
304                                          const Message &   aRequestMessage,
305                                          Header &          aResponseHeader,
306                                          Message &         aResponseMessage,
307                                          NameCompressInfo &aCompressInfo);
308     static Error            AppendQuestion(const char *      aName,
309                                            const Question &  aQuestion,
310                                            Message &         aMessage,
311                                            NameCompressInfo &aCompressInfo);
312     static Error            AppendPtrRecord(Message &         aMessage,
313                                             const char *      aServiceName,
314                                             const char *      aInstanceName,
315                                             uint32_t          aTtl,
316                                             NameCompressInfo &aCompressInfo);
317     static Error            AppendSrvRecord(Message &         aMessage,
318                                             const char *      aInstanceName,
319                                             const char *      aHostName,
320                                             uint32_t          aTtl,
321                                             uint16_t          aPriority,
322                                             uint16_t          aWeight,
323                                             uint16_t          aPort,
324                                             NameCompressInfo &aCompressInfo);
325     static Error            AppendTxtRecord(Message &         aMessage,
326                                             const char *      aInstanceName,
327                                             const void *      aTxtData,
328                                             uint16_t          aTxtLength,
329                                             uint32_t          aTtl,
330                                             NameCompressInfo &aCompressInfo);
331     static Error            AppendAaaaRecord(Message &           aMessage,
332                                              const char *        aHostName,
333                                              const Ip6::Address &aAddress,
334                                              uint32_t            aTtl,
335                                              NameCompressInfo &  aCompressInfo);
336     static Error            AppendServiceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo);
337     static Error            AppendInstanceName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo);
338     static Error            AppendHostName(Message &aMessage, const char *aName, NameCompressInfo &aCompressInfo);
339     static void             IncResourceRecordCount(Header &aHeader, bool aAdditional);
340     static Error            FindNameComponents(const char *aName, const char *aDomain, NameComponentsOffsetInfo &aInfo);
341     static Error            FindPreviousLabel(const char *aName, uint8_t &aStart, uint8_t &aStop);
342     static void             SendResponse(Header                  aHeader,
343                                          Header::Response        aResponseCode,
344                                          Message &               aMessage,
345                                          const Ip6::MessageInfo &aMessageInfo,
346                                          Ip6::Udp::Socket &      aSocket);
347 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
348     Header::Response                   ResolveBySrp(Header &                  aResponseHeader,
349                                                     Message &                 aResponseMessage,
350                                                     Server::NameCompressInfo &aCompressInfo);
351     Header::Response                   ResolveQuestionBySrp(const char *      aName,
352                                                             const Question &  aQuestion,
353                                                             Header &          aResponseHeader,
354                                                             Message &         aResponseMessage,
355                                                             NameCompressInfo &aCompressInfo,
356                                                             bool              aAdditional);
357     const Srp::Server::Host *          GetNextSrpHost(const Srp::Server::Host *aHost);
358     static const Srp::Server::Service *GetNextSrpService(const Srp::Server::Host &   aHost,
359                                                          const Srp::Server::Service *aService);
360 #endif
361 
362     Error             ResolveByQueryCallbacks(Header &                aResponseHeader,
363                                               Message &               aResponseMessage,
364                                               NameCompressInfo &      aCompressInfo,
365                                               const Ip6::MessageInfo &aMessageInfo);
366     QueryTransaction *NewQuery(const Header &          aResponseHeader,
367                                Message &               aResponseMessage,
368                                const NameCompressInfo &aCompressInfo,
369                                const Ip6::MessageInfo &aMessageInfo);
370     static bool       CanAnswerQuery(const QueryTransaction &          aQuery,
371                                      const char *                      aServiceFullName,
372                                      const otDnssdServiceInstanceInfo &aInstanceInfo);
373     void              AnswerQuery(QueryTransaction &                aQuery,
374                                   const char *                      aServiceFullName,
375                                   const otDnssdServiceInstanceInfo &aInstanceInfo);
376     static bool       CanAnswerQuery(const Server::QueryTransaction &aQuery, const char *aHostFullName);
377     void AnswerQuery(QueryTransaction &aQuery, const char *aHostFullName, const otDnssdHostInfo &aHostInfo);
378     void FinalizeQuery(QueryTransaction &aQuery, Header::Response aResponseCode);
379     static DnsQueryType GetQueryTypeAndName(const Header & aHeader,
380                                             const Message &aMessage,
381                                             char (&aName)[Name::kMaxNameSize]);
382     static bool HasQuestion(const Header &aHeader, const Message &aMessage, const char *aName, uint16_t aQuestionType);
383     static void HandleTimer(Timer &aTimer);
384     void        HandleTimer(void);
385     void        ResetTimer(void);
386 
387     static const char kDnssdProtocolUdp[4];
388     static const char kDnssdProtocolTcp[4];
389     static const char kDnssdSubTypeLabel[];
390     static const char kDefaultDomainName[];
391     Ip6::Udp::Socket  mSocket;
392 
393     QueryTransaction                mQueryTransactions[kMaxConcurrentQueries];
394     void *                          mQueryCallbackContext;
395     otDnssdQuerySubscribeCallback   mQuerySubscribe;
396     otDnssdQueryUnsubscribeCallback mQueryUnsubscribe;
397     TimerMilli                      mTimer;
398 };
399 
400 } // namespace ServiceDiscovery
401 } // namespace Dns
402 } // namespace ot
403 
404 #endif // OPENTHREAD_CONFIG_DNSSD_SERVER_ENABLE
405 
406 #endif // DNS_SERVER_HPP_
407