1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13
14 #include "f16.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)48)
17 #define REL_ERROR_THRESH (6.0e-3)
18
19 #define SNR_ERROR_THRESH_KB ((float32_t)40)
20 #define REL_ERROR_THRESH_KB (5.0e-3)
21 #define ABS_ERROR_THRESH_KB (5.0e-3)
22
23 #ifdef CONFIG_ARMV8_1_M_MVEF
24 /*
25 * NOTE: The MVE vector version of the statistics functions are slightly less
26 * accurate than the scalar version.
27 */
28 #undef REL_ERROR_THRESH
29 #define REL_ERROR_THRESH (10.0e-3)
30
31 #undef SNR_ERROR_THRESH_KB
32 #define SNR_ERROR_THRESH_KB ((float32_t)39)
33 #endif
34
test_arm_max_f16(const uint16_t * input1,int ref_index,size_t length)35 static void test_arm_max_f16(
36 const uint16_t *input1, int ref_index, size_t length)
37 {
38 float16_t val;
39 uint32_t index;
40
41 /* Run test function */
42 arm_max_f16((float16_t *)input1, length, &val, &index);
43
44 /* Validate output */
45 zassert_equal(val, ((float16_t *)ref_max_val)[ref_index],
46 ASSERT_MSG_INCORRECT_COMP_RESULT);
47
48 zassert_equal(index, ref_max_idx[ref_index],
49 ASSERT_MSG_INCORRECT_COMP_RESULT);
50 }
51
52 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 7, in_com1, 0, 7);
53 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 16, in_com1, 1, 16);
54 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 23, in_com1, 2, 23);
55
test_arm_max_no_idx_f16(const uint16_t * input1,int ref_index,size_t length)56 static void test_arm_max_no_idx_f16(
57 const uint16_t *input1, int ref_index, size_t length)
58 {
59 float16_t val;
60
61 /* Run test function */
62 arm_max_no_idx_f16((float16_t *)input1, length, &val);
63
64 /* Validate output */
65 zassert_equal(val, ((float16_t *)ref_max_val)[ref_index],
66 ASSERT_MSG_INCORRECT_COMP_RESULT);
67 }
68
69 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 7, in_com1, 0, 7);
70 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 16, in_com1, 1, 16);
71 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 23, in_com1, 2, 23);
72
test_arm_min_f16(const uint16_t * input1,int ref_index,size_t length)73 static void test_arm_min_f16(
74 const uint16_t *input1, int ref_index, size_t length)
75 {
76 float16_t val;
77 uint32_t index;
78
79 /* Run test function */
80 arm_min_f16((float16_t *)input1, length, &val, &index);
81
82 /* Validate output */
83 zassert_equal(val, ((float16_t *)ref_min_val)[ref_index],
84 ASSERT_MSG_INCORRECT_COMP_RESULT);
85
86 zassert_equal(index, ref_min_idx[ref_index],
87 ASSERT_MSG_INCORRECT_COMP_RESULT);
88 }
89
90 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 7, in_com1, 0, 7);
91 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 16, in_com1, 1, 16);
92 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 23, in_com1, 2, 23);
93
test_arm_absmax_f16(const uint16_t * input1,int ref_index,size_t length)94 static void test_arm_absmax_f16(
95 const uint16_t *input1, int ref_index, size_t length)
96 {
97 float16_t val;
98 uint32_t index;
99
100 /* Run test function */
101 arm_absmax_f16((float16_t *)input1, length, &val, &index);
102
103 /* Validate output */
104 zassert_equal(val, ((float16_t *)ref_absmax_val)[ref_index],
105 ASSERT_MSG_INCORRECT_COMP_RESULT);
106
107 zassert_equal(index, ref_absmax_idx[ref_index],
108 ASSERT_MSG_INCORRECT_COMP_RESULT);
109 }
110
111 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 7, in_absminmax, 0, 7);
112 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 16, in_absminmax, 1, 16);
113 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 23, in_absminmax, 2, 23);
114
test_arm_absmin_f16(const uint16_t * input1,int ref_index,size_t length)115 static void test_arm_absmin_f16(
116 const uint16_t *input1, int ref_index, size_t length)
117 {
118 float16_t val;
119 uint32_t index;
120
121 /* Run test function */
122 arm_absmin_f16((float16_t *)input1, length, &val, &index);
123
124 /* Validate output */
125 zassert_equal(val, ((float16_t *)ref_absmin_val)[ref_index],
126 ASSERT_MSG_INCORRECT_COMP_RESULT);
127
128 zassert_equal(index, ref_absmin_idx[ref_index],
129 ASSERT_MSG_INCORRECT_COMP_RESULT);
130 }
131
132 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 7, in_absminmax, 0, 7);
133 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 16, in_absminmax, 1, 16);
134 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 23, in_absminmax, 2, 23);
135
test_arm_mean_f16(const uint16_t * input1,int ref_index,size_t length)136 static void test_arm_mean_f16(
137 const uint16_t *input1, int ref_index, size_t length)
138 {
139 float16_t ref[1];
140 float16_t *output;
141
142 /* Load reference */
143 ref[0] = ((float16_t *)ref_mean)[ref_index];
144
145 /* Allocate output buffer */
146 output = malloc(1 * sizeof(float16_t));
147 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
148
149 /* Run test function */
150 arm_mean_f16((float16_t *)input1, length, &output[0]);
151
152 /* Validate output */
153 zassert_true(
154 test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
155 ASSERT_MSG_SNR_LIMIT_EXCEED);
156
157 zassert_true(
158 test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
159 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
160
161 /* Free output buffer */
162 free(output);
163 }
164
165 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 7, in_com2, 0, 7);
166 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 16, in_com2, 1, 16);
167 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 23, in_com2, 2, 23);
168
test_arm_power_f16(const uint16_t * input1,int ref_index,size_t length)169 static void test_arm_power_f16(
170 const uint16_t *input1, int ref_index, size_t length)
171 {
172 float16_t ref[1];
173 float16_t *output;
174
175 /* Load reference */
176 ref[0] = ((float16_t *)ref_power)[ref_index];
177
178 /* Allocate output buffer */
179 output = malloc(1 * sizeof(float16_t));
180 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
181
182 /* Run test function */
183 arm_power_f16((float16_t *)input1, length, &output[0]);
184
185 /* Validate output */
186 zassert_true(
187 test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
188 ASSERT_MSG_SNR_LIMIT_EXCEED);
189
190 zassert_true(
191 test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
192 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
193
194 /* Free output buffer */
195 free(output);
196 }
197
198 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 7, in_com1, 0, 7);
199 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 16, in_com1, 1, 16);
200 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 23, in_com1, 2, 23);
201
test_arm_rms_f16(const uint16_t * input1,int ref_index,size_t length)202 static void test_arm_rms_f16(
203 const uint16_t *input1, int ref_index, size_t length)
204 {
205 float16_t ref[1];
206 float16_t *output;
207
208 /* Load reference */
209 ref[0] = ((float16_t *)ref_rms)[ref_index];
210
211 /* Allocate output buffer */
212 output = malloc(1 * sizeof(float16_t));
213 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
214
215 /* Run test function */
216 arm_rms_f16((float16_t *)input1, length, &output[0]);
217
218 /* Validate output */
219 zassert_true(
220 test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
221 ASSERT_MSG_SNR_LIMIT_EXCEED);
222
223 zassert_true(
224 test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
225 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
226
227 /* Free output buffer */
228 free(output);
229 }
230
231 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 7, in_com1, 0, 7);
232 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 16, in_com1, 1, 16);
233 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 23, in_com1, 2, 23);
234
test_arm_std_f16(const uint16_t * input1,int ref_index,size_t length)235 static void test_arm_std_f16(
236 const uint16_t *input1, int ref_index, size_t length)
237 {
238 float16_t ref[1];
239 float16_t *output;
240
241 /* Load reference */
242 ref[0] = ((float16_t *)ref_std)[ref_index];
243
244 /* Allocate output buffer */
245 output = malloc(1 * sizeof(float16_t));
246 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
247
248 /* Run test function */
249 arm_std_f16((float16_t *)input1, length, &output[0]);
250
251 /* Validate output */
252 zassert_true(
253 test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
254 ASSERT_MSG_SNR_LIMIT_EXCEED);
255
256 zassert_true(
257 test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
258 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
259
260 /* Free output buffer */
261 free(output);
262 }
263
264 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 7, in_com1, 0, 7);
265 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 16, in_com1, 1, 16);
266 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 23, in_com1, 2, 23);
267
test_arm_var_f16(const uint16_t * input1,int ref_index,size_t length)268 static void test_arm_var_f16(
269 const uint16_t *input1, int ref_index, size_t length)
270 {
271 float16_t ref[1];
272 float16_t *output;
273
274 /* Load reference */
275 ref[0] = ((float16_t *)ref_var)[ref_index];
276
277 /* Allocate output buffer */
278 output = malloc(1 * sizeof(float16_t));
279 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
280
281 /* Run test function */
282 arm_var_f16((float16_t *)input1, length, &output[0]);
283
284 /* Validate output */
285 zassert_true(
286 test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
287 ASSERT_MSG_SNR_LIMIT_EXCEED);
288
289 zassert_true(
290 test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
291 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
292
293 /* Free output buffer */
294 free(output);
295 }
296
297 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 7, in_com1, 0, 7);
298 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 16, in_com1, 1, 16);
299 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 23, in_com1, 2, 23);
300
ZTEST(statistics_f16,test_arm_entropy_f16)301 ZTEST(statistics_f16, test_arm_entropy_f16)
302 {
303 size_t index;
304 size_t length = in_entropy_dim[0];
305 const float16_t *ref = (float16_t *)ref_entropy;
306 const float16_t *input = (float16_t *)in_entropy;
307 float16_t *output;
308
309 __ASSERT_NO_MSG(ARRAY_SIZE(in_entropy_dim) > length);
310 __ASSERT_NO_MSG(ARRAY_SIZE(ref_entropy) >= length);
311
312 /* Allocate output buffer */
313 output = malloc(length * sizeof(float16_t));
314 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
315
316 /* Run test function */
317 for (index = 0; index < length; index++) {
318 output[index] =
319 arm_entropy_f16(input, in_entropy_dim[index + 1]);
320 input += in_entropy_dim[index + 1];
321 }
322
323 /* Validate output */
324 zassert_true(
325 test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
326 ASSERT_MSG_SNR_LIMIT_EXCEED);
327
328 zassert_true(
329 test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
330 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
331
332 /* Free output buffer */
333 free(output);
334 }
335
ZTEST(statistics_f16,test_arm_logsumexp_f16)336 ZTEST(statistics_f16, test_arm_logsumexp_f16)
337 {
338 size_t index;
339 size_t length = in_logsumexp_dim[0];
340 const float16_t *ref = (float16_t *)ref_logsumexp;
341 const float16_t *input = (float16_t *)in_logsumexp;
342 float16_t *output;
343
344 __ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dim) > length);
345 __ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp) >= length);
346
347 /* Allocate output buffer */
348 output = malloc(length * sizeof(float16_t));
349 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
350
351 /* Run test function */
352 for (index = 0; index < length; index++) {
353 output[index] =
354 arm_logsumexp_f16(input, in_logsumexp_dim[index + 1]);
355 input += in_logsumexp_dim[index + 1];
356 }
357
358 /* Validate output */
359 zassert_true(
360 test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
361 ASSERT_MSG_SNR_LIMIT_EXCEED);
362
363 zassert_true(
364 test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
365 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
366
367 /* Free output buffer */
368 free(output);
369 }
370
ZTEST(statistics_f16,test_arm_kullback_leibler_f16)371 ZTEST(statistics_f16, test_arm_kullback_leibler_f16)
372 {
373 size_t index;
374 size_t length = in_kl_dim[0];
375 const float16_t *ref = (float16_t *)ref_kl;
376 const float16_t *input1 = (float16_t *)in_kl1;
377 const float16_t *input2 = (float16_t *)in_kl2;
378 float16_t *output;
379
380 __ASSERT_NO_MSG(ARRAY_SIZE(in_kl_dim) > length);
381 __ASSERT_NO_MSG(ARRAY_SIZE(ref_kl) >= length);
382
383 /* Allocate output buffer */
384 output = malloc(length * sizeof(float16_t));
385 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
386
387 /* Run test function */
388 for (index = 0; index < length; index++) {
389 output[index] =
390 arm_kullback_leibler_f16(
391 input1, input2, in_kl_dim[index + 1]);
392
393 input1 += in_kl_dim[index + 1];
394 input2 += in_kl_dim[index + 1];
395 }
396
397 /* Validate output */
398 zassert_true(
399 test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH_KB),
400 ASSERT_MSG_SNR_LIMIT_EXCEED);
401
402 zassert_true(
403 test_close_error_f16(length, ref, output,
404 ABS_ERROR_THRESH_KB, REL_ERROR_THRESH_KB),
405 ASSERT_MSG_ERROR_LIMIT_EXCEED);
406
407 /* Free output buffer */
408 free(output);
409 }
410
ZTEST(statistics_f16,test_arm_logsumexp_dot_prod_f16)411 ZTEST(statistics_f16, test_arm_logsumexp_dot_prod_f16)
412 {
413 size_t index;
414 size_t length = in_logsumexp_dp_dim[0];
415 const float16_t *ref = (float16_t *)ref_logsumexp_dp;
416 const float16_t *input1 = (float16_t *)in_logsumexp_dp1;
417 const float16_t *input2 = (float16_t *)in_logsumexp_dp2;
418 float16_t *output;
419 float16_t *tmp;
420
421 __ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dp_dim) > length);
422 __ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp_dp) >= length);
423
424 /* Allocate buffers */
425 output = malloc(length * sizeof(float16_t));
426 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
427
428 tmp = malloc(12 * sizeof(float16_t));
429 zassert_not_null(tmp, ASSERT_MSG_BUFFER_ALLOC_FAILED);
430
431 /* Run test function */
432 for (index = 0; index < length; index++) {
433 output[index] =
434 arm_logsumexp_dot_prod_f16(
435 input1, input2,
436 in_logsumexp_dp_dim[index + 1], tmp);
437
438 input1 += in_logsumexp_dp_dim[index + 1];
439 input2 += in_logsumexp_dp_dim[index + 1];
440 }
441
442 /* Validate output */
443 zassert_true(
444 test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
445 ASSERT_MSG_SNR_LIMIT_EXCEED);
446
447 zassert_true(
448 test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
449 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
450
451 /* Free buffers */
452 free(output);
453 free(tmp);
454 }
455
456 ZTEST_SUITE(statistics_f16, NULL, NULL, NULL, NULL, NULL);
457