1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #ifndef ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
9 #define ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
10 
11 #include <zephyr/ztest.h>
12 #include <zephyr/kernel.h>
13 #include <stdlib.h>
14 #include <arm_math.h>
15 #ifdef CONFIG_CMSIS_DSP_FLOAT16
16 #include <arm_math_f16.h>
17 #endif
18 
19 #include "math_helper.h"
20 
21 #define ASSERT_MSG_BUFFER_ALLOC_FAILED		"buffer allocation failed"
22 #define ASSERT_MSG_SNR_LIMIT_EXCEED		"signal-to-noise ratio " \
23 						"error limit exceeded"
24 #define ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED	"absolute error limit exceeded"
25 #define ASSERT_MSG_REL_ERROR_LIMIT_EXCEED	"relative error limit exceeded"
26 #define ASSERT_MSG_ERROR_LIMIT_EXCEED		"error limit exceeded"
27 #define ASSERT_MSG_INCORRECT_COMP_RESULT	"incorrect computation result"
28 
29 #define DEFINE_TEST_VARIANT1(suite, name, variant, a1)                                             \
30 	ZTEST(suite, test_##name##_##variant)                                                      \
31 	{                                                                                          \
32 		test_##name(a1);                                                                   \
33 	}
34 
35 #define DEFINE_TEST_VARIANT2(suite, name, variant, a1, a2)                                         \
36 	ZTEST(suite, test_##name##_##variant)                                                      \
37 	{                                                                                          \
38 		test_##name(a1, a2);                                                               \
39 	}
40 
41 #define DEFINE_TEST_VARIANT3(suite, name, variant, a1, a2, a3)                                     \
42 	ZTEST(suite, test_##name##_##variant)                                                      \
43 	{                                                                                          \
44 		test_##name(a1, a2, a3);                                                           \
45 	}
46 
47 #define DEFINE_TEST_VARIANT4(suite, name, variant, a1, a2, a3, a4)                                 \
48 	ZTEST(suite, test_##name##_##variant)                                                      \
49 	{                                                                                          \
50 		test_##name(a1, a2, a3, a4);                                                       \
51 	}
52 
53 #define DEFINE_TEST_VARIANT5(suite, name, variant, a1, a2, a3, a4, a5)                             \
54 	ZTEST(suite, test_##name##_##variant)                                                      \
55 	{                                                                                          \
56 		test_##name(a1, a2, a3, a4, a5);                                                   \
57 	}
58 
59 #define DEFINE_TEST_VARIANT6(suite, name, variant, a1, a2, a3, a4, a5, a6)                         \
60 	ZTEST(suite, test_##name##_##variant)                                                      \
61 	{                                                                                          \
62 		test_##name(a1, a2, a3, a4, a5, a6);                                               \
63 	}
64 
65 #define DEFINE_TEST_VARIANT7(suite, name, variant, a1, a2, a3, a4, a5, a6, a7)                     \
66 	ZTEST(suite, test_##name##_##variant)                                                      \
67 	{                                                                                          \
68 		test_##name(a1, a2, a3, a4, a5, a6, a7);                                           \
69 	}
70 
71 #pragma GCC diagnostic push
72 #pragma GCC diagnostic ignored "-Wunused-function"
73 
test_equal_f64(size_t length,const float64_t * a,const float64_t * b)74 static inline bool test_equal_f64(
75 	size_t length, const float64_t *a, const float64_t *b)
76 {
77 	size_t index;
78 
79 	for (index = 0; index < length; index++) {
80 		if (a[index] != b[index]) {
81 			return false;
82 		}
83 	}
84 
85 	return true;
86 }
87 
test_equal_f32(size_t length,const float32_t * a,const float32_t * b)88 static inline bool test_equal_f32(
89 	size_t length, const float32_t *a, const float32_t *b)
90 {
91 	size_t index;
92 
93 	for (index = 0; index < length; index++) {
94 		if (a[index] != b[index]) {
95 			return false;
96 		}
97 	}
98 
99 	return true;
100 }
101 
102 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_equal_f16(size_t length,const float16_t * a,const float16_t * b)103 static inline bool test_equal_f16(
104 	size_t length, const float16_t *a, const float16_t *b)
105 {
106 	size_t index;
107 
108 	for (index = 0; index < length; index++) {
109 		if (a[index] != b[index]) {
110 			return false;
111 		}
112 	}
113 
114 	return true;
115 }
116 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
117 
test_equal_q63(size_t length,const q63_t * a,const q63_t * b)118 static inline bool test_equal_q63(
119 	size_t length, const q63_t *a, const q63_t *b)
120 {
121 	size_t index;
122 
123 	for (index = 0; index < length; index++) {
124 		if (a[index] != b[index]) {
125 			return false;
126 		}
127 	}
128 
129 	return true;
130 }
131 
test_equal_q31(size_t length,const q31_t * a,const q31_t * b)132 static inline bool test_equal_q31(
133 	size_t length, const q31_t *a, const q31_t *b)
134 {
135 	size_t index;
136 
137 	for (index = 0; index < length; index++) {
138 		if (a[index] != b[index]) {
139 			return false;
140 		}
141 	}
142 
143 	return true;
144 }
145 
test_equal_q15(size_t length,const q15_t * a,const q15_t * b)146 static inline bool test_equal_q15(
147 	size_t length, const q15_t *a, const q15_t *b)
148 {
149 	size_t index;
150 
151 	for (index = 0; index < length; index++) {
152 		if (a[index] != b[index]) {
153 			return false;
154 		}
155 	}
156 
157 	return true;
158 }
159 
test_equal_q7(size_t length,const q7_t * a,const q7_t * b)160 static inline bool test_equal_q7(
161 	size_t length, const q7_t *a, const q7_t *b)
162 {
163 	size_t index;
164 
165 	for (index = 0; index < length; index++) {
166 		if (a[index] != b[index]) {
167 			return false;
168 		}
169 	}
170 
171 	return true;
172 }
173 
test_near_equal_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)174 static inline bool test_near_equal_f64(
175 	size_t length, const float64_t *a, const float64_t *b,
176 	float64_t threshold)
177 {
178 	size_t index;
179 
180 	for (index = 0; index < length; index++) {
181 		if (fabs(a[index] - b[index]) > threshold) {
182 			return false;
183 		}
184 	}
185 
186 	return true;
187 }
188 
test_near_equal_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)189 static inline bool test_near_equal_f32(
190 	size_t length, const float32_t *a, const float32_t *b,
191 	float32_t threshold)
192 {
193 	size_t index;
194 
195 	for (index = 0; index < length; index++) {
196 		if (fabsf(a[index] - b[index]) > threshold) {
197 			return false;
198 		}
199 	}
200 
201 	return true;
202 }
203 
204 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_near_equal_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)205 static inline bool test_near_equal_f16(
206 	size_t length, const float16_t *a, const float16_t *b,
207 	float16_t threshold)
208 {
209 	size_t index;
210 
211 	for (index = 0; index < length; index++) {
212 		if (fabsf((float)a[index] - (float)b[index]) > (float)threshold) {
213 			return false;
214 		}
215 	}
216 
217 	return true;
218 }
219 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
220 
test_near_equal_q63(size_t length,const q63_t * a,const q63_t * b,q63_t threshold)221 static inline bool test_near_equal_q63(
222 	size_t length, const q63_t *a, const q63_t *b, q63_t threshold)
223 {
224 	size_t index;
225 
226 	for (index = 0; index < length; index++) {
227 		if (llabs(a[index] - b[index]) > threshold) {
228 			return false;
229 		}
230 	}
231 
232 	return true;
233 }
234 
test_near_equal_q31(size_t length,const q31_t * a,const q31_t * b,q31_t threshold)235 static inline bool test_near_equal_q31(
236 	size_t length, const q31_t *a, const q31_t *b, q31_t threshold)
237 {
238 	size_t index;
239 
240 	for (index = 0; index < length; index++) {
241 		if (abs(a[index] - b[index]) > threshold) {
242 			return false;
243 		}
244 	}
245 
246 	return true;
247 }
248 
test_near_equal_q15(size_t length,const q15_t * a,const q15_t * b,q15_t threshold)249 static inline bool test_near_equal_q15(
250 	size_t length, const q15_t *a, const q15_t *b, q15_t threshold)
251 {
252 	size_t index;
253 
254 	for (index = 0; index < length; index++) {
255 		if (abs(a[index] - b[index]) > threshold) {
256 			return false;
257 		}
258 	}
259 
260 	return true;
261 }
262 
test_near_equal_q7(size_t length,const q7_t * a,const q7_t * b,q7_t threshold)263 static inline bool test_near_equal_q7(
264 	size_t length, const q7_t *a, const q7_t *b, q7_t threshold)
265 {
266 	size_t index;
267 
268 	for (index = 0; index < length; index++) {
269 		if (abs(a[index] - b[index]) > threshold) {
270 			return false;
271 		}
272 	}
273 
274 	return true;
275 }
276 
test_rel_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)277 static inline bool test_rel_error_f64(
278 	size_t length, const float64_t *a, const float64_t *b,
279 	float64_t threshold)
280 {
281 	size_t index;
282 	float64_t rel, delta, average;
283 
284 	for (index = 0; index < length; index++) {
285 		delta = fabs(a[index] - b[index]);
286 		average = (fabs(a[index]) + fabs(b[index])) / 2.0;
287 
288 		if (average != 0) {
289 			rel = delta / average;
290 
291 			if (rel > threshold) {
292 				return false;
293 			}
294 		}
295 	}
296 
297 	return true;
298 }
299 
test_rel_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)300 static inline bool test_rel_error_f32(
301 	size_t length, const float32_t *a, const float32_t *b,
302 	float32_t threshold)
303 {
304 	size_t index;
305 	float32_t rel, delta, average;
306 
307 	for (index = 0; index < length; index++) {
308 		delta = fabsf(a[index] - b[index]);
309 		average = (fabsf(a[index]) + fabsf(b[index])) / 2.0f;
310 
311 		if (average != 0) {
312 			rel = delta / average;
313 
314 			if (rel > threshold) {
315 				return false;
316 			}
317 		}
318 	}
319 
320 	return true;
321 }
322 
323 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_rel_error_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)324 static inline bool test_rel_error_f16(
325 	size_t length, const float16_t *a, const float16_t *b,
326 	float16_t threshold)
327 {
328 	size_t index;
329 	float32_t rel, delta, average;
330 
331 	for (index = 0; index < length; index++) {
332 		delta = fabsf((float)a[index] - (float)b[index]);
333 		average = (fabsf((float)a[index]) + fabsf((float)b[index])) / 2.0f;
334 
335 		if (average != 0) {
336 			rel = delta / average;
337 
338 			if (rel > threshold) {
339 				return false;
340 			}
341 		}
342 	}
343 
344 	return true;
345 }
346 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
347 
test_close_error_f64(size_t length,const float64_t * ref,const float64_t * val,float64_t abs_threshold,float64_t rel_threshold)348 static inline bool test_close_error_f64(
349 	size_t length, const float64_t *ref, const float64_t *val,
350 	float64_t abs_threshold, float64_t rel_threshold)
351 {
352 	size_t index;
353 
354 	for (index = 0; index < length; index++) {
355 		if (fabs(val[index] - ref[index]) >
356 			(abs_threshold + rel_threshold * fabs(ref[index]))) {
357 			return false;
358 		}
359 	}
360 
361 	return true;
362 }
363 
test_close_error_f32(size_t length,const float32_t * ref,const float32_t * val,float32_t abs_threshold,float32_t rel_threshold)364 static inline bool test_close_error_f32(
365 	size_t length, const float32_t *ref, const float32_t *val,
366 	float32_t abs_threshold, float32_t rel_threshold)
367 {
368 	size_t index;
369 
370 	for (index = 0; index < length; index++) {
371 		if (fabsf(val[index] - ref[index]) >
372 			(abs_threshold + rel_threshold * fabsf(ref[index]))) {
373 			return false;
374 		}
375 	}
376 
377 	return true;
378 }
379 
380 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_close_error_f16(size_t length,const float16_t * ref,const float16_t * val,float32_t abs_threshold,float32_t rel_threshold)381 static inline bool test_close_error_f16(
382 	size_t length, const float16_t *ref, const float16_t *val,
383 	float32_t abs_threshold, float32_t rel_threshold)
384 {
385 	size_t index;
386 
387 	for (index = 0; index < length; index++) {
388 		if (fabsf((float)val[index] - (float)ref[index]) >
389 			(abs_threshold + rel_threshold * fabsf((float)ref[index]))) {
390 			return false;
391 		}
392 	}
393 
394 	return true;
395 }
396 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
397 
test_snr_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)398 static inline bool test_snr_error_f64(
399 	size_t length, const float64_t *a, const float64_t *b,
400 	float64_t threshold)
401 {
402 	float64_t snr;
403 
404 	snr = arm_snr_f64(a, b, length);
405 	return (snr >= threshold);
406 }
407 
test_snr_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)408 static inline bool test_snr_error_f32(
409 	size_t length, const float32_t *a, const float32_t *b,
410 	float32_t threshold)
411 {
412 	float32_t snr;
413 
414 	snr = arm_snr_f32(a, b, length);
415 	return (snr >= threshold);
416 }
417 
418 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_snr_error_f16(size_t length,const float16_t * a,const float16_t * b,float32_t threshold)419 static inline bool test_snr_error_f16(
420 	size_t length, const float16_t *a, const float16_t *b,
421 	float32_t threshold)
422 {
423 	float32_t snr;
424 
425 	snr = arm_snr_f16(a, b, length);
426 	return (snr >= threshold);
427 }
428 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
429 
test_snr_error_q63(size_t length,const q63_t * a,const q63_t * b,float32_t threshold)430 static inline bool test_snr_error_q63(
431 	size_t length, const q63_t *a, const q63_t *b, float32_t threshold)
432 {
433 	float32_t snr;
434 
435 	snr = arm_snr_q63(a, b, length);
436 	return (snr >= threshold);
437 }
438 
test_snr_error_q31(size_t length,const q31_t * a,const q31_t * b,float32_t threshold)439 static inline bool test_snr_error_q31(
440 	size_t length, const q31_t *a, const q31_t *b, float32_t threshold)
441 {
442 	float32_t snr;
443 
444 	snr = arm_snr_q31(a, b, length);
445 	return (snr >= threshold);
446 }
447 
test_snr_error_q15(size_t length,const q15_t * a,const q15_t * b,float32_t threshold)448 static inline bool test_snr_error_q15(
449 	size_t length, const q15_t *a, const q15_t *b, float32_t threshold)
450 {
451 	float32_t snr;
452 
453 	snr = arm_snr_q15(a, b, length);
454 	return (snr >= threshold);
455 }
456 
test_snr_error_q7(size_t length,const q7_t * a,const q7_t * b,float32_t threshold)457 static inline bool test_snr_error_q7(
458 	size_t length, const q7_t *a, const q7_t *b, float32_t threshold)
459 {
460 	float32_t snr;
461 
462 	snr = arm_snr_q7(a, b, length);
463 	return (snr >= threshold);
464 }
465 
466 #pragma GCC diagnostic pop
467 
468 #endif /* ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_ */
469