1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "f16.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)48)
17 #define REL_ERROR_THRESH	(6.0e-3)
18 
19 #define SNR_ERROR_THRESH_KB	((float32_t)40)
20 #define REL_ERROR_THRESH_KB	(5.0e-3)
21 #define ABS_ERROR_THRESH_KB	(5.0e-3)
22 
23 #ifdef CONFIG_ARMV8_1_M_MVEF
24 /*
25  * NOTE: The MVE vector version of the statistics functions are slightly less
26  *       accurate than the scalar version.
27  */
28 #undef REL_ERROR_THRESH
29 #define REL_ERROR_THRESH	(10.0e-3)
30 
31 #undef SNR_ERROR_THRESH_KB
32 #define SNR_ERROR_THRESH_KB	((float32_t)39)
33 #endif
34 
test_arm_max_f16(const uint16_t * input1,int ref_index,size_t length)35 static void test_arm_max_f16(
36 	const uint16_t *input1, int ref_index, size_t length)
37 {
38 	float16_t val;
39 	uint32_t index;
40 
41 	/* Run test function */
42 	arm_max_f16((float16_t *)input1, length, &val, &index);
43 
44 	/* Validate output */
45 	zassert_equal(val, ((float16_t *)ref_max_val)[ref_index],
46 		ASSERT_MSG_INCORRECT_COMP_RESULT);
47 
48 	zassert_equal(index, ref_max_idx[ref_index],
49 		ASSERT_MSG_INCORRECT_COMP_RESULT);
50 }
51 
52 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 7, in_com1, 0, 7);
53 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 16, in_com1, 1, 16);
54 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_f16, 23, in_com1, 2, 23);
55 
test_arm_max_no_idx_f16(const uint16_t * input1,int ref_index,size_t length)56 static void test_arm_max_no_idx_f16(
57 	const uint16_t *input1, int ref_index, size_t length)
58 {
59 	float16_t val;
60 
61 	/* Run test function */
62 	arm_max_no_idx_f16((float16_t *)input1, length, &val);
63 
64 	/* Validate output */
65 	zassert_equal(val, ((float16_t *)ref_max_val)[ref_index],
66 		ASSERT_MSG_INCORRECT_COMP_RESULT);
67 }
68 
69 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 7, in_com1, 0, 7);
70 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 16, in_com1, 1, 16);
71 DEFINE_TEST_VARIANT3(statistics_f16, arm_max_no_idx_f16, 23, in_com1, 2, 23);
72 
test_arm_min_f16(const uint16_t * input1,int ref_index,size_t length)73 static void test_arm_min_f16(
74 	const uint16_t *input1, int ref_index, size_t length)
75 {
76 	float16_t val;
77 	uint32_t index;
78 
79 	/* Run test function */
80 	arm_min_f16((float16_t *)input1, length, &val, &index);
81 
82 	/* Validate output */
83 	zassert_equal(val, ((float16_t *)ref_min_val)[ref_index],
84 		ASSERT_MSG_INCORRECT_COMP_RESULT);
85 
86 	zassert_equal(index, ref_min_idx[ref_index],
87 		ASSERT_MSG_INCORRECT_COMP_RESULT);
88 }
89 
90 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 7, in_com1, 0, 7);
91 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 16, in_com1, 1, 16);
92 DEFINE_TEST_VARIANT3(statistics_f16, arm_min_f16, 23, in_com1, 2, 23);
93 
test_arm_absmax_f16(const uint16_t * input1,int ref_index,size_t length)94 static void test_arm_absmax_f16(
95 	const uint16_t *input1, int ref_index, size_t length)
96 {
97 	float16_t val;
98 	uint32_t index;
99 
100 	/* Run test function */
101 	arm_absmax_f16((float16_t *)input1, length, &val, &index);
102 
103 	/* Validate output */
104 	zassert_equal(val, ((float16_t *)ref_absmax_val)[ref_index],
105 		ASSERT_MSG_INCORRECT_COMP_RESULT);
106 
107 	zassert_equal(index, ref_absmax_idx[ref_index],
108 		ASSERT_MSG_INCORRECT_COMP_RESULT);
109 }
110 
111 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 7, in_absminmax, 0, 7);
112 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 16, in_absminmax, 1, 16);
113 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmax_f16, 23, in_absminmax, 2, 23);
114 
test_arm_absmin_f16(const uint16_t * input1,int ref_index,size_t length)115 static void test_arm_absmin_f16(
116 	const uint16_t *input1, int ref_index, size_t length)
117 {
118 	float16_t val;
119 	uint32_t index;
120 
121 	/* Run test function */
122 	arm_absmin_f16((float16_t *)input1, length, &val, &index);
123 
124 	/* Validate output */
125 	zassert_equal(val, ((float16_t *)ref_absmin_val)[ref_index],
126 		ASSERT_MSG_INCORRECT_COMP_RESULT);
127 
128 	zassert_equal(index, ref_absmin_idx[ref_index],
129 		ASSERT_MSG_INCORRECT_COMP_RESULT);
130 }
131 
132 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 7, in_absminmax, 0, 7);
133 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 16, in_absminmax, 1, 16);
134 DEFINE_TEST_VARIANT3(statistics_f16, arm_absmin_f16, 23, in_absminmax, 2, 23);
135 
test_arm_mean_f16(const uint16_t * input1,int ref_index,size_t length)136 static void test_arm_mean_f16(
137 	const uint16_t *input1, int ref_index, size_t length)
138 {
139 	float16_t ref[1];
140 	float16_t *output;
141 
142 	/* Load reference */
143 	ref[0] = ((float16_t *)ref_mean)[ref_index];
144 
145 	/* Allocate output buffer */
146 	output = malloc(1 * sizeof(float16_t));
147 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
148 
149 	/* Run test function */
150 	arm_mean_f16((float16_t *)input1, length, &output[0]);
151 
152 	/* Validate output */
153 	zassert_true(
154 		test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
155 		ASSERT_MSG_SNR_LIMIT_EXCEED);
156 
157 	zassert_true(
158 		test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
159 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
160 
161 	/* Free output buffer */
162 	free(output);
163 }
164 
165 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 7, in_com2, 0, 7);
166 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 16, in_com2, 1, 16);
167 DEFINE_TEST_VARIANT3(statistics_f16, arm_mean_f16, 23, in_com2, 2, 23);
168 
test_arm_power_f16(const uint16_t * input1,int ref_index,size_t length)169 static void test_arm_power_f16(
170 	const uint16_t *input1, int ref_index, size_t length)
171 {
172 	float16_t ref[1];
173 	float16_t *output;
174 
175 	/* Load reference */
176 	ref[0] = ((float16_t *)ref_power)[ref_index];
177 
178 	/* Allocate output buffer */
179 	output = malloc(1 * sizeof(float16_t));
180 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
181 
182 	/* Run test function */
183 	arm_power_f16((float16_t *)input1, length, &output[0]);
184 
185 	/* Validate output */
186 	zassert_true(
187 		test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
188 		ASSERT_MSG_SNR_LIMIT_EXCEED);
189 
190 	zassert_true(
191 		test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
192 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
193 
194 	/* Free output buffer */
195 	free(output);
196 }
197 
198 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 7, in_com1, 0, 7);
199 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 16, in_com1, 1, 16);
200 DEFINE_TEST_VARIANT3(statistics_f16, arm_power_f16, 23, in_com1, 2, 23);
201 
test_arm_rms_f16(const uint16_t * input1,int ref_index,size_t length)202 static void test_arm_rms_f16(
203 	const uint16_t *input1, int ref_index, size_t length)
204 {
205 	float16_t ref[1];
206 	float16_t *output;
207 
208 	/* Load reference */
209 	ref[0] = ((float16_t *)ref_rms)[ref_index];
210 
211 	/* Allocate output buffer */
212 	output = malloc(1 * sizeof(float16_t));
213 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
214 
215 	/* Run test function */
216 	arm_rms_f16((float16_t *)input1, length, &output[0]);
217 
218 	/* Validate output */
219 	zassert_true(
220 		test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
221 		ASSERT_MSG_SNR_LIMIT_EXCEED);
222 
223 	zassert_true(
224 		test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
225 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
226 
227 	/* Free output buffer */
228 	free(output);
229 }
230 
231 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 7, in_com1, 0, 7);
232 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 16, in_com1, 1, 16);
233 DEFINE_TEST_VARIANT3(statistics_f16, arm_rms_f16, 23, in_com1, 2, 23);
234 
test_arm_std_f16(const uint16_t * input1,int ref_index,size_t length)235 static void test_arm_std_f16(
236 	const uint16_t *input1, int ref_index, size_t length)
237 {
238 	float16_t ref[1];
239 	float16_t *output;
240 
241 	/* Load reference */
242 	ref[0] = ((float16_t *)ref_std)[ref_index];
243 
244 	/* Allocate output buffer */
245 	output = malloc(1 * sizeof(float16_t));
246 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
247 
248 	/* Run test function */
249 	arm_std_f16((float16_t *)input1, length, &output[0]);
250 
251 	/* Validate output */
252 	zassert_true(
253 		test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
254 		ASSERT_MSG_SNR_LIMIT_EXCEED);
255 
256 	zassert_true(
257 		test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
258 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
259 
260 	/* Free output buffer */
261 	free(output);
262 }
263 
264 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 7, in_com1, 0, 7);
265 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 16, in_com1, 1, 16);
266 DEFINE_TEST_VARIANT3(statistics_f16, arm_std_f16, 23, in_com1, 2, 23);
267 
test_arm_var_f16(const uint16_t * input1,int ref_index,size_t length)268 static void test_arm_var_f16(
269 	const uint16_t *input1, int ref_index, size_t length)
270 {
271 	float16_t ref[1];
272 	float16_t *output;
273 
274 	/* Load reference */
275 	ref[0] = ((float16_t *)ref_var)[ref_index];
276 
277 	/* Allocate output buffer */
278 	output = malloc(1 * sizeof(float16_t));
279 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
280 
281 	/* Run test function */
282 	arm_var_f16((float16_t *)input1, length, &output[0]);
283 
284 	/* Validate output */
285 	zassert_true(
286 		test_snr_error_f16(1, output, ref, SNR_ERROR_THRESH),
287 		ASSERT_MSG_SNR_LIMIT_EXCEED);
288 
289 	zassert_true(
290 		test_rel_error_f16(1, output, ref, REL_ERROR_THRESH),
291 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
292 
293 	/* Free output buffer */
294 	free(output);
295 }
296 
297 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 7, in_com1, 0, 7);
298 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 16, in_com1, 1, 16);
299 DEFINE_TEST_VARIANT3(statistics_f16, arm_var_f16, 23, in_com1, 2, 23);
300 
ZTEST(statistics_f16,test_arm_entropy_f16)301 ZTEST(statistics_f16, test_arm_entropy_f16)
302 {
303 	size_t index;
304 	size_t length = in_entropy_dim[0];
305 	const float16_t *ref = (float16_t *)ref_entropy;
306 	const float16_t *input = (float16_t *)in_entropy;
307 	float16_t *output;
308 
309 	__ASSERT_NO_MSG(ARRAY_SIZE(in_entropy_dim) > length);
310 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_entropy) >= length);
311 
312 	/* Allocate output buffer */
313 	output = malloc(length * sizeof(float16_t));
314 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
315 
316 	/* Run test function */
317 	for (index = 0; index < length; index++) {
318 		output[index] =
319 			arm_entropy_f16(input, in_entropy_dim[index + 1]);
320 		input += in_entropy_dim[index + 1];
321 	}
322 
323 	/* Validate output */
324 	zassert_true(
325 		test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
326 		ASSERT_MSG_SNR_LIMIT_EXCEED);
327 
328 	zassert_true(
329 		test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
330 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
331 
332 	/* Free output buffer */
333 	free(output);
334 }
335 
ZTEST(statistics_f16,test_arm_logsumexp_f16)336 ZTEST(statistics_f16, test_arm_logsumexp_f16)
337 {
338 	size_t index;
339 	size_t length = in_logsumexp_dim[0];
340 	const float16_t *ref = (float16_t *)ref_logsumexp;
341 	const float16_t *input = (float16_t *)in_logsumexp;
342 	float16_t *output;
343 
344 	__ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dim) > length);
345 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp) >= length);
346 
347 	/* Allocate output buffer */
348 	output = malloc(length * sizeof(float16_t));
349 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
350 
351 	/* Run test function */
352 	for (index = 0; index < length; index++) {
353 		output[index] =
354 			arm_logsumexp_f16(input, in_logsumexp_dim[index + 1]);
355 		input += in_logsumexp_dim[index + 1];
356 	}
357 
358 	/* Validate output */
359 	zassert_true(
360 		test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
361 		ASSERT_MSG_SNR_LIMIT_EXCEED);
362 
363 	zassert_true(
364 		test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
365 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
366 
367 	/* Free output buffer */
368 	free(output);
369 }
370 
ZTEST(statistics_f16,test_arm_kullback_leibler_f16)371 ZTEST(statistics_f16, test_arm_kullback_leibler_f16)
372 {
373 	size_t index;
374 	size_t length = in_kl_dim[0];
375 	const float16_t *ref = (float16_t *)ref_kl;
376 	const float16_t *input1 = (float16_t *)in_kl1;
377 	const float16_t *input2 = (float16_t *)in_kl2;
378 	float16_t *output;
379 
380 	__ASSERT_NO_MSG(ARRAY_SIZE(in_kl_dim) > length);
381 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_kl) >= length);
382 
383 	/* Allocate output buffer */
384 	output = malloc(length * sizeof(float16_t));
385 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
386 
387 	/* Run test function */
388 	for (index = 0; index < length; index++) {
389 		output[index] =
390 			arm_kullback_leibler_f16(
391 				input1, input2, in_kl_dim[index + 1]);
392 
393 		input1 += in_kl_dim[index + 1];
394 		input2 += in_kl_dim[index + 1];
395 	}
396 
397 	/* Validate output */
398 	zassert_true(
399 		test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH_KB),
400 		ASSERT_MSG_SNR_LIMIT_EXCEED);
401 
402 	zassert_true(
403 		test_close_error_f16(length, ref, output,
404 			ABS_ERROR_THRESH_KB, REL_ERROR_THRESH_KB),
405 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
406 
407 	/* Free output buffer */
408 	free(output);
409 }
410 
ZTEST(statistics_f16,test_arm_logsumexp_dot_prod_f16)411 ZTEST(statistics_f16, test_arm_logsumexp_dot_prod_f16)
412 {
413 	size_t index;
414 	size_t length = in_logsumexp_dp_dim[0];
415 	const float16_t *ref = (float16_t *)ref_logsumexp_dp;
416 	const float16_t *input1 = (float16_t *)in_logsumexp_dp1;
417 	const float16_t *input2 = (float16_t *)in_logsumexp_dp2;
418 	float16_t *output;
419 	float16_t *tmp;
420 
421 	__ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dp_dim) > length);
422 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp_dp) >= length);
423 
424 	/* Allocate buffers */
425 	output = malloc(length * sizeof(float16_t));
426 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
427 
428 	tmp = malloc(12 * sizeof(float16_t));
429 	zassert_not_null(tmp, ASSERT_MSG_BUFFER_ALLOC_FAILED);
430 
431 	/* Run test function */
432 	for (index = 0; index < length; index++) {
433 		output[index] =
434 			arm_logsumexp_dot_prod_f16(
435 				input1, input2,
436 				in_logsumexp_dp_dim[index + 1], tmp);
437 
438 		input1 += in_logsumexp_dp_dim[index + 1];
439 		input2 += in_logsumexp_dp_dim[index + 1];
440 	}
441 
442 	/* Validate output */
443 	zassert_true(
444 		test_snr_error_f16(length, ref, output, SNR_ERROR_THRESH),
445 		ASSERT_MSG_SNR_LIMIT_EXCEED);
446 
447 	zassert_true(
448 		test_near_equal_f16(length, ref, output, REL_ERROR_THRESH),
449 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
450 
451 	/* Free buffers */
452 	free(output);
453 	free(tmp);
454 }
455 
456 ZTEST_SUITE(statistics_f16, NULL, NULL, NULL, NULL, NULL);
457