1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "resolver.hpp"
30 
31 #include "platform-posix.h"
32 
33 #include <openthread/logging.h>
34 #include <openthread/message.h>
35 #include <openthread/udp.h>
36 #include <openthread/platform/dns.h>
37 #include <openthread/platform/time.h>
38 
39 #include "common/code_utils.hpp"
40 
41 #include <arpa/inet.h>
42 #include <arpa/nameser.h>
43 #include <cassert>
44 #include <netinet/in.h>
45 #include <sys/select.h>
46 #include <sys/socket.h>
47 #include <unistd.h>
48 
49 #include <fstream>
50 #include <string>
51 
52 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
53 
54 namespace {
55 constexpr char kResolvConfFullPath[] = "/etc/resolv.conf";
56 constexpr char kNameserverItem[]     = "nameserver";
57 } // namespace
58 
59 extern ot::Posix::Resolver gResolver;
60 
61 namespace ot {
62 namespace Posix {
63 
Init(void)64 void Resolver::Init(void)
65 {
66     memset(mUpstreamTransaction, 0, sizeof(mUpstreamTransaction));
67     LoadDnsServerListFromConf();
68 }
69 
TryRefreshDnsServerList(void)70 void Resolver::TryRefreshDnsServerList(void)
71 {
72     uint64_t now = otPlatTimeGet();
73 
74     if (now > mUpstreamDnsServerListFreshness + kDnsServerListCacheTimeoutMs ||
75         (mUpstreamDnsServerCount == 0 && now > mUpstreamDnsServerListFreshness + kDnsServerListNullCacheTimeoutMs))
76     {
77         LoadDnsServerListFromConf();
78     }
79 }
80 
LoadDnsServerListFromConf(void)81 void Resolver::LoadDnsServerListFromConf(void)
82 {
83     std::string   line;
84     std::ifstream fp;
85 
86     mUpstreamDnsServerCount = 0;
87 
88     fp.open(kResolvConfFullPath);
89 
90     while (fp.good() && std::getline(fp, line) && mUpstreamDnsServerCount < kMaxUpstreamServerCount)
91     {
92         if (line.find(kNameserverItem, 0) == 0)
93         {
94             in_addr_t addr;
95 
96             if (inet_pton(AF_INET, &line.c_str()[sizeof(kNameserverItem)], &addr) == 1)
97             {
98                 otLogInfoPlat("Got nameserver #%d: %s", mUpstreamDnsServerCount,
99                               &line.c_str()[sizeof(kNameserverItem)]);
100                 mUpstreamDnsServerList[mUpstreamDnsServerCount] = addr;
101                 mUpstreamDnsServerCount++;
102             }
103         }
104     }
105 
106     if (mUpstreamDnsServerCount == 0)
107     {
108         otLogCritPlat("No domain name servers found in %s, default to 127.0.0.1", kResolvConfFullPath);
109     }
110 
111     mUpstreamDnsServerListFreshness = otPlatTimeGet();
112 }
113 
Query(otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)114 void Resolver::Query(otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
115 {
116     char        packet[kMaxDnsMessageSize];
117     otError     error  = OT_ERROR_NONE;
118     uint16_t    length = otMessageGetLength(aQuery);
119     sockaddr_in serverAddr;
120 
121     Transaction *txn = nullptr;
122 
123     VerifyOrExit(length <= kMaxDnsMessageSize, error = OT_ERROR_NO_BUFS);
124     VerifyOrExit(otMessageRead(aQuery, 0, &packet, sizeof(packet)) == length, error = OT_ERROR_NO_BUFS);
125 
126     txn = AllocateTransaction(aTxn);
127     VerifyOrExit(txn != nullptr, error = OT_ERROR_NO_BUFS);
128 
129     TryRefreshDnsServerList();
130 
131     serverAddr.sin_family = AF_INET;
132     serverAddr.sin_port   = htons(53);
133     for (int i = 0; i < mUpstreamDnsServerCount; i++)
134     {
135         serverAddr.sin_addr.s_addr = mUpstreamDnsServerList[i];
136         VerifyOrExit(
137             sendto(txn->mUdpFd, packet, length, MSG_DONTWAIT, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) > 0,
138             error = OT_ERROR_NO_ROUTE);
139     }
140     otLogInfoPlat("Forwarded DNS query %p to %d server(s).", static_cast<void *>(aTxn), mUpstreamDnsServerCount);
141 
142 exit:
143     if (error != OT_ERROR_NONE)
144     {
145         otLogCritPlat("Failed to forward DNS query %p to server: %d", static_cast<void *>(aTxn), error);
146     }
147     return;
148 }
149 
Cancel(otPlatDnsUpstreamQuery * aTxn)150 void Resolver::Cancel(otPlatDnsUpstreamQuery *aTxn)
151 {
152     Transaction *txn = GetTransaction(aTxn);
153 
154     if (txn != nullptr)
155     {
156         CloseTransaction(txn);
157     }
158 
159     otPlatDnsUpstreamQueryDone(gInstance, aTxn, nullptr);
160 }
161 
AllocateTransaction(otPlatDnsUpstreamQuery * aThreadTxn)162 Resolver::Transaction *Resolver::AllocateTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
163 {
164     int          fdOrError = 0;
165     Transaction *ret       = nullptr;
166 
167     for (Transaction &txn : mUpstreamTransaction)
168     {
169         if (txn.mThreadTxn == nullptr)
170         {
171             fdOrError = socket(AF_INET, SOCK_DGRAM, 0);
172             if (fdOrError < 0)
173             {
174                 otLogInfoPlat("Failed to create socket for upstream resolver: %d", fdOrError);
175                 break;
176             }
177             ret             = &txn;
178             ret->mUdpFd     = fdOrError;
179             ret->mThreadTxn = aThreadTxn;
180             break;
181         }
182     }
183 
184     return ret;
185 }
186 
ForwardResponse(Transaction * aTxn)187 void Resolver::ForwardResponse(Transaction *aTxn)
188 {
189     char       response[kMaxDnsMessageSize];
190     ssize_t    readSize;
191     otError    error   = OT_ERROR_NONE;
192     otMessage *message = nullptr;
193 
194     VerifyOrExit((readSize = read(aTxn->mUdpFd, response, sizeof(response))) > 0);
195 
196     message = otUdpNewMessage(gInstance, nullptr);
197     VerifyOrExit(message != nullptr, error = OT_ERROR_NO_BUFS);
198     SuccessOrExit(error = otMessageAppend(message, response, readSize));
199 
200     otPlatDnsUpstreamQueryDone(gInstance, aTxn->mThreadTxn, message);
201     message = nullptr;
202 
203 exit:
204     if (readSize < 0)
205     {
206         otLogInfoPlat("Failed to read response from upstream resolver socket: %d", errno);
207     }
208     if (error != OT_ERROR_NONE)
209     {
210         otLogInfoPlat("Failed to forward upstream DNS response: %s", otThreadErrorToString(error));
211     }
212     if (message != nullptr)
213     {
214         otMessageFree(message);
215     }
216 }
217 
GetTransaction(int aFd)218 Resolver::Transaction *Resolver::GetTransaction(int aFd)
219 {
220     Transaction *ret = nullptr;
221 
222     for (Transaction &txn : mUpstreamTransaction)
223     {
224         if (txn.mThreadTxn != nullptr && txn.mUdpFd == aFd)
225         {
226             ret = &txn;
227             break;
228         }
229     }
230 
231     return ret;
232 }
233 
GetTransaction(otPlatDnsUpstreamQuery * aThreadTxn)234 Resolver::Transaction *Resolver::GetTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
235 {
236     Transaction *ret = nullptr;
237 
238     for (Transaction &txn : mUpstreamTransaction)
239     {
240         if (txn.mThreadTxn == aThreadTxn)
241         {
242             ret = &txn;
243             break;
244         }
245     }
246 
247     return ret;
248 }
249 
CloseTransaction(Transaction * aTxn)250 void Resolver::CloseTransaction(Transaction *aTxn)
251 {
252     if (aTxn->mUdpFd >= 0)
253     {
254         close(aTxn->mUdpFd);
255         aTxn->mUdpFd = -1;
256     }
257     aTxn->mThreadTxn = nullptr;
258 }
259 
UpdateFdSet(fd_set * aReadFdSet,fd_set * aErrorFdSet,int * aMaxFd)260 void Resolver::UpdateFdSet(fd_set *aReadFdSet, fd_set *aErrorFdSet, int *aMaxFd)
261 {
262     for (Transaction &txn : mUpstreamTransaction)
263     {
264         if (txn.mThreadTxn != nullptr)
265         {
266             FD_SET(txn.mUdpFd, aReadFdSet);
267             FD_SET(txn.mUdpFd, aErrorFdSet);
268             if (txn.mUdpFd > *aMaxFd)
269             {
270                 *aMaxFd = txn.mUdpFd;
271             }
272         }
273     }
274 }
275 
Process(const fd_set * aReadFdSet,const fd_set * aErrorFdSet)276 void Resolver::Process(const fd_set *aReadFdSet, const fd_set *aErrorFdSet)
277 {
278     for (Transaction &txn : mUpstreamTransaction)
279     {
280         if (txn.mThreadTxn != nullptr)
281         {
282             // Note: On Linux, we can only get the error via read, so they should share the same logic.
283             if (FD_ISSET(txn.mUdpFd, aErrorFdSet) || FD_ISSET(txn.mUdpFd, aReadFdSet))
284             {
285                 ForwardResponse(&txn);
286                 CloseTransaction(&txn);
287             }
288         }
289     }
290 }
291 
292 } // namespace Posix
293 } // namespace ot
294 
otPlatDnsStartUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)295 void otPlatDnsStartUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
296 {
297     OT_UNUSED_VARIABLE(aInstance);
298 
299     gResolver.Query(aTxn, aQuery);
300 }
301 
otPlatDnsCancelUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn)302 void otPlatDnsCancelUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn)
303 {
304     OT_UNUSED_VARIABLE(aInstance);
305 
306     gResolver.Cancel(aTxn);
307 }
308 
309 #endif // OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
310