1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "resolver.hpp"
30 
31 #include "platform-posix.h"
32 
33 #include <openthread/logging.h>
34 #include <openthread/message.h>
35 #include <openthread/nat64.h>
36 #include <openthread/openthread-system.h>
37 #include <openthread/udp.h>
38 #include <openthread/platform/dns.h>
39 #include <openthread/platform/time.h>
40 
41 #include "common/code_utils.hpp"
42 
43 #include <arpa/inet.h>
44 #include <arpa/nameser.h>
45 #include <cassert>
46 #include <netinet/in.h>
47 #include <sys/select.h>
48 #include <sys/socket.h>
49 #include <unistd.h>
50 
51 #include <fstream>
52 #include <string>
53 
54 #if OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
55 
56 namespace {
57 constexpr char kResolvConfFullPath[] = "/etc/resolv.conf";
58 constexpr char kNameserverItem[]     = "nameserver";
59 } // namespace
60 
61 ot::Posix::Resolver gResolver;
62 
63 namespace ot {
64 namespace Posix {
65 
66 const char Resolver::kLogModuleName[] = "Resolver";
67 
Init(void)68 void Resolver::Init(void)
69 {
70     memset(mUpstreamTransaction, 0, sizeof(mUpstreamTransaction));
71     LoadDnsServerListFromConf();
72 }
73 
TryRefreshDnsServerList(void)74 void Resolver::TryRefreshDnsServerList(void)
75 {
76     uint64_t now = otPlatTimeGet();
77 
78     if (now > mUpstreamDnsServerListFreshness + kDnsServerListCacheTimeoutMs ||
79         (mUpstreamDnsServerCount == 0 && now > mUpstreamDnsServerListFreshness + kDnsServerListNullCacheTimeoutMs))
80     {
81         LoadDnsServerListFromConf();
82     }
83 }
84 
LoadDnsServerListFromConf(void)85 void Resolver::LoadDnsServerListFromConf(void)
86 {
87     std::string   line;
88     std::ifstream fp;
89 
90     VerifyOrExit(mIsResolvConfEnabled);
91 
92     mUpstreamDnsServerCount = 0;
93 
94     fp.open(kResolvConfFullPath);
95 
96     while (fp.good() && std::getline(fp, line) && mUpstreamDnsServerCount < kMaxUpstreamServerCount)
97     {
98         if (line.find(kNameserverItem, 0) == 0)
99         {
100             in_addr_t addr;
101 
102             if (inet_pton(AF_INET, &line.c_str()[sizeof(kNameserverItem)], &addr) == 1)
103             {
104                 LogInfo("Got nameserver #%d: %s", mUpstreamDnsServerCount, &line.c_str()[sizeof(kNameserverItem)]);
105                 mUpstreamDnsServerList[mUpstreamDnsServerCount] = addr;
106                 mUpstreamDnsServerCount++;
107             }
108         }
109     }
110 
111     if (mUpstreamDnsServerCount == 0)
112     {
113         LogCrit("No domain name servers found in %s, default to 127.0.0.1", kResolvConfFullPath);
114     }
115 
116     mUpstreamDnsServerListFreshness = otPlatTimeGet();
117 exit:
118     return;
119 }
120 
Query(otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)121 void Resolver::Query(otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
122 {
123     char        packet[kMaxDnsMessageSize];
124     otError     error  = OT_ERROR_NONE;
125     uint16_t    length = otMessageGetLength(aQuery);
126     sockaddr_in serverAddr;
127 
128     Transaction *txn = nullptr;
129 
130     VerifyOrExit(length <= kMaxDnsMessageSize, error = OT_ERROR_NO_BUFS);
131     VerifyOrExit(otMessageRead(aQuery, 0, &packet, sizeof(packet)) == length, error = OT_ERROR_NO_BUFS);
132 
133     txn = AllocateTransaction(aTxn);
134     VerifyOrExit(txn != nullptr, error = OT_ERROR_NO_BUFS);
135 
136     TryRefreshDnsServerList();
137 
138     serverAddr.sin_family = AF_INET;
139     serverAddr.sin_port   = htons(53);
140     for (int i = 0; i < mUpstreamDnsServerCount; i++)
141     {
142         serverAddr.sin_addr.s_addr = mUpstreamDnsServerList[i];
143         VerifyOrExit(
144             sendto(txn->mUdpFd, packet, length, MSG_DONTWAIT, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) > 0,
145             error = OT_ERROR_NO_ROUTE);
146     }
147     LogInfo("Forwarded DNS query %p to %d server(s).", static_cast<void *>(aTxn), mUpstreamDnsServerCount);
148 
149 exit:
150     if (error != OT_ERROR_NONE)
151     {
152         LogCrit("Failed to forward DNS query %p to server: %d", static_cast<void *>(aTxn), error);
153     }
154     return;
155 }
156 
Cancel(otPlatDnsUpstreamQuery * aTxn)157 void Resolver::Cancel(otPlatDnsUpstreamQuery *aTxn)
158 {
159     Transaction *txn = GetTransaction(aTxn);
160 
161     if (txn != nullptr)
162     {
163         CloseTransaction(txn);
164     }
165 
166     otPlatDnsUpstreamQueryDone(gInstance, aTxn, nullptr);
167 }
168 
AllocateTransaction(otPlatDnsUpstreamQuery * aThreadTxn)169 Resolver::Transaction *Resolver::AllocateTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
170 {
171     int          fdOrError = 0;
172     Transaction *ret       = nullptr;
173 
174     for (Transaction &txn : mUpstreamTransaction)
175     {
176         if (txn.mThreadTxn == nullptr)
177         {
178             fdOrError = socket(AF_INET, SOCK_DGRAM, 0);
179             if (fdOrError < 0)
180             {
181                 LogInfo("Failed to create socket for upstream resolver: %d", fdOrError);
182                 break;
183             }
184             ret             = &txn;
185             ret->mUdpFd     = fdOrError;
186             ret->mThreadTxn = aThreadTxn;
187             break;
188         }
189     }
190 
191     return ret;
192 }
193 
ForwardResponse(Transaction * aTxn)194 void Resolver::ForwardResponse(Transaction *aTxn)
195 {
196     char       response[kMaxDnsMessageSize];
197     ssize_t    readSize;
198     otError    error   = OT_ERROR_NONE;
199     otMessage *message = nullptr;
200 
201     VerifyOrExit((readSize = read(aTxn->mUdpFd, response, sizeof(response))) > 0);
202 
203     message = otUdpNewMessage(gInstance, nullptr);
204     VerifyOrExit(message != nullptr, error = OT_ERROR_NO_BUFS);
205     SuccessOrExit(error = otMessageAppend(message, response, readSize));
206 
207     otPlatDnsUpstreamQueryDone(gInstance, aTxn->mThreadTxn, message);
208     message = nullptr;
209 
210 exit:
211     if (readSize < 0)
212     {
213         LogInfo("Failed to read response from upstream resolver socket: %d", errno);
214     }
215     if (error != OT_ERROR_NONE)
216     {
217         LogInfo("Failed to forward upstream DNS response: %s", otThreadErrorToString(error));
218     }
219     if (message != nullptr)
220     {
221         otMessageFree(message);
222     }
223 }
224 
GetTransaction(int aFd)225 Resolver::Transaction *Resolver::GetTransaction(int aFd)
226 {
227     Transaction *ret = nullptr;
228 
229     for (Transaction &txn : mUpstreamTransaction)
230     {
231         if (txn.mThreadTxn != nullptr && txn.mUdpFd == aFd)
232         {
233             ret = &txn;
234             break;
235         }
236     }
237 
238     return ret;
239 }
240 
GetTransaction(otPlatDnsUpstreamQuery * aThreadTxn)241 Resolver::Transaction *Resolver::GetTransaction(otPlatDnsUpstreamQuery *aThreadTxn)
242 {
243     Transaction *ret = nullptr;
244 
245     for (Transaction &txn : mUpstreamTransaction)
246     {
247         if (txn.mThreadTxn == aThreadTxn)
248         {
249             ret = &txn;
250             break;
251         }
252     }
253 
254     return ret;
255 }
256 
CloseTransaction(Transaction * aTxn)257 void Resolver::CloseTransaction(Transaction *aTxn)
258 {
259     if (aTxn->mUdpFd >= 0)
260     {
261         close(aTxn->mUdpFd);
262         aTxn->mUdpFd = -1;
263     }
264     aTxn->mThreadTxn = nullptr;
265 }
266 
UpdateFdSet(otSysMainloopContext & aContext)267 void Resolver::UpdateFdSet(otSysMainloopContext &aContext)
268 {
269     for (Transaction &txn : mUpstreamTransaction)
270     {
271         if (txn.mThreadTxn != nullptr)
272         {
273             FD_SET(txn.mUdpFd, &aContext.mReadFdSet);
274             FD_SET(txn.mUdpFd, &aContext.mErrorFdSet);
275             if (txn.mUdpFd > aContext.mMaxFd)
276             {
277                 aContext.mMaxFd = txn.mUdpFd;
278             }
279         }
280     }
281 }
282 
Process(const otSysMainloopContext & aContext)283 void Resolver::Process(const otSysMainloopContext &aContext)
284 {
285     for (Transaction &txn : mUpstreamTransaction)
286     {
287         if (txn.mThreadTxn != nullptr)
288         {
289             // Note: On Linux, we can only get the error via read, so they should share the same logic.
290             if (FD_ISSET(txn.mUdpFd, &aContext.mErrorFdSet) || FD_ISSET(txn.mUdpFd, &aContext.mReadFdSet))
291             {
292                 ForwardResponse(&txn);
293                 CloseTransaction(&txn);
294             }
295         }
296     }
297 }
298 
SetUpstreamDnsServers(const otIp6Address * aUpstreamDnsServers,int aNumServers)299 void Resolver::SetUpstreamDnsServers(const otIp6Address *aUpstreamDnsServers, int aNumServers)
300 {
301     mUpstreamDnsServerCount = 0;
302 
303     for (int i = 0; i < aNumServers && i < kMaxUpstreamServerCount; ++i)
304     {
305         otIp4Address ip4Address;
306 
307         // TODO: support DNS servers with IPv6 addresses
308         if (otIp4FromIp4MappedIp6Address(&aUpstreamDnsServers[i], &ip4Address) == OT_ERROR_NONE)
309         {
310             mUpstreamDnsServerList[mUpstreamDnsServerCount] = ip4Address.mFields.m32;
311             mUpstreamDnsServerCount++;
312         }
313     }
314 }
315 
316 } // namespace Posix
317 } // namespace ot
318 
platformResolverProcess(const otSysMainloopContext * aContext)319 void platformResolverProcess(const otSysMainloopContext *aContext) { gResolver.Process(*aContext); }
320 
platformResolverUpdateFdSet(otSysMainloopContext * aContext)321 void platformResolverUpdateFdSet(otSysMainloopContext *aContext) { gResolver.UpdateFdSet(*aContext); }
322 
platformResolverInit(void)323 void platformResolverInit(void) { gResolver.Init(); }
324 
otPlatDnsStartUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn,const otMessage * aQuery)325 void otPlatDnsStartUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn, const otMessage *aQuery)
326 {
327     OT_UNUSED_VARIABLE(aInstance);
328 
329     gResolver.Query(aTxn, aQuery);
330 }
331 
otPlatDnsCancelUpstreamQuery(otInstance * aInstance,otPlatDnsUpstreamQuery * aTxn)332 void otPlatDnsCancelUpstreamQuery(otInstance *aInstance, otPlatDnsUpstreamQuery *aTxn)
333 {
334     OT_UNUSED_VARIABLE(aInstance);
335 
336     gResolver.Cancel(aTxn);
337 }
338 
otSysUpstreamDnsServerSetResolvConfEnabled(bool aEnabled)339 void otSysUpstreamDnsServerSetResolvConfEnabled(bool aEnabled) { gResolver.SetResolvConfEnabled(aEnabled); }
340 
otSysUpstreamDnsSetServerList(const otIp6Address * aUpstreamDnsServers,int aNumServers)341 void otSysUpstreamDnsSetServerList(const otIp6Address *aUpstreamDnsServers, int aNumServers)
342 {
343     gResolver.SetUpstreamDnsServers(aUpstreamDnsServers, aNumServers);
344 }
345 
346 #endif // OPENTHREAD_CONFIG_DNS_UPSTREAM_QUERY_ENABLE
347