1 /*
2  * Copyright (c) 2022 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 #include <zephyr/net/socket.h>
9 #include <zephyr/sys/iterable_sections.h>
10 
11 #include "sockets_internal.h"
12 
13 LOG_MODULE_REGISTER(net_sock_dispatcher, CONFIG_NET_SOCKETS_LOG_LEVEL);
14 
15 __net_socket struct dispatcher_context {
16 	int fd;
17 	int family;
18 	int type;
19 	int proto;
20 	bool is_used;
21 };
22 
23 static struct dispatcher_context
24 	dispatcher_context[CONFIG_NET_SOCKETS_OFFLOAD_DISPATCHER_CONTEXT_MAX];
25 
26 static K_MUTEX_DEFINE(dispatcher_lock);
27 
28 static int sock_dispatch_create(int family, int type, int proto);
29 
is_tls(int proto)30 static bool is_tls(int proto)
31 {
32 	if ((proto >= IPPROTO_TLS_1_0 && proto <= IPPROTO_TLS_1_2) ||
33 	    (proto >= IPPROTO_DTLS_1_0 && proto <= IPPROTO_DTLS_1_2)) {
34 		return true;
35 	}
36 
37 	return false;
38 }
39 
dispatcher_ctx_free(struct dispatcher_context * ctx)40 static void dispatcher_ctx_free(struct dispatcher_context *ctx)
41 {
42 	(void)k_mutex_lock(&dispatcher_lock, K_FOREVER);
43 
44 	/* Free the dispatcher entry. */
45 	memset(ctx, 0, sizeof(*ctx));
46 
47 	k_mutex_unlock(&dispatcher_lock);
48 }
49 
sock_dispatch_socket(struct dispatcher_context * ctx,net_socket_create_t socket_create)50 static int sock_dispatch_socket(struct dispatcher_context *ctx,
51 				net_socket_create_t socket_create)
52 {
53 	int new_fd, fd;
54 	const struct socket_op_vtable *vtable;
55 	void *obj;
56 
57 	new_fd = socket_create(ctx->family, ctx->type, ctx->proto);
58 	if (new_fd < 0) {
59 		LOG_INF("Failed to create socket to dispatch");
60 		return -1;
61 	}
62 
63 	obj = zvfs_get_fd_obj_and_vtable(new_fd,
64 				      (const struct fd_op_vtable **)&vtable,
65 				      NULL);
66 	if (obj == NULL) {
67 		return -1;
68 	}
69 
70 	/* Reassing FD with new obj and entry. */
71 	fd = ctx->fd;
72 	zvfs_finalize_typed_fd(fd, obj, (const struct fd_op_vtable *)vtable, ZVFS_MODE_IFSOCK);
73 
74 	/* Release FD that is no longer in use. */
75 	zvfs_free_fd(new_fd);
76 
77 	dispatcher_ctx_free(ctx);
78 
79 	return fd;
80 }
81 
sock_dispatch_find(int family,int type,int proto,bool native_only)82 static struct net_socket_register *sock_dispatch_find(int family, int type,
83 						      int proto, bool native_only)
84 {
85 	STRUCT_SECTION_FOREACH(net_socket_register, sock_family) {
86 		/* Ignore dispatcher itself. */
87 		if (sock_family->handler == sock_dispatch_create) {
88 			continue;
89 		}
90 
91 		if (native_only && sock_family->is_offloaded) {
92 			continue;
93 		}
94 
95 		if (sock_family->family != family &&
96 		    sock_family->family != AF_UNSPEC) {
97 			continue;
98 		}
99 
100 		NET_ASSERT(sock_family->is_supported);
101 
102 		if (!sock_family->is_supported(family, type, proto)) {
103 			continue;
104 		}
105 
106 		return sock_family;
107 	}
108 
109 	return NULL;
110 }
111 
sock_dispatch_native(struct dispatcher_context * ctx)112 static int sock_dispatch_native(struct dispatcher_context *ctx)
113 {
114 	struct net_socket_register *sock_family;
115 
116 	sock_family = sock_dispatch_find(ctx->family, ctx->type,
117 					 ctx->proto, true);
118 	if (sock_family == NULL) {
119 		errno = ENOENT;
120 		return -1;
121 	}
122 
123 	return sock_dispatch_socket(ctx, sock_family->handler);
124 }
125 
sock_dispatch_default(struct dispatcher_context * ctx)126 static int sock_dispatch_default(struct dispatcher_context *ctx)
127 {
128 	struct net_socket_register *sock_family;
129 
130 	sock_family = sock_dispatch_find(ctx->family, ctx->type,
131 					 ctx->proto, false);
132 	if (sock_family == NULL) {
133 		errno = ENOENT;
134 		return -1;
135 	}
136 
137 	return sock_dispatch_socket(ctx, sock_family->handler);
138 }
139 
sock_dispatch_read_vmeth(void * obj,void * buffer,size_t count)140 static ssize_t sock_dispatch_read_vmeth(void *obj, void *buffer, size_t count)
141 {
142 	int fd;
143 	const struct fd_op_vtable *vtable;
144 	void *new_obj;
145 
146 	fd = sock_dispatch_default(obj);
147 	if (fd < 0) {
148 		return -1;
149 	}
150 
151 	new_obj = zvfs_get_fd_obj_and_vtable(fd, &vtable, NULL);
152 	if (new_obj == NULL) {
153 		return -1;
154 	}
155 
156 	return vtable->read(new_obj, buffer, count);
157 }
158 
sock_dispatch_write_vmeth(void * obj,const void * buffer,size_t count)159 static ssize_t sock_dispatch_write_vmeth(void *obj, const void *buffer,
160 					 size_t count)
161 {
162 	int fd;
163 	const struct fd_op_vtable *vtable;
164 	void *new_obj;
165 
166 	fd = sock_dispatch_default(obj);
167 	if (fd < 0) {
168 		return -1;
169 	}
170 
171 	new_obj = zvfs_get_fd_obj_and_vtable(fd, &vtable, NULL);
172 	if (new_obj == NULL) {
173 		return -1;
174 	}
175 
176 	return vtable->write(new_obj, buffer, count);
177 }
178 
sock_dispatch_ioctl_vmeth(void * obj,unsigned int request,va_list args)179 static int sock_dispatch_ioctl_vmeth(void *obj, unsigned int request,
180 				     va_list args)
181 {
182 	int fd;
183 	const struct fd_op_vtable *vtable;
184 	void *new_obj;
185 
186 	if (request == ZFD_IOCTL_SET_LOCK) {
187 		/* Ignore set lock, used by FD logic. */
188 		return 0;
189 	}
190 
191 	fd = sock_dispatch_default(obj);
192 	if (fd < 0) {
193 		return -1;
194 	}
195 
196 	new_obj = zvfs_get_fd_obj_and_vtable(fd, &vtable, NULL);
197 	if (new_obj == NULL) {
198 		return -1;
199 	}
200 
201 	return vtable->ioctl(new_obj, request, args);
202 }
203 
sock_dispatch_shutdown_vmeth(void * obj,int how)204 static int sock_dispatch_shutdown_vmeth(void *obj, int how)
205 {
206 	int fd = sock_dispatch_default(obj);
207 
208 	if (fd < 0) {
209 		return -1;
210 	}
211 
212 	return zsock_shutdown(fd, how);
213 }
214 
sock_dispatch_bind_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)215 static int sock_dispatch_bind_vmeth(void *obj, const struct sockaddr *addr,
216 				    socklen_t addrlen)
217 {
218 	int fd = sock_dispatch_default(obj);
219 
220 	if (fd < 0) {
221 		return -1;
222 	}
223 
224 	return zsock_bind(fd, addr, addrlen);
225 }
226 
sock_dispatch_connect_vmeth(void * obj,const struct sockaddr * addr,socklen_t addrlen)227 static int sock_dispatch_connect_vmeth(void *obj, const struct sockaddr *addr,
228 				       socklen_t addrlen)
229 {
230 	int fd = sock_dispatch_default(obj);
231 
232 	if (fd < 0) {
233 		return -1;
234 	}
235 
236 	return zsock_connect(fd, addr, addrlen);
237 }
238 
sock_dispatch_listen_vmeth(void * obj,int backlog)239 static int sock_dispatch_listen_vmeth(void *obj, int backlog)
240 {
241 	int fd = sock_dispatch_default(obj);
242 
243 	if (fd < 0) {
244 		return -1;
245 	}
246 
247 	return zsock_listen(fd, backlog);
248 }
249 
sock_dispatch_accept_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)250 static int sock_dispatch_accept_vmeth(void *obj, struct sockaddr *addr,
251 				      socklen_t *addrlen)
252 {
253 	int fd = sock_dispatch_default(obj);
254 
255 	if (fd < 0) {
256 		return -1;
257 	}
258 
259 	return zsock_accept(fd, addr, addrlen);
260 }
261 
sock_dispatch_sendto_vmeth(void * obj,const void * buf,size_t len,int flags,const struct sockaddr * addr,socklen_t addrlen)262 static ssize_t sock_dispatch_sendto_vmeth(void *obj, const void *buf,
263 					  size_t len, int flags,
264 					  const struct sockaddr *addr,
265 					  socklen_t addrlen)
266 {
267 	int fd = sock_dispatch_default(obj);
268 
269 	if (fd < 0) {
270 		return -1;
271 	}
272 
273 	return zsock_sendto(fd, buf, len, flags, addr, addrlen);
274 }
275 
sock_dispatch_sendmsg_vmeth(void * obj,const struct msghdr * msg,int flags)276 static ssize_t sock_dispatch_sendmsg_vmeth(void *obj, const struct msghdr *msg,
277 					   int flags)
278 {
279 	int fd = sock_dispatch_default(obj);
280 
281 	if (fd < 0) {
282 		return -1;
283 	}
284 
285 	return zsock_sendmsg(fd, msg, flags);
286 }
287 
sock_dispatch_recvfrom_vmeth(void * obj,void * buf,size_t max_len,int flags,struct sockaddr * addr,socklen_t * addrlen)288 static ssize_t sock_dispatch_recvfrom_vmeth(void *obj, void *buf,
289 					    size_t max_len, int flags,
290 					    struct sockaddr *addr,
291 					    socklen_t *addrlen)
292 {
293 	int fd = sock_dispatch_default(obj);
294 
295 	if (fd < 0) {
296 		return -1;
297 	}
298 
299 	return zsock_recvfrom(fd, buf, max_len, flags, addr, addrlen);
300 }
301 
sock_dispatch_getsockopt_vmeth(void * obj,int level,int optname,void * optval,socklen_t * optlen)302 static int sock_dispatch_getsockopt_vmeth(void *obj, int level, int optname,
303 					  void *optval, socklen_t *optlen)
304 {
305 	int fd = sock_dispatch_default(obj);
306 
307 	if (fd < 0) {
308 		return -1;
309 	}
310 
311 	return zsock_getsockopt(fd, level, optname, optval, optlen);
312 }
313 
sock_dispatch_setsockopt_vmeth(void * obj,int level,int optname,const void * optval,socklen_t optlen)314 static int sock_dispatch_setsockopt_vmeth(void *obj, int level, int optname,
315 					  const void *optval, socklen_t optlen)
316 {
317 	int fd;
318 
319 	if ((level == SOL_SOCKET) && (optname == SO_BINDTODEVICE)) {
320 		struct net_if *iface;
321 		const struct ifreq *ifreq = optval;
322 
323 		if ((ifreq == NULL) || (optlen != sizeof(*ifreq))) {
324 			errno = EINVAL;
325 			return -1;
326 		}
327 
328 		if (IS_ENABLED(CONFIG_NET_INTERFACE_NAME)) {
329 			int ret;
330 
331 			ret = net_if_get_by_name(ifreq->ifr_name);
332 			if (ret < 0) {
333 				errno = -ret;
334 				return -1;
335 			}
336 
337 			iface = net_if_get_by_index(ret);
338 			if (iface == NULL) {
339 				errno = ENODEV;
340 				return -1;
341 			}
342 		} else {
343 			const struct device *dev;
344 
345 			dev = device_get_binding(ifreq->ifr_name);
346 			if (dev == NULL) {
347 				errno = ENODEV;
348 				return -1;
349 			}
350 
351 			iface = net_if_lookup_by_dev(dev);
352 			if (iface == NULL) {
353 				errno = ENODEV;
354 				return -1;
355 			}
356 		}
357 
358 		if (net_if_socket_offload(iface) != NULL) {
359 			/* Offloaded socket interface - use associated socket implementation. */
360 			fd = sock_dispatch_socket(obj, net_if_socket_offload(iface));
361 		} else {
362 			/* Native interface - use native socket implementation. */
363 			fd = sock_dispatch_native(obj);
364 		}
365 	} else if ((level == SOL_TLS) && (optname == TLS_NATIVE)) {
366 		const int *tls_native = optval;
367 		struct dispatcher_context *ctx = obj;
368 
369 		if ((tls_native == NULL) || (optlen != sizeof(int))) {
370 			errno = EINVAL;
371 			return -1;
372 		}
373 
374 		if (!is_tls(ctx->proto)) {
375 			errno = ENOPROTOOPT;
376 			return -1;
377 		}
378 
379 		if (*tls_native) {
380 			fd = sock_dispatch_native(obj);
381 		} else {
382 			/* No action needed */
383 			return 0;
384 		}
385 	} else {
386 		fd = sock_dispatch_default(obj);
387 	}
388 
389 	if (fd < 0) {
390 		return -1;
391 	}
392 
393 	return zsock_setsockopt(fd, level, optname, optval, optlen);
394 }
395 
sock_dispatch_close_vmeth(void * obj)396 static int sock_dispatch_close_vmeth(void *obj)
397 {
398 	dispatcher_ctx_free(obj);
399 
400 	return 0;
401 }
402 
sock_dispatch_getpeername_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)403 static int sock_dispatch_getpeername_vmeth(void *obj, struct sockaddr *addr,
404 					   socklen_t *addrlen)
405 {
406 	int fd = sock_dispatch_default(obj);
407 
408 	if (fd < 0) {
409 		return -1;
410 	}
411 
412 	return zsock_getpeername(fd, addr, addrlen);
413 }
414 
sock_dispatch_getsockname_vmeth(void * obj,struct sockaddr * addr,socklen_t * addrlen)415 static int sock_dispatch_getsockname_vmeth(void *obj, struct sockaddr *addr,
416 					   socklen_t *addrlen)
417 {
418 	int fd = sock_dispatch_default(obj);
419 
420 	if (fd < 0) {
421 		return -1;
422 	}
423 
424 	return zsock_getsockname(fd, addr, addrlen);
425 }
426 
427 static const struct socket_op_vtable sock_dispatch_fd_op_vtable = {
428 	.fd_vtable = {
429 		.read = sock_dispatch_read_vmeth,
430 		.write = sock_dispatch_write_vmeth,
431 		.close = sock_dispatch_close_vmeth,
432 		.ioctl = sock_dispatch_ioctl_vmeth,
433 	},
434 	.shutdown = sock_dispatch_shutdown_vmeth,
435 	.bind = sock_dispatch_bind_vmeth,
436 	.connect = sock_dispatch_connect_vmeth,
437 	.listen = sock_dispatch_listen_vmeth,
438 	.accept = sock_dispatch_accept_vmeth,
439 	.sendto = sock_dispatch_sendto_vmeth,
440 	.sendmsg = sock_dispatch_sendmsg_vmeth,
441 	.recvfrom = sock_dispatch_recvfrom_vmeth,
442 	.getsockopt = sock_dispatch_getsockopt_vmeth,
443 	.setsockopt = sock_dispatch_setsockopt_vmeth,
444 	.getpeername = sock_dispatch_getpeername_vmeth,
445 	.getsockname = sock_dispatch_getsockname_vmeth,
446 };
447 
sock_dispatch_create(int family,int type,int proto)448 static int sock_dispatch_create(int family, int type, int proto)
449 {
450 	struct dispatcher_context *entry = NULL;
451 	int fd = -1;
452 
453 	(void)k_mutex_lock(&dispatcher_lock, K_FOREVER);
454 
455 	for (int i = 0; i < ARRAY_SIZE(dispatcher_context); i++) {
456 		if (dispatcher_context[i].is_used) {
457 			continue;
458 		}
459 
460 		entry = &dispatcher_context[i];
461 		break;
462 	}
463 
464 	if (entry == NULL) {
465 		errno = ENOMEM;
466 		goto out;
467 	}
468 
469 	if (sock_dispatch_find(family, type, proto, false) == NULL) {
470 		errno = EAFNOSUPPORT;
471 		goto out;
472 	}
473 
474 	fd = zvfs_reserve_fd();
475 	if (fd < 0) {
476 		goto out;
477 	}
478 
479 	entry->fd = fd;
480 	entry->family = family;
481 	entry->type = type;
482 	entry->proto = proto;
483 	entry->is_used = true;
484 
485 	zvfs_finalize_typed_fd(fd, entry, (const struct fd_op_vtable *)&sock_dispatch_fd_op_vtable,
486 			    ZVFS_MODE_IFSOCK);
487 
488 out:
489 	k_mutex_unlock(&dispatcher_lock);
490 	return fd;
491 }
492 
is_supported(int family,int type,int proto)493 static bool is_supported(int family, int type, int proto)
494 {
495 	return true;
496 }
497 
498 NET_SOCKET_REGISTER(sock_dispatch, 0, AF_UNSPEC, is_supported,
499 		    sock_dispatch_create);
500