1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #ifndef ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
9 #define ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_
10
11 #include <ztest.h>
12 #include <zephyr.h>
13 #include <stdlib.h>
14 #include <arm_math.h>
15 #ifdef CONFIG_CMSIS_DSP_FLOAT16
16 #include <arm_math_f16.h>
17 #endif
18
19 #include "math_helper.h"
20
21 #define ASSERT_MSG_BUFFER_ALLOC_FAILED "buffer allocation failed"
22 #define ASSERT_MSG_SNR_LIMIT_EXCEED "signal-to-noise ratio " \
23 "error limit exceeded"
24 #define ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED "absolute error limit exceeded"
25 #define ASSERT_MSG_REL_ERROR_LIMIT_EXCEED "relative error limit exceeded"
26 #define ASSERT_MSG_ERROR_LIMIT_EXCEED "error limit exceeded"
27 #define ASSERT_MSG_INCORRECT_COMP_RESULT "incorrect computation result"
28
29 #define DEFINE_TEST_VARIANT1(name, variant, a1) \
30 static void test_##name##_##variant(void) \
31 { \
32 test_##name(a1); \
33 }
34
35 #define DEFINE_TEST_VARIANT2(name, variant, a1, a2) \
36 static void test_##name##_##variant(void) \
37 { \
38 test_##name(a1, a2); \
39 }
40
41 #define DEFINE_TEST_VARIANT3(name, variant, a1, a2, a3) \
42 static void test_##name##_##variant(void) \
43 { \
44 test_##name(a1, a2, a3); \
45 }
46
47 #define DEFINE_TEST_VARIANT4(name, variant, a1, a2, a3, a4) \
48 static void test_##name##_##variant(void) \
49 { \
50 test_##name(a1, a2, a3, a4); \
51 }
52
53 #define DEFINE_TEST_VARIANT5(name, variant, a1, a2, a3, a4, a5) \
54 static void test_##name##_##variant(void) \
55 { \
56 test_##name(a1, a2, a3, a4, a5); \
57 }
58
59 #define DEFINE_TEST_VARIANT6(name, variant, a1, a2, a3, a4, a5, a6) \
60 static void test_##name##_##variant(void) \
61 { \
62 test_##name(a1, a2, a3, a4, a5, a6); \
63 }
64
65 #define DEFINE_TEST_VARIANT7(name, variant, a1, a2, a3, a4, a5, a6, a7) \
66 static void test_##name##_##variant(void) \
67 { \
68 test_##name(a1, a2, a3, a4, a5, a6, a7); \
69 }
70
71 #pragma GCC diagnostic push
72 #pragma GCC diagnostic ignored "-Wunused-function"
73
test_equal_f64(size_t length,const float64_t * a,const float64_t * b)74 static inline bool test_equal_f64(
75 size_t length, const float64_t *a, const float64_t *b)
76 {
77 size_t index;
78
79 for (index = 0; index < length; index++) {
80 if (a[index] != b[index]) {
81 return false;
82 }
83 }
84
85 return true;
86 }
87
test_equal_f32(size_t length,const float32_t * a,const float32_t * b)88 static inline bool test_equal_f32(
89 size_t length, const float32_t *a, const float32_t *b)
90 {
91 size_t index;
92
93 for (index = 0; index < length; index++) {
94 if (a[index] != b[index]) {
95 return false;
96 }
97 }
98
99 return true;
100 }
101
102 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_equal_f16(size_t length,const float16_t * a,const float16_t * b)103 static inline bool test_equal_f16(
104 size_t length, const float16_t *a, const float16_t *b)
105 {
106 size_t index;
107
108 for (index = 0; index < length; index++) {
109 if (a[index] != b[index]) {
110 return false;
111 }
112 }
113
114 return true;
115 }
116 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
117
test_equal_q63(size_t length,const q63_t * a,const q63_t * b)118 static inline bool test_equal_q63(
119 size_t length, const q63_t *a, const q63_t *b)
120 {
121 size_t index;
122
123 for (index = 0; index < length; index++) {
124 if (a[index] != b[index]) {
125 return false;
126 }
127 }
128
129 return true;
130 }
131
test_equal_q31(size_t length,const q31_t * a,const q31_t * b)132 static inline bool test_equal_q31(
133 size_t length, const q31_t *a, const q31_t *b)
134 {
135 size_t index;
136
137 for (index = 0; index < length; index++) {
138 if (a[index] != b[index]) {
139 return false;
140 }
141 }
142
143 return true;
144 }
145
test_equal_q15(size_t length,const q15_t * a,const q15_t * b)146 static inline bool test_equal_q15(
147 size_t length, const q15_t *a, const q15_t *b)
148 {
149 size_t index;
150
151 for (index = 0; index < length; index++) {
152 if (a[index] != b[index]) {
153 return false;
154 }
155 }
156
157 return true;
158 }
159
test_equal_q7(size_t length,const q7_t * a,const q7_t * b)160 static inline bool test_equal_q7(
161 size_t length, const q7_t *a, const q7_t *b)
162 {
163 size_t index;
164
165 for (index = 0; index < length; index++) {
166 if (a[index] != b[index]) {
167 return false;
168 }
169 }
170
171 return true;
172 }
173
test_near_equal_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)174 static inline bool test_near_equal_f64(
175 size_t length, const float64_t *a, const float64_t *b,
176 float64_t threshold)
177 {
178 size_t index;
179
180 for (index = 0; index < length; index++) {
181 if (fabs(a[index] - b[index]) > threshold) {
182 return false;
183 }
184 }
185
186 return true;
187 }
188
test_near_equal_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)189 static inline bool test_near_equal_f32(
190 size_t length, const float32_t *a, const float32_t *b,
191 float32_t threshold)
192 {
193 size_t index;
194
195 for (index = 0; index < length; index++) {
196 if (fabs(a[index] - b[index]) > threshold) {
197 return false;
198 }
199 }
200
201 return true;
202 }
203
204 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_near_equal_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)205 static inline bool test_near_equal_f16(
206 size_t length, const float16_t *a, const float16_t *b,
207 float16_t threshold)
208 {
209 size_t index;
210
211 for (index = 0; index < length; index++) {
212 if (fabs(a[index] - b[index]) > threshold) {
213 return false;
214 }
215 }
216
217 return true;
218 }
219 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
220
test_near_equal_q63(size_t length,const q63_t * a,const q63_t * b,q63_t threshold)221 static inline bool test_near_equal_q63(
222 size_t length, const q63_t *a, const q63_t *b, q63_t threshold)
223 {
224 size_t index;
225
226 for (index = 0; index < length; index++) {
227 if (llabs(a[index] - b[index]) > threshold) {
228 return false;
229 }
230 }
231
232 return true;
233 }
234
test_near_equal_q31(size_t length,const q31_t * a,const q31_t * b,q31_t threshold)235 static inline bool test_near_equal_q31(
236 size_t length, const q31_t *a, const q31_t *b, q31_t threshold)
237 {
238 size_t index;
239
240 for (index = 0; index < length; index++) {
241 if (abs(a[index] - b[index]) > threshold) {
242 return false;
243 }
244 }
245
246 return true;
247 }
248
test_near_equal_q15(size_t length,const q15_t * a,const q15_t * b,q15_t threshold)249 static inline bool test_near_equal_q15(
250 size_t length, const q15_t *a, const q15_t *b, q15_t threshold)
251 {
252 size_t index;
253
254 for (index = 0; index < length; index++) {
255 if (abs(a[index] - b[index]) > threshold) {
256 return false;
257 }
258 }
259
260 return true;
261 }
262
test_near_equal_q7(size_t length,const q7_t * a,const q7_t * b,q7_t threshold)263 static inline bool test_near_equal_q7(
264 size_t length, const q7_t *a, const q7_t *b, q7_t threshold)
265 {
266 size_t index;
267
268 for (index = 0; index < length; index++) {
269 if (abs(a[index] - b[index]) > threshold) {
270 return false;
271 }
272 }
273
274 return true;
275 }
276
test_rel_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)277 static inline bool test_rel_error_f64(
278 size_t length, const float64_t *a, const float64_t *b,
279 float64_t threshold)
280 {
281 size_t index;
282 float64_t rel, delta, average;
283
284 for (index = 0; index < length; index++) {
285 delta = fabs(a[index] - b[index]);
286 average = (fabs(a[index]) + fabs(b[index])) / 2.0;
287
288 if (average != 0) {
289 rel = delta / average;
290
291 if (rel > threshold) {
292 return false;
293 }
294 }
295 }
296
297 return true;
298 }
299
test_rel_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)300 static inline bool test_rel_error_f32(
301 size_t length, const float32_t *a, const float32_t *b,
302 float32_t threshold)
303 {
304 size_t index;
305 float32_t rel, delta, average;
306
307 for (index = 0; index < length; index++) {
308 delta = fabs(a[index] - b[index]);
309 average = (fabs(a[index]) + fabs(b[index])) / 2.0f;
310
311 if (average != 0) {
312 rel = delta / average;
313
314 if (rel > threshold) {
315 return false;
316 }
317 }
318 }
319
320 return true;
321 }
322
323 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_rel_error_f16(size_t length,const float16_t * a,const float16_t * b,float16_t threshold)324 static inline bool test_rel_error_f16(
325 size_t length, const float16_t *a, const float16_t *b,
326 float16_t threshold)
327 {
328 size_t index;
329 float32_t rel, delta, average;
330
331 for (index = 0; index < length; index++) {
332 delta = fabs(a[index] - b[index]);
333 average = (fabs(a[index]) + fabs(b[index])) / 2.0f;
334
335 if (average != 0) {
336 rel = delta / average;
337
338 if (rel > threshold) {
339 return false;
340 }
341 }
342 }
343
344 return true;
345 }
346 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
347
test_close_error_f64(size_t length,const float64_t * ref,const float64_t * val,float64_t abs_threshold,float64_t rel_threshold)348 static inline bool test_close_error_f64(
349 size_t length, const float64_t *ref, const float64_t *val,
350 float64_t abs_threshold, float64_t rel_threshold)
351 {
352 size_t index;
353
354 for (index = 0; index < length; index++) {
355 if (fabs(val[index] - ref[index]) >
356 (abs_threshold + rel_threshold * fabs(ref[index]))) {
357 return false;
358 }
359 }
360
361 return true;
362 }
363
test_close_error_f32(size_t length,const float32_t * ref,const float32_t * val,float32_t abs_threshold,float32_t rel_threshold)364 static inline bool test_close_error_f32(
365 size_t length, const float32_t *ref, const float32_t *val,
366 float32_t abs_threshold, float32_t rel_threshold)
367 {
368 size_t index;
369
370 for (index = 0; index < length; index++) {
371 if (fabs(val[index] - ref[index]) >
372 (abs_threshold + rel_threshold * fabs(ref[index]))) {
373 return false;
374 }
375 }
376
377 return true;
378 }
379
380 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_close_error_f16(size_t length,const float16_t * ref,const float16_t * val,float32_t abs_threshold,float32_t rel_threshold)381 static inline bool test_close_error_f16(
382 size_t length, const float16_t *ref, const float16_t *val,
383 float32_t abs_threshold, float32_t rel_threshold)
384 {
385 size_t index;
386
387 for (index = 0; index < length; index++) {
388 if (fabs(val[index] - ref[index]) >
389 (abs_threshold + rel_threshold * fabs(ref[index]))) {
390 return false;
391 }
392 }
393
394 return true;
395 }
396 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
397
test_snr_error_f64(size_t length,const float64_t * a,const float64_t * b,float64_t threshold)398 static inline bool test_snr_error_f64(
399 size_t length, const float64_t *a, const float64_t *b,
400 float64_t threshold)
401 {
402 float64_t snr;
403
404 snr = arm_snr_f64(a, b, length);
405 return (snr >= threshold);
406 }
407
test_snr_error_f32(size_t length,const float32_t * a,const float32_t * b,float32_t threshold)408 static inline bool test_snr_error_f32(
409 size_t length, const float32_t *a, const float32_t *b,
410 float32_t threshold)
411 {
412 float32_t snr;
413
414 snr = arm_snr_f32(a, b, length);
415 return (snr >= threshold);
416 }
417
418 #ifdef CONFIG_CMSIS_DSP_FLOAT16
test_snr_error_f16(size_t length,const float16_t * a,const float16_t * b,float32_t threshold)419 static inline bool test_snr_error_f16(
420 size_t length, const float16_t *a, const float16_t *b,
421 float32_t threshold)
422 {
423 float32_t snr;
424
425 snr = arm_snr_f16(a, b, length);
426 return (snr >= threshold);
427 }
428 #endif /* CONFIG_CMSIS_DSP_FLOAT16 */
429
test_snr_error_q63(size_t length,const q63_t * a,const q63_t * b,float32_t threshold)430 static inline bool test_snr_error_q63(
431 size_t length, const q63_t *a, const q63_t *b, float32_t threshold)
432 {
433 float32_t snr;
434
435 snr = arm_snr_q63(a, b, length);
436 return (snr >= threshold);
437 }
438
test_snr_error_q31(size_t length,const q31_t * a,const q31_t * b,float32_t threshold)439 static inline bool test_snr_error_q31(
440 size_t length, const q31_t *a, const q31_t *b, float32_t threshold)
441 {
442 float32_t snr;
443
444 snr = arm_snr_q31(a, b, length);
445 return (snr >= threshold);
446 }
447
test_snr_error_q15(size_t length,const q15_t * a,const q15_t * b,float32_t threshold)448 static inline bool test_snr_error_q15(
449 size_t length, const q15_t *a, const q15_t *b, float32_t threshold)
450 {
451 float32_t snr;
452
453 snr = arm_snr_q15(a, b, length);
454 return (snr >= threshold);
455 }
456
test_snr_error_q7(size_t length,const q7_t * a,const q7_t * b,float32_t threshold)457 static inline bool test_snr_error_q7(
458 size_t length, const q7_t *a, const q7_t *b, float32_t threshold)
459 {
460 float32_t snr;
461
462 snr = arm_snr_q7(a, b, length);
463 return (snr >= threshold);
464 }
465
466 #pragma GCC diagnostic pop
467
468 #endif /* ZEPHYR_TESTS_LIB_CMSIS_DSP_COMMON_TEST_COMMON_H_ */
469