1 /*
2 * Copyright (c) 2024 Nordic Semiconductor ASA
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_dns_dispatcher, CONFIG_DNS_SOCKET_DISPATCHER_LOG_LEVEL);
9
10 #include <zephyr/kernel.h>
11 #include <zephyr/sys/check.h>
12 #include <zephyr/sys/slist.h>
13 #include <zephyr/net_buf.h>
14 #include <zephyr/net/net_if.h>
15 #include <zephyr/net/dns_resolve.h>
16 #include <zephyr/net/socket_service.h>
17
18 #include "../../ip/net_stats.h"
19 #include "dns_pack.h"
20
21 static K_MUTEX_DEFINE(lock);
22
23 static sys_slist_t sockets;
24
25 #define DNS_RESOLVER_MIN_BUF 1
26 #define DNS_RESOLVER_BUF_CTR (DNS_RESOLVER_MIN_BUF + \
27 CONFIG_DNS_RESOLVER_ADDITIONAL_BUF_CTR)
28
29 NET_BUF_POOL_DEFINE(dns_msg_pool, DNS_RESOLVER_BUF_CTR,
30 DNS_RESOLVER_MAX_BUF_SIZE, 0, NULL);
31
32 static struct socket_dispatch_table {
33 struct dns_socket_dispatcher *ctx;
34 } dispatch_table[CONFIG_ZVFS_OPEN_MAX];
35
dns_dispatch(struct dns_socket_dispatcher * dispatcher,int sock,struct sockaddr * addr,size_t addrlen,struct net_buf * dns_data,size_t buf_len)36 static int dns_dispatch(struct dns_socket_dispatcher *dispatcher,
37 int sock, struct sockaddr *addr, size_t addrlen,
38 struct net_buf *dns_data, size_t buf_len)
39 {
40 /* Helper struct to track the dns msg received from the server */
41 struct dns_msg_t dns_msg;
42 bool is_query;
43 int data_len;
44 int ret;
45
46 data_len = MIN(buf_len, DNS_RESOLVER_MAX_BUF_SIZE);
47
48 dns_msg.msg = dns_data->data;
49 dns_msg.msg_size = data_len;
50
51 /* Make sure that we can read DNS id, flags and rcode */
52 if (dns_msg.msg_size < (sizeof(uint16_t) + sizeof(uint16_t))) {
53 ret = -EINVAL;
54 goto done;
55 }
56
57 if (dns_header_rcode(dns_msg.msg) == DNS_HEADER_REFUSED) {
58 ret = -EINVAL;
59 goto done;
60 }
61
62 is_query = (dns_header_qr(dns_msg.msg) == DNS_QUERY);
63 if (is_query) {
64 if (dispatcher->type == DNS_SOCKET_RESPONDER) {
65 /* Call the responder callback */
66 ret = dispatcher->cb(dispatcher->ctx, sock,
67 addr, addrlen,
68 dns_data, data_len);
69 } else if (dispatcher->pair) {
70 ret = dispatcher->pair->cb(dispatcher->pair->ctx, sock,
71 addr, addrlen,
72 dns_data, data_len);
73 } else {
74 /* Discard the message as it was a query and there are none
75 * expecting a query.
76 */
77 ret = -ENOENT;
78 }
79 } else {
80 /* So this was an answer to a query that was made by resolver.
81 */
82 if (dispatcher->type == DNS_SOCKET_RESOLVER) {
83 /* Call the resolver callback */
84 ret = dispatcher->cb(dispatcher->ctx, sock,
85 addr, addrlen,
86 dns_data, data_len);
87 } else if (dispatcher->pair) {
88 ret = dispatcher->pair->cb(dispatcher->pair->ctx, sock,
89 addr, addrlen,
90 dns_data, data_len);
91 } else {
92 /* Discard the message as it was not a query reply and
93 * we were a reply.
94 */
95 ret = -ENOENT;
96 }
97 }
98
99 done:
100 if (IS_ENABLED(CONFIG_NET_STATISTICS_DNS)) {
101 struct net_if *iface = NULL;
102
103 if (IS_ENABLED(CONFIG_NET_IPV6) && addr->sa_family == AF_INET6) {
104 iface = net_if_ipv6_select_src_iface(&net_sin6(addr)->sin6_addr);
105 } else if (IS_ENABLED(CONFIG_NET_IPV4) && addr->sa_family == AF_INET) {
106 iface = net_if_ipv4_select_src_iface(&net_sin(addr)->sin_addr);
107 }
108
109 if (iface != NULL) {
110 if (ret < 0) {
111 net_stats_update_dns_drop(iface);
112 } else {
113 net_stats_update_dns_recv(iface);
114 }
115 }
116 }
117
118 return ret;
119 }
120
recv_data(struct net_socket_service_event * pev)121 static int recv_data(struct net_socket_service_event *pev)
122 {
123 struct socket_dispatch_table *table = pev->user_data;
124 struct dns_socket_dispatcher *dispatcher;
125 socklen_t optlen = sizeof(int);
126 struct net_buf *dns_data = NULL;
127 struct sockaddr addr;
128 size_t addrlen;
129 int family, sock_error;
130 int ret = 0, len;
131
132 dispatcher = table[pev->event.fd].ctx;
133
134 k_mutex_lock(&dispatcher->lock, K_FOREVER);
135
136 (void)zsock_getsockopt(pev->event.fd, SOL_SOCKET,
137 SO_DOMAIN, &family, &optlen);
138
139 if ((pev->event.revents & ZSOCK_POLLERR) ||
140 (pev->event.revents & ZSOCK_POLLNVAL)) {
141 (void)zsock_getsockopt(pev->event.fd, SOL_SOCKET,
142 SO_ERROR, &sock_error, &optlen);
143 if (sock_error > 0) {
144 NET_ERR("Receiver IPv%d socket error (%d)",
145 family == AF_INET ? 4 : 6, sock_error);
146 ret = DNS_EAI_SYSTEM;
147 }
148
149 goto unlock;
150 }
151
152 addrlen = (family == AF_INET) ? sizeof(struct sockaddr_in) :
153 sizeof(struct sockaddr_in6);
154
155 dns_data = net_buf_alloc(&dns_msg_pool, dispatcher->buf_timeout);
156 if (!dns_data) {
157 ret = DNS_EAI_MEMORY;
158 goto unlock;
159 }
160
161 ret = zsock_recvfrom(pev->event.fd, dns_data->data,
162 net_buf_max_len(dns_data), 0,
163 (struct sockaddr *)&addr, &addrlen);
164 if (ret < 0) {
165 ret = -errno;
166 NET_ERR("recv failed on IPv%d socket (%d)",
167 family == AF_INET ? 4 : 6, -ret);
168 goto free_buf;
169 }
170
171 len = ret;
172
173 ret = dns_dispatch(dispatcher, pev->event.fd,
174 (struct sockaddr *)&addr, addrlen,
175 dns_data, len);
176 free_buf:
177 if (dns_data) {
178 net_buf_unref(dns_data);
179 }
180
181 unlock:
182 k_mutex_unlock(&dispatcher->lock);
183
184 return ret;
185 }
186
dns_dispatcher_svc_handler(struct net_socket_service_event * pev)187 void dns_dispatcher_svc_handler(struct net_socket_service_event *pev)
188 {
189 int ret;
190
191 ret = recv_data(pev);
192 if (ret < 0 && ret != DNS_EAI_ALLDONE && ret != -ENOENT) {
193 NET_ERR("DNS recv error (%d)", ret);
194 }
195 }
196
dns_dispatcher_register(struct dns_socket_dispatcher * ctx)197 int dns_dispatcher_register(struct dns_socket_dispatcher *ctx)
198 {
199 struct dns_socket_dispatcher *entry, *next, *found = NULL;
200 sys_snode_t *prev_node = NULL;
201 bool dup = false;
202 size_t addrlen;
203 int ret = 0;
204
205 k_mutex_lock(&lock, K_FOREVER);
206
207 if (sys_slist_find(&sockets, &ctx->node, &prev_node)) {
208 ret = -EALREADY;
209 goto out;
210 }
211
212 SYS_SLIST_FOR_EACH_CONTAINER_SAFE(&sockets, entry, next, node) {
213 /* Refuse to register context if we have identical context
214 * already registered.
215 */
216 if (ctx->type == entry->type &&
217 ctx->local_addr.sa_family == entry->local_addr.sa_family &&
218 ctx->ifindex == entry->ifindex) {
219 if (net_sin(&entry->local_addr)->sin_port ==
220 net_sin(&ctx->local_addr)->sin_port) {
221 dup = true;
222 continue;
223 }
224 }
225
226 /* Then check if there is an entry with same family and
227 * port already in the list. If there is then we can act
228 * as a dispatcher for the given socket. Do not break
229 * from the loop even if we found an entry so that we
230 * can catch possible duplicates.
231 */
232 if (found == NULL && ctx->type != entry->type &&
233 ctx->local_addr.sa_family == entry->local_addr.sa_family) {
234 if (net_sin(&entry->local_addr)->sin_port ==
235 net_sin(&ctx->local_addr)->sin_port) {
236 found = entry;
237 continue;
238 }
239 }
240 }
241
242 if (dup) {
243 /* Found a duplicate */
244 ret = -EALREADY;
245 goto out;
246 }
247
248 if (found != NULL) {
249 entry = found;
250
251 if (entry->pair != NULL) {
252 NET_DBG("Already paired connection found.");
253 ret = -EALREADY;
254 goto out;
255 }
256
257 entry->pair = ctx;
258
259 for (int i = 0; i < ctx->fds_len; i++) {
260 CHECKIF((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
261 ret = -ERANGE;
262 goto out;
263 }
264
265 if (ctx->fds[i].fd < 0) {
266 continue;
267 }
268
269 if (dispatch_table[ctx->fds[i].fd].ctx == NULL) {
270 dispatch_table[ctx->fds[i].fd].ctx = ctx;
271 }
272 }
273
274 /* Basically we are now done. If there is incoming data to
275 * the socket, the dispatcher will then pass it to the correct
276 * recipient.
277 */
278 ret = 0;
279 goto out;
280 }
281
282 ctx->buf_timeout = DNS_BUF_TIMEOUT;
283
284 if (ctx->local_addr.sa_family == AF_INET) {
285 addrlen = sizeof(struct sockaddr_in);
286 } else {
287 addrlen = sizeof(struct sockaddr_in6);
288 }
289
290 /* Bind and then register a socket service with this combo */
291 ret = zsock_bind(ctx->sock, &ctx->local_addr, addrlen);
292 if (ret < 0) {
293 ret = -errno;
294 NET_DBG("Cannot bind DNS socket %d (%d)", ctx->sock, ret);
295 goto out;
296 }
297
298 ctx->pair = NULL;
299
300 for (int i = 0; i < ctx->fds_len; i++) {
301 if ((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
302 ret = -ERANGE;
303 goto out;
304 }
305
306 if (ctx->fds[i].fd < 0) {
307 continue;
308 }
309
310 if (dispatch_table[ctx->fds[i].fd].ctx == NULL) {
311 dispatch_table[ctx->fds[i].fd].ctx = ctx;
312 }
313 }
314
315 ret = net_socket_service_register(ctx->svc, ctx->fds, ctx->fds_len, &dispatch_table);
316 if (ret < 0) {
317 NET_DBG("Cannot register socket service (%d)", ret);
318 goto out;
319 }
320
321 sys_slist_prepend(&sockets, &ctx->node);
322
323 out:
324 k_mutex_unlock(&lock);
325
326 return ret;
327 }
328
dns_dispatcher_unregister(struct dns_socket_dispatcher * ctx)329 int dns_dispatcher_unregister(struct dns_socket_dispatcher *ctx)
330 {
331 int ret = 0;
332
333 k_mutex_lock(&lock, K_FOREVER);
334
335 (void)sys_slist_find_and_remove(&sockets, &ctx->node);
336
337 (void)net_socket_service_unregister(ctx->svc);
338
339 /* Mark the context as unregistered */
340 ctx->sock = -1;
341
342 for (int i = 0; i < ctx->fds_len; i++) {
343 CHECKIF((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
344 ret = -ERANGE;
345 goto out;
346 }
347
348 dispatch_table[ctx->fds[i].fd].ctx = NULL;
349 }
350
351 out:
352 k_mutex_unlock(&lock);
353
354 return ret;
355 }
356
dns_dispatcher_init(void)357 void dns_dispatcher_init(void)
358 {
359 sys_slist_init(&sockets);
360 }
361