1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #ifndef ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
9 #define ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
10
11 #include <zephyr/ztest.h>
12 #include <zephyr/kernel.h>
13 #include <stdlib.h>
14 #include <arm_math.h>
15 #ifdef CONFIG_CMSIS_DSP_FLOAT16
16 #include <arm_math_f16.h>
17 #endif
18
19 #include "math_helper.h"
20
21 #define ASSERT_MSG_BUFFER_ALLOC_FAILED "buffer allocation failed"
22 #define ASSERT_MSG_SNR_LIMIT_EXCEED "signal-to-noise ratio " \
23 "error limit exceeded"
24 #define ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED "absolute error limit exceeded"
25 #define ASSERT_MSG_REL_ERROR_LIMIT_EXCEED "relative error limit exceeded"
26 #define ASSERT_MSG_ERROR_LIMIT_EXCEED "error limit exceeded"
27 #define ASSERT_MSG_INCORRECT_COMP_RESULT "incorrect computation result"
28
29 #if defined(CONFIG_ZTEST_NEW_API)
30 #define DEFINE_TEST_VARIANT1(suite, name, variant, a1) \
31 ZTEST(suite, test_##name##_##variant) \
32 { \
33 test_##name(a1); \
34 }
35
36 #define DEFINE_TEST_VARIANT2(suite, name, variant, a1, a2) \
37 ZTEST(suite, test_##name##_##variant) \
38 { \
39 test_##name(a1, a2); \
40 }
41
42 #define DEFINE_TEST_VARIANT3(suite, name, variant, a1, a2, a3) \
43 ZTEST(suite, test_##name##_##variant) \
44 { \
45 test_##name(a1, a2, a3); \
46 }
47
48 #define DEFINE_TEST_VARIANT4(suite, name, variant, a1, a2, a3, a4) \
49 ZTEST(suite, test_##name##_##variant) \
50 { \
51 test_##name(a1, a2, a3, a4); \
52 }
53
54 #define DEFINE_TEST_VARIANT5(suite, name, variant, a1, a2, a3, a4, a5) \
55 ZTEST(suite, test_##name##_##variant) \
56 { \
57 test_##name(a1, a2, a3, a4, a5); \
58 }
59
60 #define DEFINE_TEST_VARIANT6(suite, name, variant, a1, a2, a3, a4, a5, a6) \
61 ZTEST(suite, test_##name##_##variant) \
62 { \
63 test_##name(a1, a2, a3, a4, a5, a6); \
64 }
65
66 #define DEFINE_TEST_VARIANT7(suite, name, variant, a1, a2, a3, a4, a5, a6, a7) \
67 ZTEST(suite, test_##name##_##variant) \
68 { \
69 test_##name(a1, a2, a3, a4, a5, a6, a7); \
70 }
71 #else /* !defined(CONFIG_ZTEST_NEW_API) */
72 #define DEFINE_TEST_VARIANT1(name, variant, a1) \
73 static void test_##name##_##variant(void) \
74 { \
75 test_##name(a1); \
76 }
77
78 #define DEFINE_TEST_VARIANT2(name, variant, a1, a2) \
79 static void test_##name##_##variant(void) \
80 { \
81 test_##name(a1, a2); \
82 }
83
84 #define DEFINE_TEST_VARIANT3(name, variant, a1, a2, a3) \
85 static void test_##name##_##variant(void) \
86 { \
87 test_##name(a1, a2, a3); \
88 }
89
90 #define DEFINE_TEST_VARIANT4(name, variant, a1, a2, a3, a4) \
91 static void test_##name##_##variant(void) \
92 { \
93 test_##name(a1, a2, a3, a4); \
94 }
95
96 #define DEFINE_TEST_VARIANT5(name, variant, a1, a2, a3, a4, a5) \
97 static void test_##name##_##variant(void) \
98 { \
99 test_##name(a1, a2, a3, a4, a5); \
100 }
101
102 #define DEFINE_TEST_VARIANT6(name, variant, a1, a2, a3, a4, a5, a6) \
103 static void test_##name##_##variant(void) \
104 { \
105 test_##name(a1, a2, a3, a4, a5, a6); \
106 }
107
108 #define DEFINE_TEST_VARIANT7(name, variant, a1, a2, a3, a4, a5, a6, a7) \
109 static void test_##name##_##variant(void) \
110 { \
111 test_##name(a1, a2, a3, a4, a5, a6, a7); \
112 }
113 #endif /* !defined(CONFIG_ZTEST_NEW_API) */
114
115 #pragma GCC diagnostic push
116 #pragma GCC diagnostic ignored "-Wunused-function"
117
test_equal_f64(size_t length,const float64_t * a,const float64_t * b)118 static inline bool test_equal_f64(
119 size_t length, const float64_t *a, const float64_t *b)
120 {
121 size_t index;
122
123 for (index = 0; index < length; index++) {
124 if (a[index] != b[index]) {
125 return false;
126 }
127 }
128
129 return true;
130 }
131
test_equal_f32(size_t length,const float32_t * a,const float32_t * b)132 static inline bool test_equal_f32(
133 size_t length, const float32_t *a, const float32_t *b)
134 {
135 size_t index;
136
137 for (index = 0; index < length; index++) {
138 if (a[index] != b[index]) {
139 return false;
140 }
141 }
142
143 return true;
144 }
145
146 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_equal_f16(size_t length,const float16_t * a,const float16_t * b)147 static inline bool test_equal_f16(
148 size_t length, const float16_t *a, const float16_t *b)
149 {
150 size_t index;
151
152 for (index = 0; index < length; index++) {
153 if (a[index] != b[index]) {
154 return false;
155 }
156 }
157
158 return true;
159 }
160 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
161
test_equal_q63(size_t length,const q63_t * a,const q63_t * b)162 static inline bool test_equal_q63(
163 size_t length, const q63_t *a, const q63_t *b)
164 {
165 size_t index;
166
167 for (index = 0; index < length; index++) {
168 if (a[index] != b[index]) {
169 return false;
170 }
171 }
172
173 return true;
174 }
175
test_equal_q31(size_t length,const q31_t * a,const q31_t * b)176 static inline bool test_equal_q31(
177 size_t length, const q31_t *a, const q31_t *b)
178 {
179 size_t index;
180
181 for (index = 0; index < length; index++) {
182 if (a[index] != b[index]) {
183 return false;
184 }
185 }
186
187 return true;
188 }
189
test_equal_q15(size_t length,const q15_t * a,const q15_t * b)190 static inline bool test_equal_q15(
191 size_t length, const q15_t *a, const q15_t *b)
192 {
193 size_t index;
194
195 for (index = 0; index < length; index++) {
196 if (a[index] != b[index]) {
197 return false;
198 }
199 }
200
201 return true;
202 }
203
test_equal_q7(size_t length,const q7_t * a,const q7_t * b)204 static inline bool test_equal_q7(
205 size_t length, const q7_t *a, const q7_t *b)
206 {
207 size_t index;
208
209 for (index = 0; index < length; index++) {
210 if (a[index] != b[index]) {
211 return false;
212 }
213 }
214
215 return true;
216 }
217
test_near_equal_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)218 static inline bool test_near_equal_f64(
219 size_t length, const float64_t *a, const float64_t *b,
220 float64_t threshold)
221 {
222 size_t index;
223
224 for (index = 0; index < length; index++) {
225 if (fabs(a[index] - b[index]) > threshold) {
226 return false;
227 }
228 }
229
230 return true;
231 }
232
test_near_equal_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)233 static inline bool test_near_equal_f32(
234 size_t length, const float32_t *a, const float32_t *b,
235 float32_t threshold)
236 {
237 size_t index;
238
239 for (index = 0; index < length; index++) {
240 if (fabsf(a[index] - b[index]) > threshold) {
241 return false;
242 }
243 }
244
245 return true;
246 }
247
248 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_near_equal_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)249 static inline bool test_near_equal_f16(
250 size_t length, const float16_t *a, const float16_t *b,
251 float16_t threshold)
252 {
253 size_t index;
254
255 for (index = 0; index < length; index++) {
256 if (fabsf((float)a[index] - (float)b[index]) > (float)threshold) {
257 return false;
258 }
259 }
260
261 return true;
262 }
263 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
264
test_near_equal_q63(size_t length,const q63_t * a,const q63_t * b,q63_t threshold)265 static inline bool test_near_equal_q63(
266 size_t length, const q63_t *a, const q63_t *b, q63_t threshold)
267 {
268 size_t index;
269
270 for (index = 0; index < length; index++) {
271 if (llabs(a[index] - b[index]) > threshold) {
272 return false;
273 }
274 }
275
276 return true;
277 }
278
test_near_equal_q31(size_t length,const q31_t * a,const q31_t * b,q31_t threshold)279 static inline bool test_near_equal_q31(
280 size_t length, const q31_t *a, const q31_t *b, q31_t threshold)
281 {
282 size_t index;
283
284 for (index = 0; index < length; index++) {
285 if (abs(a[index] - b[index]) > threshold) {
286 return false;
287 }
288 }
289
290 return true;
291 }
292
test_near_equal_q15(size_t length,const q15_t * a,const q15_t * b,q15_t threshold)293 static inline bool test_near_equal_q15(
294 size_t length, const q15_t *a, const q15_t *b, q15_t threshold)
295 {
296 size_t index;
297
298 for (index = 0; index < length; index++) {
299 if (abs(a[index] - b[index]) > threshold) {
300 return false;
301 }
302 }
303
304 return true;
305 }
306
test_near_equal_q7(size_t length,const q7_t * a,const q7_t * b,q7_t threshold)307 static inline bool test_near_equal_q7(
308 size_t length, const q7_t *a, const q7_t *b, q7_t threshold)
309 {
310 size_t index;
311
312 for (index = 0; index < length; index++) {
313 if (abs(a[index] - b[index]) > threshold) {
314 return false;
315 }
316 }
317
318 return true;
319 }
320
test_rel_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)321 static inline bool test_rel_error_f64(
322 size_t length, const float64_t *a, const float64_t *b,
323 float64_t threshold)
324 {
325 size_t index;
326 float64_t rel, delta, average;
327
328 for (index = 0; index < length; index++) {
329 delta = fabs(a[index] - b[index]);
330 average = (fabs(a[index]) + fabs(b[index])) / 2.0;
331
332 if (average != 0) {
333 rel = delta / average;
334
335 if (rel > threshold) {
336 return false;
337 }
338 }
339 }
340
341 return true;
342 }
343
test_rel_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)344 static inline bool test_rel_error_f32(
345 size_t length, const float32_t *a, const float32_t *b,
346 float32_t threshold)
347 {
348 size_t index;
349 float32_t rel, delta, average;
350
351 for (index = 0; index < length; index++) {
352 delta = fabsf(a[index] - b[index]);
353 average = (fabsf(a[index]) + fabsf(b[index])) / 2.0f;
354
355 if (average != 0) {
356 rel = delta / average;
357
358 if (rel > threshold) {
359 return false;
360 }
361 }
362 }
363
364 return true;
365 }
366
367 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_rel_error_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)368 static inline bool test_rel_error_f16(
369 size_t length, const float16_t *a, const float16_t *b,
370 float16_t threshold)
371 {
372 size_t index;
373 float32_t rel, delta, average;
374
375 for (index = 0; index < length; index++) {
376 delta = fabsf((float)a[index] - (float)b[index]);
377 average = (fabsf((float)a[index]) + fabsf((float)b[index])) / 2.0f;
378
379 if (average != 0) {
380 rel = delta / average;
381
382 if (rel > threshold) {
383 return false;
384 }
385 }
386 }
387
388 return true;
389 }
390 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
391
test_close_error_f64(size_t length,const float64_t * ref,const float64_t * val,float64_t abs_threshold,float64_t rel_threshold)392 static inline bool test_close_error_f64(
393 size_t length, const float64_t *ref, const float64_t *val,
394 float64_t abs_threshold, float64_t rel_threshold)
395 {
396 size_t index;
397
398 for (index = 0; index < length; index++) {
399 if (fabs(val[index] - ref[index]) >
400 (abs_threshold + rel_threshold * fabs(ref[index]))) {
401 return false;
402 }
403 }
404
405 return true;
406 }
407
test_close_error_f32(size_t length,const float32_t * ref,const float32_t * val,float32_t abs_threshold,float32_t rel_threshold)408 static inline bool test_close_error_f32(
409 size_t length, const float32_t *ref, const float32_t *val,
410 float32_t abs_threshold, float32_t rel_threshold)
411 {
412 size_t index;
413
414 for (index = 0; index < length; index++) {
415 if (fabsf(val[index] - ref[index]) >
416 (abs_threshold + rel_threshold * fabsf(ref[index]))) {
417 return false;
418 }
419 }
420
421 return true;
422 }
423
424 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_close_error_f16(size_t length,const float16_t * ref,const float16_t * val,float32_t abs_threshold,float32_t rel_threshold)425 static inline bool test_close_error_f16(
426 size_t length, const float16_t *ref, const float16_t *val,
427 float32_t abs_threshold, float32_t rel_threshold)
428 {
429 size_t index;
430
431 for (index = 0; index < length; index++) {
432 if (fabsf((float)val[index] - (float)ref[index]) >
433 (abs_threshold + rel_threshold * fabsf((float)ref[index]))) {
434 return false;
435 }
436 }
437
438 return true;
439 }
440 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
441
test_snr_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)442 static inline bool test_snr_error_f64(
443 size_t length, const float64_t *a, const float64_t *b,
444 float64_t threshold)
445 {
446 float64_t snr;
447
448 snr = arm_snr_f64(a, b, length);
449 return (snr >= threshold);
450 }
451
test_snr_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)452 static inline bool test_snr_error_f32(
453 size_t length, const float32_t *a, const float32_t *b,
454 float32_t threshold)
455 {
456 float32_t snr;
457
458 snr = arm_snr_f32(a, b, length);
459 return (snr >= threshold);
460 }
461
462 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_snr_error_f16(size_t length,const float16_t * a,const float16_t * b,float32_t threshold)463 static inline bool test_snr_error_f16(
464 size_t length, const float16_t *a, const float16_t *b,
465 float32_t threshold)
466 {
467 float32_t snr;
468
469 snr = arm_snr_f16(a, b, length);
470 return (snr >= threshold);
471 }
472 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
473
test_snr_error_q63(size_t length,const q63_t * a,const q63_t * b,float32_t threshold)474 static inline bool test_snr_error_q63(
475 size_t length, const q63_t *a, const q63_t *b, float32_t threshold)
476 {
477 float32_t snr;
478
479 snr = arm_snr_q63(a, b, length);
480 return (snr >= threshold);
481 }
482
test_snr_error_q31(size_t length,const q31_t * a,const q31_t * b,float32_t threshold)483 static inline bool test_snr_error_q31(
484 size_t length, const q31_t *a, const q31_t *b, float32_t threshold)
485 {
486 float32_t snr;
487
488 snr = arm_snr_q31(a, b, length);
489 return (snr >= threshold);
490 }
491
test_snr_error_q15(size_t length,const q15_t * a,const q15_t * b,float32_t threshold)492 static inline bool test_snr_error_q15(
493 size_t length, const q15_t *a, const q15_t *b, float32_t threshold)
494 {
495 float32_t snr;
496
497 snr = arm_snr_q15(a, b, length);
498 return (snr >= threshold);
499 }
500
test_snr_error_q7(size_t length,const q7_t * a,const q7_t * b,float32_t threshold)501 static inline bool test_snr_error_q7(
502 size_t length, const q7_t *a, const q7_t *b, float32_t threshold)
503 {
504 float32_t snr;
505
506 snr = arm_snr_q7(a, b, length);
507 return (snr >= threshold);
508 }
509
510 #pragma GCC diagnostic pop
511
512 #endif /* ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_ */
513