1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "f32.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)120)
17 #define REL_ERROR_THRESH	(1.0e-5)
18 
test_arm_max_f32(const uint32_t * input1,int ref_index,size_t length)19 static void test_arm_max_f32(
20 	const uint32_t *input1, int ref_index, size_t length)
21 {
22 	float32_t val;
23 	uint32_t index;
24 
25 	/* Run test function */
26 	arm_max_f32((float32_t *)input1, length, &val, &index);
27 
28 	/* Validate output */
29 	zassert_equal(val, ((float32_t *)ref_max_val)[ref_index],
30 		ASSERT_MSG_INCORRECT_COMP_RESULT);
31 
32 	zassert_equal(index, ref_max_idx[ref_index],
33 		ASSERT_MSG_INCORRECT_COMP_RESULT);
34 }
35 
36 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_f32, 3, in_com1, 0, 3);
37 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_f32, 8, in_com1, 1, 8);
38 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_f32, 11, in_com1, 2, 11);
39 
test_arm_max_no_idx_f32(const uint32_t * input1,int ref_index,size_t length)40 static void test_arm_max_no_idx_f32(
41 	const uint32_t *input1, int ref_index, size_t length)
42 {
43 	float32_t val;
44 
45 	/* Run test function */
46 	arm_max_no_idx_f32((float32_t *)input1, length, &val);
47 
48 	/* Validate output */
49 	zassert_equal(val, ((float32_t *)ref_max_val)[ref_index],
50 		ASSERT_MSG_INCORRECT_COMP_RESULT);
51 }
52 
53 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_no_idx_f32, 3, in_com1, 0, 3);
54 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_no_idx_f32, 8, in_com1, 1, 8);
55 DEFINE_TEST_VARIANT3(statistics_f32, arm_max_no_idx_f32, 11, in_com1, 2, 11);
56 
test_arm_min_f32(const uint32_t * input1,int ref_index,size_t length)57 static void test_arm_min_f32(
58 	const uint32_t *input1, int ref_index, size_t length)
59 {
60 	float32_t val;
61 	uint32_t index;
62 
63 	/* Run test function */
64 	arm_min_f32((float32_t *)input1, length, &val, &index);
65 
66 	/* Validate output */
67 	zassert_equal(val, ((float32_t *)ref_min_val)[ref_index],
68 		ASSERT_MSG_INCORRECT_COMP_RESULT);
69 
70 	zassert_equal(index, ref_min_idx[ref_index],
71 		ASSERT_MSG_INCORRECT_COMP_RESULT);
72 }
73 
74 DEFINE_TEST_VARIANT3(statistics_f32, arm_min_f32, 3, in_com1, 0, 3);
75 DEFINE_TEST_VARIANT3(statistics_f32, arm_min_f32, 8, in_com1, 1, 8);
76 DEFINE_TEST_VARIANT3(statistics_f32, arm_min_f32, 11, in_com1, 2, 11);
77 
test_arm_absmax_f32(const uint32_t * input1,int ref_index,size_t length)78 static void test_arm_absmax_f32(
79 	const uint32_t *input1, int ref_index, size_t length)
80 {
81 	float32_t val;
82 	uint32_t index;
83 
84 	/* Run test function */
85 	arm_absmax_f32((float32_t *)input1, length, &val, &index);
86 
87 	/* Validate output */
88 	zassert_equal(val, ((float32_t *)ref_absmax_val)[ref_index],
89 		ASSERT_MSG_INCORRECT_COMP_RESULT);
90 
91 	zassert_equal(index, ref_absmax_idx[ref_index],
92 		ASSERT_MSG_INCORRECT_COMP_RESULT);
93 }
94 
95 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmax_f32, 3, in_absminmax, 0, 3);
96 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmax_f32, 8, in_absminmax, 1, 8);
97 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmax_f32, 11, in_absminmax, 2, 11);
98 
test_arm_absmin_f32(const uint32_t * input1,int ref_index,size_t length)99 static void test_arm_absmin_f32(
100 	const uint32_t *input1, int ref_index, size_t length)
101 {
102 	float32_t val;
103 	uint32_t index;
104 
105 	/* Run test function */
106 	arm_absmin_f32((float32_t *)input1, length, &val, &index);
107 
108 	/* Validate output */
109 	zassert_equal(val, ((float32_t *)ref_absmin_val)[ref_index],
110 		ASSERT_MSG_INCORRECT_COMP_RESULT);
111 
112 	zassert_equal(index, ref_absmin_idx[ref_index],
113 		ASSERT_MSG_INCORRECT_COMP_RESULT);
114 }
115 
116 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmin_f32, 3, in_absminmax, 0, 3);
117 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmin_f32, 8, in_absminmax, 1, 8);
118 DEFINE_TEST_VARIANT3(statistics_f32, arm_absmin_f32, 11, in_absminmax, 2, 11);
119 
test_arm_mean_f32(const uint32_t * input1,int ref_index,size_t length)120 static void test_arm_mean_f32(
121 	const uint32_t *input1, int ref_index, size_t length)
122 {
123 	float32_t ref[1];
124 	float32_t *output;
125 
126 	/* Load reference */
127 	ref[0] = ((float32_t *)ref_mean)[ref_index];
128 
129 	/* Allocate output buffer */
130 	output = malloc(1 * sizeof(float32_t));
131 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
132 
133 	/* Run test function */
134 	arm_mean_f32((float32_t *)input1, length, &output[0]);
135 
136 	/* Validate output */
137 	zassert_true(
138 		test_snr_error_f32(1, output, ref, SNR_ERROR_THRESH),
139 		ASSERT_MSG_SNR_LIMIT_EXCEED);
140 
141 	zassert_true(
142 		test_rel_error_f32(1, output, ref, REL_ERROR_THRESH),
143 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
144 
145 	/* Free output buffer */
146 	free(output);
147 }
148 
149 DEFINE_TEST_VARIANT3(statistics_f32, arm_mean_f32, 3, in_com2, 0, 3);
150 DEFINE_TEST_VARIANT3(statistics_f32, arm_mean_f32, 8, in_com2, 1, 8);
151 DEFINE_TEST_VARIANT3(statistics_f32, arm_mean_f32, 11, in_com2, 2, 11);
152 
test_arm_power_f32(const uint32_t * input1,int ref_index,size_t length)153 static void test_arm_power_f32(
154 	const uint32_t *input1, int ref_index, size_t length)
155 {
156 	float32_t ref[1];
157 	float32_t *output;
158 
159 	/* Load reference */
160 	ref[0] = ((float32_t *)ref_power)[ref_index];
161 
162 	/* Allocate output buffer */
163 	output = malloc(1 * sizeof(float32_t));
164 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
165 
166 	/* Run test function */
167 	arm_power_f32((float32_t *)input1, length, &output[0]);
168 
169 	/* Validate output */
170 	zassert_true(
171 		test_snr_error_f32(1, output, ref, SNR_ERROR_THRESH),
172 		ASSERT_MSG_SNR_LIMIT_EXCEED);
173 
174 	zassert_true(
175 		test_rel_error_f32(1, output, ref, REL_ERROR_THRESH),
176 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
177 
178 	/* Free output buffer */
179 	free(output);
180 }
181 
182 DEFINE_TEST_VARIANT3(statistics_f32, arm_power_f32, 3, in_com1, 0, 3);
183 DEFINE_TEST_VARIANT3(statistics_f32, arm_power_f32, 8, in_com1, 1, 8);
184 DEFINE_TEST_VARIANT3(statistics_f32, arm_power_f32, 11, in_com1, 2, 11);
185 
test_arm_rms_f32(const uint32_t * input1,int ref_index,size_t length)186 static void test_arm_rms_f32(
187 	const uint32_t *input1, int ref_index, size_t length)
188 {
189 	float32_t ref[1];
190 	float32_t *output;
191 
192 	/* Load reference */
193 	ref[0] = ((float32_t *)ref_rms)[ref_index];
194 
195 	/* Allocate output buffer */
196 	output = malloc(1 * sizeof(float32_t));
197 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
198 
199 	/* Run test function */
200 	arm_rms_f32((float32_t *)input1, length, &output[0]);
201 
202 	/* Validate output */
203 	zassert_true(
204 		test_snr_error_f32(1, output, ref, SNR_ERROR_THRESH),
205 		ASSERT_MSG_SNR_LIMIT_EXCEED);
206 
207 	zassert_true(
208 		test_rel_error_f32(1, output, ref, REL_ERROR_THRESH),
209 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
210 
211 	/* Free output buffer */
212 	free(output);
213 }
214 
215 DEFINE_TEST_VARIANT3(statistics_f32, arm_rms_f32, 3, in_com1, 0, 3);
216 DEFINE_TEST_VARIANT3(statistics_f32, arm_rms_f32, 8, in_com1, 1, 8);
217 DEFINE_TEST_VARIANT3(statistics_f32, arm_rms_f32, 11, in_com1, 2, 11);
218 
test_arm_std_f32(const uint32_t * input1,int ref_index,size_t length)219 static void test_arm_std_f32(
220 	const uint32_t *input1, int ref_index, size_t length)
221 {
222 	float32_t ref[1];
223 	float32_t *output;
224 
225 	/* Load reference */
226 	ref[0] = ((float32_t *)ref_std)[ref_index];
227 
228 	/* Allocate output buffer */
229 	output = malloc(1 * sizeof(float32_t));
230 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
231 
232 	/* Run test function */
233 	arm_std_f32((float32_t *)input1, length, &output[0]);
234 
235 	/* Validate output */
236 	zassert_true(
237 		test_snr_error_f32(1, output, ref, SNR_ERROR_THRESH),
238 		ASSERT_MSG_SNR_LIMIT_EXCEED);
239 
240 	zassert_true(
241 		test_rel_error_f32(1, output, ref, REL_ERROR_THRESH),
242 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
243 
244 	/* Free output buffer */
245 	free(output);
246 }
247 
248 DEFINE_TEST_VARIANT3(statistics_f32, arm_std_f32, 3, in_com1, 0, 3);
249 DEFINE_TEST_VARIANT3(statistics_f32, arm_std_f32, 8, in_com1, 1, 8);
250 DEFINE_TEST_VARIANT3(statistics_f32, arm_std_f32, 11, in_com1, 2, 11);
251 
test_arm_var_f32(const uint32_t * input1,int ref_index,size_t length)252 static void test_arm_var_f32(
253 	const uint32_t *input1, int ref_index, size_t length)
254 {
255 	float32_t ref[1];
256 	float32_t *output;
257 
258 	/* Load reference */
259 	ref[0] = ((float32_t *)ref_var)[ref_index];
260 
261 	/* Allocate output buffer */
262 	output = malloc(1 * sizeof(float32_t));
263 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
264 
265 	/* Run test function */
266 	arm_var_f32((float32_t *)input1, length, &output[0]);
267 
268 	/* Validate output */
269 	zassert_true(
270 		test_snr_error_f32(1, output, ref, SNR_ERROR_THRESH),
271 		ASSERT_MSG_SNR_LIMIT_EXCEED);
272 
273 	zassert_true(
274 		test_rel_error_f32(1, output, ref, REL_ERROR_THRESH),
275 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
276 
277 	/* Free output buffer */
278 	free(output);
279 }
280 
281 DEFINE_TEST_VARIANT3(statistics_f32, arm_var_f32, 3, in_com1, 0, 3);
282 DEFINE_TEST_VARIANT3(statistics_f32, arm_var_f32, 8, in_com1, 1, 8);
283 DEFINE_TEST_VARIANT3(statistics_f32, arm_var_f32, 11, in_com1, 2, 11);
284 
ZTEST(statistics_f32,test_arm_entropy_f32)285 ZTEST(statistics_f32, test_arm_entropy_f32)
286 {
287 	size_t index;
288 	size_t length = in_entropy_dim[0];
289 	const float32_t *ref = (float32_t *)ref_entropy;
290 	const float32_t *input = (float32_t *)in_entropy;
291 	float32_t *output;
292 
293 	__ASSERT_NO_MSG(ARRAY_SIZE(in_entropy_dim) > length);
294 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_entropy) >= length);
295 
296 	/* Allocate output buffer */
297 	output = malloc(length * sizeof(float32_t));
298 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
299 
300 	/* Run test function */
301 	for (index = 0; index < length; index++) {
302 		output[index] =
303 			arm_entropy_f32(input, in_entropy_dim[index + 1]);
304 		input += in_entropy_dim[index + 1];
305 	}
306 
307 	/* Validate output */
308 	zassert_true(
309 		test_snr_error_f32(length, ref, output, SNR_ERROR_THRESH),
310 		ASSERT_MSG_SNR_LIMIT_EXCEED);
311 
312 	zassert_true(
313 		test_near_equal_f32(length, ref, output, REL_ERROR_THRESH),
314 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
315 
316 	/* Free output buffer */
317 	free(output);
318 }
319 
ZTEST(statistics_f32,test_arm_logsumexp_f32)320 ZTEST(statistics_f32, test_arm_logsumexp_f32)
321 {
322 	size_t index;
323 	size_t length = in_logsumexp_dim[0];
324 	const float32_t *ref = (float32_t *)ref_logsumexp;
325 	const float32_t *input = (float32_t *)in_logsumexp;
326 	float32_t *output;
327 
328 	__ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dim) > length);
329 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp) >= length);
330 
331 	/* Allocate output buffer */
332 	output = malloc(length * sizeof(float32_t));
333 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
334 
335 	/* Run test function */
336 	for (index = 0; index < length; index++) {
337 		output[index] =
338 			arm_logsumexp_f32(input, in_logsumexp_dim[index + 1]);
339 		input += in_logsumexp_dim[index + 1];
340 	}
341 
342 	/* Validate output */
343 	zassert_true(
344 		test_snr_error_f32(length, ref, output, SNR_ERROR_THRESH),
345 		ASSERT_MSG_SNR_LIMIT_EXCEED);
346 
347 	zassert_true(
348 		test_near_equal_f32(length, ref, output, REL_ERROR_THRESH),
349 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
350 
351 	/* Free output buffer */
352 	free(output);
353 }
354 
ZTEST(statistics_f32,test_arm_kullback_leibler_f32)355 ZTEST(statistics_f32, test_arm_kullback_leibler_f32)
356 {
357 	size_t index;
358 	size_t length = in_kl_dim[0];
359 	const float32_t *ref = (float32_t *)ref_kl;
360 	const float32_t *input1 = (float32_t *)in_kl1;
361 	const float32_t *input2 = (float32_t *)in_kl2;
362 	float32_t *output;
363 
364 	__ASSERT_NO_MSG(ARRAY_SIZE(in_kl_dim) > length);
365 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_kl) >= length);
366 
367 	/* Allocate output buffer */
368 	output = malloc(length * sizeof(float32_t));
369 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
370 
371 	/* Run test function */
372 	for (index = 0; index < length; index++) {
373 		output[index] =
374 			arm_kullback_leibler_f32(
375 				input1, input2, in_kl_dim[index + 1]);
376 
377 		input1 += in_kl_dim[index + 1];
378 		input2 += in_kl_dim[index + 1];
379 	}
380 
381 	/* Validate output */
382 	zassert_true(
383 		test_snr_error_f32(length, ref, output, SNR_ERROR_THRESH),
384 		ASSERT_MSG_SNR_LIMIT_EXCEED);
385 
386 	zassert_true(
387 		test_near_equal_f32(length, ref, output, REL_ERROR_THRESH),
388 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
389 
390 	/* Free output buffer */
391 	free(output);
392 }
393 
ZTEST(statistics_f32,test_arm_logsumexp_dot_prod_f32)394 ZTEST(statistics_f32, test_arm_logsumexp_dot_prod_f32)
395 {
396 	size_t index;
397 	size_t length = in_logsumexp_dp_dim[0];
398 	const float32_t *ref = (float32_t *)ref_logsumexp_dp;
399 	const float32_t *input1 = (float32_t *)in_logsumexp_dp1;
400 	const float32_t *input2 = (float32_t *)in_logsumexp_dp2;
401 	float32_t *output;
402 	float32_t *tmp;
403 
404 	__ASSERT_NO_MSG(ARRAY_SIZE(in_logsumexp_dp_dim) > length);
405 	__ASSERT_NO_MSG(ARRAY_SIZE(ref_logsumexp_dp) >= length);
406 
407 	/* Allocate buffers */
408 	output = malloc(length * sizeof(float32_t));
409 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
410 
411 	tmp = malloc(12 * sizeof(float32_t));
412 	zassert_not_null(tmp, ASSERT_MSG_BUFFER_ALLOC_FAILED);
413 
414 	/* Run test function */
415 	for (index = 0; index < length; index++) {
416 		output[index] =
417 			arm_logsumexp_dot_prod_f32(
418 				input1, input2,
419 				in_logsumexp_dp_dim[index + 1], tmp);
420 
421 		input1 += in_logsumexp_dp_dim[index + 1];
422 		input2 += in_logsumexp_dp_dim[index + 1];
423 	}
424 
425 	/* Validate output */
426 	zassert_true(
427 		test_snr_error_f32(length, ref, output, SNR_ERROR_THRESH),
428 		ASSERT_MSG_SNR_LIMIT_EXCEED);
429 
430 	zassert_true(
431 		test_near_equal_f32(length, ref, output, REL_ERROR_THRESH),
432 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
433 
434 	/* Free buffers */
435 	free(output);
436 	free(tmp);
437 }
438 
439 ZTEST_SUITE(statistics_f32, NULL, NULL, NULL, NULL, NULL);
440