1 /*
2  * Copyright (c) 2024 Endress+Hauser AG
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/net/dns_resolve.h>
8 #include "dns_cache.h"
9 
10 LOG_MODULE_REGISTER(net_dns_cache, CONFIG_DNS_RESOLVER_LOG_LEVEL);
11 
12 static void dns_cache_clean(struct dns_cache const *cache);
13 
dns_cache_flush(struct dns_cache * cache)14 int dns_cache_flush(struct dns_cache *cache)
15 {
16 	k_mutex_lock(cache->lock, K_FOREVER);
17 	for (size_t i = 0; i < cache->size; i++) {
18 		cache->entries[i].in_use = false;
19 	}
20 	k_mutex_unlock(cache->lock);
21 
22 	return 0;
23 }
24 
dns_cache_add(struct dns_cache * cache,char const * query,struct dns_addrinfo const * addrinfo,uint32_t ttl)25 int dns_cache_add(struct dns_cache *cache, char const *query, struct dns_addrinfo const *addrinfo,
26 		  uint32_t ttl)
27 {
28 	k_timepoint_t closest_to_expiry = sys_timepoint_calc(K_FOREVER);
29 	size_t index_to_replace = 0;
30 	bool found_empty = false;
31 
32 	if (cache == NULL || query == NULL || addrinfo == NULL || ttl == 0) {
33 		return -EINVAL;
34 	}
35 
36 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
37 		NET_WARN("Query string to big to be processed %u >= "
38 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
39 			 strlen(query));
40 		return -EINVAL;
41 	}
42 
43 	k_mutex_lock(cache->lock, K_FOREVER);
44 
45 	NET_DBG("Add \"%s\" with TTL %" PRIu32, query, ttl);
46 
47 	dns_cache_clean(cache);
48 
49 	for (size_t i = 0; i < cache->size; i++) {
50 		if (!cache->entries[i].in_use) {
51 			index_to_replace = i;
52 			found_empty = true;
53 			break;
54 		} else if (sys_timepoint_cmp(closest_to_expiry, cache->entries[i].expiry) > 0) {
55 			index_to_replace = i;
56 			closest_to_expiry = cache->entries[i].expiry;
57 		}
58 	}
59 
60 	if (!found_empty) {
61 		NET_DBG("Overwrite \"%s\"", cache->entries[index_to_replace].query);
62 	}
63 
64 	strncpy(cache->entries[index_to_replace].query, query,
65 		CONFIG_DNS_RESOLVER_MAX_QUERY_LEN - 1);
66 	cache->entries[index_to_replace].data = *addrinfo;
67 	cache->entries[index_to_replace].expiry = sys_timepoint_calc(K_SECONDS(ttl));
68 	cache->entries[index_to_replace].in_use = true;
69 
70 	k_mutex_unlock(cache->lock);
71 
72 	return 0;
73 }
74 
dns_cache_remove(struct dns_cache * cache,char const * query)75 int dns_cache_remove(struct dns_cache *cache, char const *query)
76 {
77 	NET_DBG("Remove all entries with query \"%s\"", query);
78 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
79 		NET_WARN("Query string to big to be processed %u >= "
80 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
81 			 strlen(query));
82 		return -EINVAL;
83 	}
84 
85 	k_mutex_lock(cache->lock, K_FOREVER);
86 
87 	dns_cache_clean(cache);
88 
89 	for (size_t i = 0; i < cache->size; i++) {
90 		if (cache->entries[i].in_use && strcmp(cache->entries[i].query, query) == 0) {
91 			cache->entries[i].in_use = false;
92 		}
93 	}
94 
95 	k_mutex_unlock(cache->lock);
96 
97 	return 0;
98 }
99 
dns_cache_find(struct dns_cache const * cache,const char * query,struct dns_addrinfo * addrinfo,size_t addrinfo_array_len)100 int dns_cache_find(struct dns_cache const *cache, const char *query, struct dns_addrinfo *addrinfo,
101 		   size_t addrinfo_array_len)
102 {
103 	size_t found = 0;
104 
105 	NET_DBG("Find \"%s\"", query);
106 	if (cache == NULL || query == NULL || addrinfo == NULL || addrinfo_array_len <= 0) {
107 		return -EINVAL;
108 	}
109 	if (strlen(query) >= CONFIG_DNS_RESOLVER_MAX_QUERY_LEN) {
110 		NET_WARN("Query string to big to be processed %u >= "
111 			 "CONFIG_DNS_RESOLVER_MAX_QUERY_LEN",
112 			 strlen(query));
113 		return -EINVAL;
114 	}
115 
116 	k_mutex_lock(cache->lock, K_FOREVER);
117 
118 	dns_cache_clean(cache);
119 
120 	for (size_t i = 0; i < cache->size; i++) {
121 		if (!cache->entries[i].in_use) {
122 			continue;
123 		}
124 		if (strcmp(cache->entries[i].query, query) != 0) {
125 			continue;
126 		}
127 		if (found >= addrinfo_array_len) {
128 			NET_WARN("Found \"%s\" but not enough space in provided buffer.", query);
129 			found++;
130 		} else {
131 			addrinfo[found] = cache->entries[i].data;
132 			found++;
133 			NET_DBG("Found \"%s\"", query);
134 		}
135 	}
136 
137 	k_mutex_unlock(cache->lock);
138 
139 	if (found > addrinfo_array_len) {
140 		return -ENOSR;
141 	}
142 
143 	if (found == 0) {
144 		NET_DBG("Could not find \"%s\"", query);
145 	}
146 	return found;
147 }
148 
149 /* Needs to be called when lock is already acquired */
dns_cache_clean(struct dns_cache const * cache)150 static void dns_cache_clean(struct dns_cache const *cache)
151 {
152 	for (size_t i = 0; i < cache->size; i++) {
153 		if (!cache->entries[i].in_use) {
154 			continue;
155 		}
156 
157 		if (sys_timepoint_expired(cache->entries[i].expiry)) {
158 			NET_DBG("Remove \"%s\"", cache->entries[i].query);
159 			cache->entries[i].in_use = false;
160 		}
161 	}
162 }
163