1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #ifndef ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
9 #define ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
10 
11 #include <zephyr/ztest.h>
12 #include <zephyr/kernel.h>
13 #include <stdlib.h>
14 #include <arm_math.h>
15 #ifdef CONFIG_CMSIS_DSP_FLOAT16
16 #include <arm_math_f16.h>
17 #endif
18 
19 #include "math_helper.h"
20 
21 #define ASSERT_MSG_BUFFER_ALLOC_FAILED		"buffer allocation failed"
22 #define ASSERT_MSG_SNR_LIMIT_EXCEED		"signal-to-noise ratio " \
23 						"error limit exceeded"
24 #define ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED	"absolute error limit exceeded"
25 #define ASSERT_MSG_REL_ERROR_LIMIT_EXCEED	"relative error limit exceeded"
26 #define ASSERT_MSG_ERROR_LIMIT_EXCEED		"error limit exceeded"
27 #define ASSERT_MSG_INCORRECT_COMP_RESULT	"incorrect computation result"
28 
29 #if defined(CONFIG_ZTEST_NEW_API)
30 #define DEFINE_TEST_VARIANT1(suite, name, variant, a1)                                             \
31 	ZTEST(suite, test_##name##_##variant)                                                      \
32 	{                                                                                          \
33 		test_##name(a1);                                                                   \
34 	}
35 
36 #define DEFINE_TEST_VARIANT2(suite, name, variant, a1, a2)                                         \
37 	ZTEST(suite, test_##name##_##variant)                                                      \
38 	{                                                                                          \
39 		test_##name(a1, a2);                                                               \
40 	}
41 
42 #define DEFINE_TEST_VARIANT3(suite, name, variant, a1, a2, a3)                                     \
43 	ZTEST(suite, test_##name##_##variant)                                                      \
44 	{                                                                                          \
45 		test_##name(a1, a2, a3);                                                           \
46 	}
47 
48 #define DEFINE_TEST_VARIANT4(suite, name, variant, a1, a2, a3, a4)                                 \
49 	ZTEST(suite, test_##name##_##variant)                                                      \
50 	{                                                                                          \
51 		test_##name(a1, a2, a3, a4);                                                       \
52 	}
53 
54 #define DEFINE_TEST_VARIANT5(suite, name, variant, a1, a2, a3, a4, a5)                             \
55 	ZTEST(suite, test_##name##_##variant)                                                      \
56 	{                                                                                          \
57 		test_##name(a1, a2, a3, a4, a5);                                                   \
58 	}
59 
60 #define DEFINE_TEST_VARIANT6(suite, name, variant, a1, a2, a3, a4, a5, a6)                         \
61 	ZTEST(suite, test_##name##_##variant)                                                      \
62 	{                                                                                          \
63 		test_##name(a1, a2, a3, a4, a5, a6);                                               \
64 	}
65 
66 #define DEFINE_TEST_VARIANT7(suite, name, variant, a1, a2, a3, a4, a5, a6, a7)                     \
67 	ZTEST(suite, test_##name##_##variant)                                                      \
68 	{                                                                                          \
69 		test_##name(a1, a2, a3, a4, a5, a6, a7);                                           \
70 	}
71 #else /* !defined(CONFIG_ZTEST_NEW_API) */
72 #define DEFINE_TEST_VARIANT1(name, variant, a1)		\
73 	static void test_##name##_##variant(void)	\
74 	{						\
75 		test_##name(a1);			\
76 	}
77 
78 #define DEFINE_TEST_VARIANT2(name, variant, a1, a2)	\
79 	static void test_##name##_##variant(void)	\
80 	{						\
81 		test_##name(a1, a2);			\
82 	}
83 
84 #define DEFINE_TEST_VARIANT3(name, variant, a1, a2, a3)	\
85 	static void test_##name##_##variant(void)	\
86 	{						\
87 		test_##name(a1, a2, a3);		\
88 	}
89 
90 #define DEFINE_TEST_VARIANT4(name, variant, a1, a2, a3, a4)	\
91 	static void test_##name##_##variant(void)		\
92 	{							\
93 		test_##name(a1, a2, a3, a4);			\
94 	}
95 
96 #define DEFINE_TEST_VARIANT5(name, variant, a1, a2, a3, a4, a5)	\
97 	static void test_##name##_##variant(void)		\
98 	{							\
99 		test_##name(a1, a2, a3, a4, a5);		\
100 	}
101 
102 #define DEFINE_TEST_VARIANT6(name, variant, a1, a2, a3, a4, a5, a6)	\
103 	static void test_##name##_##variant(void)			\
104 	{								\
105 		test_##name(a1, a2, a3, a4, a5, a6);			\
106 	}
107 
108 #define DEFINE_TEST_VARIANT7(name, variant, a1, a2, a3, a4, a5, a6, a7)	\
109 	static void test_##name##_##variant(void)			\
110 	{								\
111 		test_##name(a1, a2, a3, a4, a5, a6, a7);		\
112 	}
113 #endif /* !defined(CONFIG_ZTEST_NEW_API) */
114 
115 #pragma GCC diagnostic push
116 #pragma GCC diagnostic ignored "-Wunused-function"
117 
test_equal_f64(size_t length,const float64_t * a,const float64_t * b)118 static inline bool test_equal_f64(
119 	size_t length, const float64_t *a, const float64_t *b)
120 {
121 	size_t index;
122 
123 	for (index = 0; index < length; index++) {
124 		if (a[index] != b[index]) {
125 			return false;
126 		}
127 	}
128 
129 	return true;
130 }
131 
test_equal_f32(size_t length,const float32_t * a,const float32_t * b)132 static inline bool test_equal_f32(
133 	size_t length, const float32_t *a, const float32_t *b)
134 {
135 	size_t index;
136 
137 	for (index = 0; index < length; index++) {
138 		if (a[index] != b[index]) {
139 			return false;
140 		}
141 	}
142 
143 	return true;
144 }
145 
146 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_equal_f16(size_t length,const float16_t * a,const float16_t * b)147 static inline bool test_equal_f16(
148 	size_t length, const float16_t *a, const float16_t *b)
149 {
150 	size_t index;
151 
152 	for (index = 0; index < length; index++) {
153 		if (a[index] != b[index]) {
154 			return false;
155 		}
156 	}
157 
158 	return true;
159 }
160 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
161 
test_equal_q63(size_t length,const q63_t * a,const q63_t * b)162 static inline bool test_equal_q63(
163 	size_t length, const q63_t *a, const q63_t *b)
164 {
165 	size_t index;
166 
167 	for (index = 0; index < length; index++) {
168 		if (a[index] != b[index]) {
169 			return false;
170 		}
171 	}
172 
173 	return true;
174 }
175 
test_equal_q31(size_t length,const q31_t * a,const q31_t * b)176 static inline bool test_equal_q31(
177 	size_t length, const q31_t *a, const q31_t *b)
178 {
179 	size_t index;
180 
181 	for (index = 0; index < length; index++) {
182 		if (a[index] != b[index]) {
183 			return false;
184 		}
185 	}
186 
187 	return true;
188 }
189 
test_equal_q15(size_t length,const q15_t * a,const q15_t * b)190 static inline bool test_equal_q15(
191 	size_t length, const q15_t *a, const q15_t *b)
192 {
193 	size_t index;
194 
195 	for (index = 0; index < length; index++) {
196 		if (a[index] != b[index]) {
197 			return false;
198 		}
199 	}
200 
201 	return true;
202 }
203 
test_equal_q7(size_t length,const q7_t * a,const q7_t * b)204 static inline bool test_equal_q7(
205 	size_t length, const q7_t *a, const q7_t *b)
206 {
207 	size_t index;
208 
209 	for (index = 0; index < length; index++) {
210 		if (a[index] != b[index]) {
211 			return false;
212 		}
213 	}
214 
215 	return true;
216 }
217 
test_near_equal_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)218 static inline bool test_near_equal_f64(
219 	size_t length, const float64_t *a, const float64_t *b,
220 	float64_t threshold)
221 {
222 	size_t index;
223 
224 	for (index = 0; index < length; index++) {
225 		if (fabs(a[index] - b[index]) > threshold) {
226 			return false;
227 		}
228 	}
229 
230 	return true;
231 }
232 
test_near_equal_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)233 static inline bool test_near_equal_f32(
234 	size_t length, const float32_t *a, const float32_t *b,
235 	float32_t threshold)
236 {
237 	size_t index;
238 
239 	for (index = 0; index < length; index++) {
240 		if (fabsf(a[index] - b[index]) > threshold) {
241 			return false;
242 		}
243 	}
244 
245 	return true;
246 }
247 
248 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_near_equal_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)249 static inline bool test_near_equal_f16(
250 	size_t length, const float16_t *a, const float16_t *b,
251 	float16_t threshold)
252 {
253 	size_t index;
254 
255 	for (index = 0; index < length; index++) {
256 		if (fabsf((float)a[index] - (float)b[index]) > (float)threshold) {
257 			return false;
258 		}
259 	}
260 
261 	return true;
262 }
263 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
264 
test_near_equal_q63(size_t length,const q63_t * a,const q63_t * b,q63_t threshold)265 static inline bool test_near_equal_q63(
266 	size_t length, const q63_t *a, const q63_t *b, q63_t threshold)
267 {
268 	size_t index;
269 
270 	for (index = 0; index < length; index++) {
271 		if (llabs(a[index] - b[index]) > threshold) {
272 			return false;
273 		}
274 	}
275 
276 	return true;
277 }
278 
test_near_equal_q31(size_t length,const q31_t * a,const q31_t * b,q31_t threshold)279 static inline bool test_near_equal_q31(
280 	size_t length, const q31_t *a, const q31_t *b, q31_t threshold)
281 {
282 	size_t index;
283 
284 	for (index = 0; index < length; index++) {
285 		if (abs(a[index] - b[index]) > threshold) {
286 			return false;
287 		}
288 	}
289 
290 	return true;
291 }
292 
test_near_equal_q15(size_t length,const q15_t * a,const q15_t * b,q15_t threshold)293 static inline bool test_near_equal_q15(
294 	size_t length, const q15_t *a, const q15_t *b, q15_t threshold)
295 {
296 	size_t index;
297 
298 	for (index = 0; index < length; index++) {
299 		if (abs(a[index] - b[index]) > threshold) {
300 			return false;
301 		}
302 	}
303 
304 	return true;
305 }
306 
test_near_equal_q7(size_t length,const q7_t * a,const q7_t * b,q7_t threshold)307 static inline bool test_near_equal_q7(
308 	size_t length, const q7_t *a, const q7_t *b, q7_t threshold)
309 {
310 	size_t index;
311 
312 	for (index = 0; index < length; index++) {
313 		if (abs(a[index] - b[index]) > threshold) {
314 			return false;
315 		}
316 	}
317 
318 	return true;
319 }
320 
test_rel_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)321 static inline bool test_rel_error_f64(
322 	size_t length, const float64_t *a, const float64_t *b,
323 	float64_t threshold)
324 {
325 	size_t index;
326 	float64_t rel, delta, average;
327 
328 	for (index = 0; index < length; index++) {
329 		delta = fabs(a[index] - b[index]);
330 		average = (fabs(a[index]) + fabs(b[index])) / 2.0;
331 
332 		if (average != 0) {
333 			rel = delta / average;
334 
335 			if (rel > threshold) {
336 				return false;
337 			}
338 		}
339 	}
340 
341 	return true;
342 }
343 
test_rel_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)344 static inline bool test_rel_error_f32(
345 	size_t length, const float32_t *a, const float32_t *b,
346 	float32_t threshold)
347 {
348 	size_t index;
349 	float32_t rel, delta, average;
350 
351 	for (index = 0; index < length; index++) {
352 		delta = fabsf(a[index] - b[index]);
353 		average = (fabsf(a[index]) + fabsf(b[index])) / 2.0f;
354 
355 		if (average != 0) {
356 			rel = delta / average;
357 
358 			if (rel > threshold) {
359 				return false;
360 			}
361 		}
362 	}
363 
364 	return true;
365 }
366 
367 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_rel_error_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)368 static inline bool test_rel_error_f16(
369 	size_t length, const float16_t *a, const float16_t *b,
370 	float16_t threshold)
371 {
372 	size_t index;
373 	float32_t rel, delta, average;
374 
375 	for (index = 0; index < length; index++) {
376 		delta = fabsf((float)a[index] - (float)b[index]);
377 		average = (fabsf((float)a[index]) + fabsf((float)b[index])) / 2.0f;
378 
379 		if (average != 0) {
380 			rel = delta / average;
381 
382 			if (rel > threshold) {
383 				return false;
384 			}
385 		}
386 	}
387 
388 	return true;
389 }
390 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
391 
test_close_error_f64(size_t length,const float64_t * ref,const float64_t * val,float64_t abs_threshold,float64_t rel_threshold)392 static inline bool test_close_error_f64(
393 	size_t length, const float64_t *ref, const float64_t *val,
394 	float64_t abs_threshold, float64_t rel_threshold)
395 {
396 	size_t index;
397 
398 	for (index = 0; index < length; index++) {
399 		if (fabs(val[index] - ref[index]) >
400 			(abs_threshold + rel_threshold * fabs(ref[index]))) {
401 			return false;
402 		}
403 	}
404 
405 	return true;
406 }
407 
test_close_error_f32(size_t length,const float32_t * ref,const float32_t * val,float32_t abs_threshold,float32_t rel_threshold)408 static inline bool test_close_error_f32(
409 	size_t length, const float32_t *ref, const float32_t *val,
410 	float32_t abs_threshold, float32_t rel_threshold)
411 {
412 	size_t index;
413 
414 	for (index = 0; index < length; index++) {
415 		if (fabsf(val[index] - ref[index]) >
416 			(abs_threshold + rel_threshold * fabsf(ref[index]))) {
417 			return false;
418 		}
419 	}
420 
421 	return true;
422 }
423 
424 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_close_error_f16(size_t length,const float16_t * ref,const float16_t * val,float32_t abs_threshold,float32_t rel_threshold)425 static inline bool test_close_error_f16(
426 	size_t length, const float16_t *ref, const float16_t *val,
427 	float32_t abs_threshold, float32_t rel_threshold)
428 {
429 	size_t index;
430 
431 	for (index = 0; index < length; index++) {
432 		if (fabsf((float)val[index] - (float)ref[index]) >
433 			(abs_threshold + rel_threshold * fabsf((float)ref[index]))) {
434 			return false;
435 		}
436 	}
437 
438 	return true;
439 }
440 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
441 
test_snr_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)442 static inline bool test_snr_error_f64(
443 	size_t length, const float64_t *a, const float64_t *b,
444 	float64_t threshold)
445 {
446 	float64_t snr;
447 
448 	snr = arm_snr_f64(a, b, length);
449 	return (snr >= threshold);
450 }
451 
test_snr_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)452 static inline bool test_snr_error_f32(
453 	size_t length, const float32_t *a, const float32_t *b,
454 	float32_t threshold)
455 {
456 	float32_t snr;
457 
458 	snr = arm_snr_f32(a, b, length);
459 	return (snr >= threshold);
460 }
461 
462 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_snr_error_f16(size_t length,const float16_t * a,const float16_t * b,float32_t threshold)463 static inline bool test_snr_error_f16(
464 	size_t length, const float16_t *a, const float16_t *b,
465 	float32_t threshold)
466 {
467 	float32_t snr;
468 
469 	snr = arm_snr_f16(a, b, length);
470 	return (snr >= threshold);
471 }
472 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
473 
test_snr_error_q63(size_t length,const q63_t * a,const q63_t * b,float32_t threshold)474 static inline bool test_snr_error_q63(
475 	size_t length, const q63_t *a, const q63_t *b, float32_t threshold)
476 {
477 	float32_t snr;
478 
479 	snr = arm_snr_q63(a, b, length);
480 	return (snr >= threshold);
481 }
482 
test_snr_error_q31(size_t length,const q31_t * a,const q31_t * b,float32_t threshold)483 static inline bool test_snr_error_q31(
484 	size_t length, const q31_t *a, const q31_t *b, float32_t threshold)
485 {
486 	float32_t snr;
487 
488 	snr = arm_snr_q31(a, b, length);
489 	return (snr >= threshold);
490 }
491 
test_snr_error_q15(size_t length,const q15_t * a,const q15_t * b,float32_t threshold)492 static inline bool test_snr_error_q15(
493 	size_t length, const q15_t *a, const q15_t *b, float32_t threshold)
494 {
495 	float32_t snr;
496 
497 	snr = arm_snr_q15(a, b, length);
498 	return (snr >= threshold);
499 }
500 
test_snr_error_q7(size_t length,const q7_t * a,const q7_t * b,float32_t threshold)501 static inline bool test_snr_error_q7(
502 	size_t length, const q7_t *a, const q7_t *b, float32_t threshold)
503 {
504 	float32_t snr;
505 
506 	snr = arm_snr_q7(a, b, length);
507 	return (snr >= threshold);
508 }
509 
510 #pragma GCC diagnostic pop
511 
512 #endif /* ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_ */
513