1 /*
2  * Copyright (c) 2024 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(net_dns_dispatcher, CONFIG_DNS_SOCKET_DISPATCHER_LOG_LEVEL);
9 
10 #include <zephyr/kernel.h>
11 #include <zephyr/sys/check.h>
12 #include <zephyr/sys/slist.h>
13 #include <zephyr/net_buf.h>
14 #include <zephyr/net/net_if.h>
15 #include <zephyr/net/dns_resolve.h>
16 #include <zephyr/net/socket_service.h>
17 
18 #include "../../ip/net_stats.h"
19 #include "dns_pack.h"
20 
21 static K_MUTEX_DEFINE(lock);
22 
23 static sys_slist_t sockets;
24 
25 #define DNS_RESOLVER_MIN_BUF	1
26 #define DNS_RESOLVER_BUF_CTR	(DNS_RESOLVER_MIN_BUF + \
27 				 CONFIG_DNS_RESOLVER_ADDITIONAL_BUF_CTR)
28 
29 NET_BUF_POOL_DEFINE(dns_msg_pool, DNS_RESOLVER_BUF_CTR,
30 		    DNS_RESOLVER_MAX_BUF_SIZE, 0, NULL);
31 
32 static struct socket_dispatch_table {
33 	struct dns_socket_dispatcher *ctx;
34 } dispatch_table[CONFIG_ZVFS_OPEN_MAX];
35 
dns_dispatch(struct dns_socket_dispatcher * dispatcher,int sock,struct sockaddr * addr,size_t addrlen,struct net_buf * dns_data,size_t buf_len)36 static int dns_dispatch(struct dns_socket_dispatcher *dispatcher,
37 			int sock, struct sockaddr *addr, size_t addrlen,
38 			struct net_buf *dns_data, size_t buf_len)
39 {
40 	/* Helper struct to track the dns msg received from the server */
41 	struct dns_msg_t dns_msg;
42 	bool is_query;
43 	int data_len;
44 	int ret;
45 
46 	data_len = MIN(buf_len, DNS_RESOLVER_MAX_BUF_SIZE);
47 
48 	dns_msg.msg = dns_data->data;
49 	dns_msg.msg_size = data_len;
50 
51 	/* Make sure that we can read DNS id, flags and rcode */
52 	if (dns_msg.msg_size < (sizeof(uint16_t) + sizeof(uint16_t))) {
53 		ret = -EINVAL;
54 		goto done;
55 	}
56 
57 	if (dns_header_rcode(dns_msg.msg) == DNS_HEADER_REFUSED) {
58 		ret = -EINVAL;
59 		goto done;
60 	}
61 
62 	is_query = (dns_header_qr(dns_msg.msg) == DNS_QUERY);
63 	if (is_query) {
64 		if (dispatcher->type == DNS_SOCKET_RESPONDER) {
65 			/* Call the responder callback */
66 			ret = dispatcher->cb(dispatcher->ctx, sock,
67 					     addr, addrlen,
68 					     dns_data, data_len);
69 		} else if (dispatcher->pair) {
70 			ret = dispatcher->pair->cb(dispatcher->pair->ctx, sock,
71 						   addr, addrlen,
72 						   dns_data, data_len);
73 		} else {
74 			/* Discard the message as it was a query and there are none
75 			 * expecting a query.
76 			 */
77 			ret = -ENOENT;
78 		}
79 	} else {
80 		/* So this was an answer to a query that was made by resolver.
81 		 */
82 		if (dispatcher->type == DNS_SOCKET_RESOLVER) {
83 			/* Call the resolver callback */
84 			ret = dispatcher->cb(dispatcher->ctx, sock,
85 					     addr, addrlen,
86 					     dns_data, data_len);
87 		} else if (dispatcher->pair) {
88 			ret = dispatcher->pair->cb(dispatcher->pair->ctx, sock,
89 						   addr, addrlen,
90 						   dns_data, data_len);
91 		} else {
92 			/* Discard the message as it was not a query reply and
93 			 * we were a reply.
94 			 */
95 			ret = -ENOENT;
96 		}
97 	}
98 
99 done:
100 	if (IS_ENABLED(CONFIG_NET_STATISTICS_DNS)) {
101 		struct net_if *iface = NULL;
102 
103 		if (IS_ENABLED(CONFIG_NET_IPV6) && addr->sa_family == AF_INET6) {
104 			iface = net_if_ipv6_select_src_iface(&net_sin6(addr)->sin6_addr);
105 		} else if (IS_ENABLED(CONFIG_NET_IPV4) && addr->sa_family == AF_INET) {
106 			iface = net_if_ipv4_select_src_iface(&net_sin(addr)->sin_addr);
107 		}
108 
109 		if (iface != NULL) {
110 			if (ret < 0) {
111 				net_stats_update_dns_drop(iface);
112 			} else {
113 				net_stats_update_dns_recv(iface);
114 			}
115 		}
116 	}
117 
118 	return ret;
119 }
120 
recv_data(struct net_socket_service_event * pev)121 static int recv_data(struct net_socket_service_event *pev)
122 {
123 	struct socket_dispatch_table *table = pev->user_data;
124 	struct dns_socket_dispatcher *dispatcher;
125 	socklen_t optlen = sizeof(int);
126 	struct net_buf *dns_data = NULL;
127 	struct sockaddr addr;
128 	size_t addrlen;
129 	int family, sock_error;
130 	int ret = 0, len;
131 
132 	dispatcher = table[pev->event.fd].ctx;
133 
134 	k_mutex_lock(&dispatcher->lock, K_FOREVER);
135 
136 	(void)zsock_getsockopt(pev->event.fd, SOL_SOCKET,
137 			       SO_DOMAIN, &family, &optlen);
138 
139 	if ((pev->event.revents & ZSOCK_POLLERR) ||
140 	    (pev->event.revents & ZSOCK_POLLNVAL)) {
141 		(void)zsock_getsockopt(pev->event.fd, SOL_SOCKET,
142 				       SO_ERROR, &sock_error, &optlen);
143 		if (sock_error > 0) {
144 			NET_ERR("Receiver IPv%d socket error (%d)",
145 				family == AF_INET ? 4 : 6, sock_error);
146 			ret = DNS_EAI_SYSTEM;
147 		}
148 
149 		goto unlock;
150 	}
151 
152 	addrlen = (family == AF_INET) ? sizeof(struct sockaddr_in) :
153 		sizeof(struct sockaddr_in6);
154 
155 	dns_data = net_buf_alloc(&dns_msg_pool, dispatcher->buf_timeout);
156 	if (!dns_data) {
157 		ret = DNS_EAI_MEMORY;
158 		goto unlock;
159 	}
160 
161 	ret = zsock_recvfrom(pev->event.fd, dns_data->data,
162 			     net_buf_max_len(dns_data), 0,
163 			     (struct sockaddr *)&addr, &addrlen);
164 	if (ret < 0) {
165 		ret = -errno;
166 		NET_ERR("recv failed on IPv%d socket (%d)",
167 			family == AF_INET ? 4 : 6, -ret);
168 		goto free_buf;
169 	}
170 
171 	len = ret;
172 
173 	ret = dns_dispatch(dispatcher, pev->event.fd,
174 			   (struct sockaddr *)&addr, addrlen,
175 			   dns_data, len);
176 free_buf:
177 	if (dns_data) {
178 		net_buf_unref(dns_data);
179 	}
180 
181 unlock:
182 	k_mutex_unlock(&dispatcher->lock);
183 
184 	return ret;
185 }
186 
dns_dispatcher_svc_handler(struct net_socket_service_event * pev)187 void dns_dispatcher_svc_handler(struct net_socket_service_event *pev)
188 {
189 	int ret;
190 
191 	ret = recv_data(pev);
192 	if (ret < 0 && ret != DNS_EAI_ALLDONE && ret != -ENOENT) {
193 		NET_ERR("DNS recv error (%d)", ret);
194 	}
195 }
196 
dns_dispatcher_register(struct dns_socket_dispatcher * ctx)197 int dns_dispatcher_register(struct dns_socket_dispatcher *ctx)
198 {
199 	struct dns_socket_dispatcher *entry, *next, *found = NULL;
200 	sys_snode_t *prev_node = NULL;
201 	bool dup = false;
202 	size_t addrlen;
203 	int ret = 0;
204 
205 	k_mutex_lock(&lock, K_FOREVER);
206 
207 	if (sys_slist_find(&sockets, &ctx->node, &prev_node)) {
208 		ret = -EALREADY;
209 		goto out;
210 	}
211 
212 	SYS_SLIST_FOR_EACH_CONTAINER_SAFE(&sockets, entry, next, node) {
213 		/* Refuse to register context if we have identical context
214 		 * already registered.
215 		 */
216 		if (ctx->type == entry->type &&
217 		    ctx->local_addr.sa_family == entry->local_addr.sa_family &&
218 		    ctx->ifindex == entry->ifindex) {
219 			if (net_sin(&entry->local_addr)->sin_port ==
220 			    net_sin(&ctx->local_addr)->sin_port) {
221 				dup = true;
222 				continue;
223 			}
224 		}
225 
226 		/* Then check if there is an entry with same family and
227 		 * port already in the list. If there is then we can act
228 		 * as a dispatcher for the given socket. Do not break
229 		 * from the loop even if we found an entry so that we
230 		 * can catch possible duplicates.
231 		 */
232 		if (found == NULL && ctx->type != entry->type &&
233 		    ctx->local_addr.sa_family == entry->local_addr.sa_family) {
234 			if (net_sin(&entry->local_addr)->sin_port ==
235 			    net_sin(&ctx->local_addr)->sin_port) {
236 				found = entry;
237 				continue;
238 			}
239 		}
240 	}
241 
242 	if (dup) {
243 		/* Found a duplicate */
244 		ret = -EALREADY;
245 		goto out;
246 	}
247 
248 	if (found != NULL) {
249 		entry = found;
250 
251 		if (entry->pair != NULL) {
252 			NET_DBG("Already paired connection found.");
253 			ret = -EALREADY;
254 			goto out;
255 		}
256 
257 		entry->pair = ctx;
258 
259 		for (int i = 0; i < ctx->fds_len; i++) {
260 			CHECKIF((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
261 				ret = -ERANGE;
262 				goto out;
263 			}
264 
265 			if (ctx->fds[i].fd < 0) {
266 				continue;
267 			}
268 
269 			if (dispatch_table[ctx->fds[i].fd].ctx == NULL) {
270 				dispatch_table[ctx->fds[i].fd].ctx = ctx;
271 			}
272 		}
273 
274 		/* Basically we are now done. If there is incoming data to
275 		 * the socket, the dispatcher will then pass it to the correct
276 		 * recipient.
277 		 */
278 		ret = 0;
279 		goto out;
280 	}
281 
282 	ctx->buf_timeout = DNS_BUF_TIMEOUT;
283 
284 	if (ctx->local_addr.sa_family == AF_INET) {
285 		addrlen = sizeof(struct sockaddr_in);
286 	} else {
287 		addrlen = sizeof(struct sockaddr_in6);
288 	}
289 
290 	/* Bind and then register a socket service with this combo */
291 	ret = zsock_bind(ctx->sock, &ctx->local_addr, addrlen);
292 	if (ret < 0) {
293 		ret = -errno;
294 		NET_DBG("Cannot bind DNS socket %d (%d)", ctx->sock, ret);
295 		goto out;
296 	}
297 
298 	ctx->pair = NULL;
299 
300 	for (int i = 0; i < ctx->fds_len; i++) {
301 		if ((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
302 			ret = -ERANGE;
303 			goto out;
304 		}
305 
306 		if (ctx->fds[i].fd < 0) {
307 			continue;
308 		}
309 
310 		if (dispatch_table[ctx->fds[i].fd].ctx == NULL) {
311 			dispatch_table[ctx->fds[i].fd].ctx = ctx;
312 		}
313 	}
314 
315 	ret = net_socket_service_register(ctx->svc, ctx->fds, ctx->fds_len, &dispatch_table);
316 	if (ret < 0) {
317 		NET_DBG("Cannot register socket service (%d)", ret);
318 		goto out;
319 	}
320 
321 	sys_slist_prepend(&sockets, &ctx->node);
322 
323 out:
324 	k_mutex_unlock(&lock);
325 
326 	return ret;
327 }
328 
dns_dispatcher_unregister(struct dns_socket_dispatcher * ctx)329 int dns_dispatcher_unregister(struct dns_socket_dispatcher *ctx)
330 {
331 	int ret = 0;
332 
333 	k_mutex_lock(&lock, K_FOREVER);
334 
335 	(void)sys_slist_find_and_remove(&sockets, &ctx->node);
336 
337 	(void)net_socket_service_unregister(ctx->svc);
338 
339 	/* Mark the context as unregistered */
340 	ctx->sock = -1;
341 
342 	for (int i = 0; i < ctx->fds_len; i++) {
343 		CHECKIF((int)ctx->fds[i].fd >= (int)ARRAY_SIZE(dispatch_table)) {
344 			ret = -ERANGE;
345 			goto out;
346 		}
347 
348 		dispatch_table[ctx->fds[i].fd].ctx = NULL;
349 	}
350 
351 out:
352 	k_mutex_unlock(&lock);
353 
354 	return ret;
355 }
356 
dns_dispatcher_init(void)357 void dns_dispatcher_init(void)
358 {
359 	sys_slist_init(&sockets);
360 }
361