1 /*
2  * Copyright (c) 2024 Endress+Hauser AG
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/net/dns_resolve.h>
8 #include <zephyr/net/net_ip.h>
9 #include "dns_cache.h"
10 
11 LOG_MODULE_REGISTER(net_dns_cache, CONFIG_DNS_RESOLVER_LOG_LEVEL);
12 
13 static void dns_cache_clean(struct dns_cache const *cache);
14 
dns_cache_flush(struct dns_cache * cache)15 int dns_cache_flush(struct dns_cache *cache)
16 {
17 	k_mutex_lock(cache->lock, K_FOREVER);
18 	for (size_t i = 0; i < cache->size; i++) {
19 		cache->entries[i].in_use = false;
20 	}
21 	k_mutex_unlock(cache->lock);
22 
23 	return 0;
24 }
25 
dns_cache_add(struct dns_cache * cache,char const * query,struct dns_addrinfo const * addrinfo,uint32_t ttl)26 int dns_cache_add(struct dns_cache *cache, char const *query, struct dns_addrinfo const *addrinfo,
27 		  uint32_t ttl)
28 {
29 	k_timepoint_t closest_to_expiry = sys_timepoint_calc(K_FOREVER);
30 	size_t index_to_replace = 0;
31 	bool found_empty = false;
32 
33 	if (cache == NULL || query == NULL || addrinfo == NULL || ttl == 0) {
34 		return -EINVAL;
35 	}
36 
37 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
38 		NET_WARN("Query string to big to be processed %u >= "
39 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
40 			 strlen(query));
41 		return -EINVAL;
42 	}
43 
44 	k_mutex_lock(cache->lock, K_FOREVER);
45 
46 	NET_DBG("Add \"%s\" with TTL %" PRIu32, query, ttl);
47 
48 	dns_cache_clean(cache);
49 
50 	for (size_t i = 0; i < cache->size; i++) {
51 		if (!cache->entries[i].in_use) {
52 			index_to_replace = i;
53 			found_empty = true;
54 			break;
55 		} else if (sys_timepoint_cmp(closest_to_expiry, cache->entries[i].expiry) > 0) {
56 			index_to_replace = i;
57 			closest_to_expiry = cache->entries[i].expiry;
58 		}
59 	}
60 
61 	if (!found_empty) {
62 		NET_DBG("Overwrite \"%s\"", cache->entries[index_to_replace].query);
63 	}
64 
65 	strncpy(cache->entries[index_to_replace].query, query,
66 		CONFIG_DNS_RESOLVER_MAX_QUERY_LEN - 1);
67 	cache->entries[index_to_replace].data = *addrinfo;
68 	cache->entries[index_to_replace].expiry = sys_timepoint_calc(K_SECONDS(ttl));
69 	cache->entries[index_to_replace].in_use = true;
70 
71 	k_mutex_unlock(cache->lock);
72 
73 	return 0;
74 }
75 
dns_cache_remove(struct dns_cache * cache,char const * query)76 int dns_cache_remove(struct dns_cache *cache, char const *query)
77 {
78 	NET_DBG("Remove all entries with query \"%s\"", query);
79 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
80 		NET_WARN("Query string to big to be processed %u >= "
81 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
82 			 strlen(query));
83 		return -EINVAL;
84 	}
85 
86 	k_mutex_lock(cache->lock, K_FOREVER);
87 
88 	dns_cache_clean(cache);
89 
90 	for (size_t i = 0; i < cache->size; i++) {
91 		if (cache->entries[i].in_use && strcmp(cache->entries[i].query, query) == 0) {
92 			cache->entries[i].in_use = false;
93 		}
94 	}
95 
96 	k_mutex_unlock(cache->lock);
97 
98 	return 0;
99 }
100 
dns_cache_find(struct dns_cache const * cache,const char * query,enum dns_query_type type,struct dns_addrinfo * addrinfo,size_t addrinfo_array_len)101 int dns_cache_find(struct dns_cache const *cache, const char *query, enum dns_query_type type,
102 		   struct dns_addrinfo *addrinfo, size_t addrinfo_array_len)
103 {
104 	size_t found = 0;
105 	sa_family_t family;
106 
107 	NET_DBG("Find \"%s\"", query);
108 	if (cache == NULL || query == NULL || addrinfo == NULL || addrinfo_array_len <= 0) {
109 		return -EINVAL;
110 	}
111 	if (type == DNS_QUERY_TYPE_A) {
112 		family = AF_INET;
113 	} else if (type == DNS_QUERY_TYPE_AAAA) {
114 		family = AF_INET6;
115 	} else {
116 		return -EINVAL;
117 	}
118 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
119 		NET_WARN("Query string to big to be processed %u >= "
120 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
121 			 strlen(query));
122 		return -EINVAL;
123 	}
124 
125 	k_mutex_lock(cache->lock, K_FOREVER);
126 
127 	dns_cache_clean(cache);
128 
129 	for (size_t i = 0; i < cache->size; i++) {
130 		if (!cache->entries[i].in_use) {
131 			continue;
132 		}
133 		if (strcmp(cache->entries[i].query, query) != 0) {
134 			continue;
135 		}
136 		if (cache->entries[i].data.ai_family != family) {
137 			continue;
138 		}
139 		if (found >= addrinfo_array_len) {
140 			NET_WARN("Found \"%s\" but not enough space in provided buffer.", query);
141 			found++;
142 		} else {
143 			addrinfo[found] = cache->entries[i].data;
144 			found++;
145 			NET_DBG("Found \"%s\"", query);
146 		}
147 	}
148 
149 	k_mutex_unlock(cache->lock);
150 
151 	if (found > addrinfo_array_len) {
152 		return -ENOSR;
153 	}
154 
155 	if (found == 0) {
156 		NET_DBG("Could not find \"%s\"", query);
157 	}
158 	return found;
159 }
160 
161 /* Needs to be called when lock is already acquired */
dns_cache_clean(struct dns_cache const * cache)162 static void dns_cache_clean(struct dns_cache const *cache)
163 {
164 	for (size_t i = 0; i < cache->size; i++) {
165 		if (!cache->entries[i].in_use) {
166 			continue;
167 		}
168 
169 		if (sys_timepoint_expired(cache->entries[i].expiry)) {
170 			NET_DBG("Remove \"%s\"", cache->entries[i].query);
171 			cache->entries[i].in_use = false;
172 		}
173 	}
174 }
175