1 /*
2  * Copyright (c) 2022 Meta
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <memory>
8 #include <unordered_map>
9 
10 #include <zephyr/sys/hash_map.h>
11 #include <zephyr/sys/hash_map_cxx.h>
12 
13 using cxx_map = std::unordered_map<uint64_t, uint64_t>;
14 
sys_hashmap_cxx_iter_next(struct sys_hashmap_iterator * it)15 static void sys_hashmap_cxx_iter_next(struct sys_hashmap_iterator *it)
16 {
17 	cxx_map *umap = static_cast<cxx_map *>(it->map->data->buckets);
18 
19 	__ASSERT_NO_MSG(umap != nullptr);
20 
21 	__ASSERT(it->size == it->map->data->size, "Concurrent modification!");
22 	__ASSERT(sys_hashmap_iterator_has_next(it), "Attempt to access beyond current bound!");
23 
24 	auto it2 = umap->begin();
25 	for (size_t i = 0; i < it->pos; ++i, it2++)
26 		;
27 
28 	it->key = it2->first;
29 	it->value = it2->second;
30 	++it->pos;
31 }
32 
33 /*
34  * C++ Wrapped Hashmap API
35  */
36 
sys_hashmap_cxx_iter(const struct sys_hashmap * map,struct sys_hashmap_iterator * it)37 static void sys_hashmap_cxx_iter(const struct sys_hashmap *map, struct sys_hashmap_iterator *it)
38 {
39 	it->map = map;
40 	it->next = sys_hashmap_cxx_iter_next;
41 	it->state = nullptr;
42 	it->key = 0;
43 	it->value = 0;
44 	it->pos = 0;
45 	*((size_t *)&it->size) = map->data->size;
46 }
47 
sys_hashmap_cxx_clear(struct sys_hashmap * map,sys_hashmap_callback_t cb,void * cookie)48 static void sys_hashmap_cxx_clear(struct sys_hashmap *map, sys_hashmap_callback_t cb, void *cookie)
49 {
50 	cxx_map *umap = static_cast<cxx_map *>(map->data->buckets);
51 
52 	if (umap == nullptr) {
53 		return;
54 	}
55 
56 	if (cb != nullptr) {
57 		for (auto &kv : *umap) {
58 			cb(kv.first, kv.second, cookie);
59 		}
60 	}
61 
62 	delete umap;
63 
64 	map->data->buckets = nullptr;
65 	map->data->n_buckets = 0;
66 	map->data->size = 0;
67 }
68 
sys_hashmap_cxx_insert(struct sys_hashmap * map,uint64_t key,uint64_t value,uint64_t * old_value)69 static int sys_hashmap_cxx_insert(struct sys_hashmap *map, uint64_t key, uint64_t value,
70 				  uint64_t *old_value)
71 {
72 	cxx_map *umap = static_cast<cxx_map *>(map->data->buckets);
73 
74 	if (umap == nullptr) {
75 		umap = new cxx_map;
76 		umap->max_load_factor(map->config->load_factor / 100.0f);
77 		map->data->buckets = umap;
78 	}
79 
80 	auto it = umap->find(key);
81 	if (it != umap->end() && old_value != nullptr) {
82 		*old_value = it->second;
83 		it->second = value;
84 		return 0;
85 	}
86 
87 	try {
88 		(*umap)[key] = value;
89 	} catch(...) {
90 		return -ENOMEM;
91 	}
92 
93 	++map->data->size;
94 	map->data->n_buckets = umap->bucket_count();
95 
96 	return 1;
97 }
98 
sys_hashmap_cxx_remove(struct sys_hashmap * map,uint64_t key,uint64_t * value)99 static bool sys_hashmap_cxx_remove(struct sys_hashmap *map, uint64_t key, uint64_t *value)
100 {
101 	cxx_map *umap = static_cast<cxx_map *>(map->data->buckets);
102 
103 	if (umap == nullptr) {
104 		return false;
105 	}
106 
107 	auto it = umap->find(key);
108 	if (it == umap->end()) {
109 		return false;
110 	}
111 
112 	if (value != nullptr) {
113 		*value = it->second;
114 	}
115 
116 	umap->erase(key);
117 	--map->data->size;
118 	map->data->n_buckets = umap->bucket_count();
119 
120 	if (map->data->size == 0) {
121 		delete umap;
122 		map->data->n_buckets = 0;
123 		map->data->buckets = nullptr;
124 	}
125 
126 	return true;
127 }
128 
sys_hashmap_cxx_get(const struct sys_hashmap * map,uint64_t key,uint64_t * value)129 static bool sys_hashmap_cxx_get(const struct sys_hashmap *map, uint64_t key, uint64_t *value)
130 {
131 	cxx_map *umap = static_cast<cxx_map *>(map->data->buckets);
132 
133 	if (umap == nullptr) {
134 		return false;
135 	}
136 
137 	auto it = umap->find(key);
138 	if (it == umap->end()) {
139 		return false;
140 	}
141 
142 	if (value != nullptr) {
143 		*value = it->second;
144 	}
145 
146 	return true;
147 }
148 
149 extern "C" {
150 const struct sys_hashmap_api sys_hashmap_cxx_api = {
151 	.iter = sys_hashmap_cxx_iter,
152 	.clear = sys_hashmap_cxx_clear,
153 	.insert = sys_hashmap_cxx_insert,
154 	.remove = sys_hashmap_cxx_remove,
155 	.get = sys_hashmap_cxx_get,
156 };
157 }
158