1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "f16.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)60)
17 #define SNR_LOG_ERROR_THRESH	((float32_t)40)
18 #define REL_ERROR_THRESH	(1.0e-3)
19 #define REL_LOG_ERROR_THRESH	(3.0e-2)
20 #define ABS_ERROR_THRESH	(1.0e-3)
21 #define ABS_LOG_ERROR_THRESH	(3.0e-2)
22 
23 #ifdef CONFIG_ARMV8_1_M_MVEF
24 /*
25  * NOTE: The MVE vector version of the `vinverse` function is slightly less
26  *       accurate than the scalar version.
27  */
28 #undef REL_ERROR_THRESH
29 #define REL_ERROR_THRESH	(1.1e-3)
30 #endif
31 
32 #if 0
33 /*
34  * NOTE: These tests must be enabled once the F16 sine and cosine function
35  *       implementations are added.
36  */
37 static void test_arm_cos_f16(void)
38 {
39 	size_t index;
40 	size_t length = ARRAY_SIZE(in_angles);
41 	float16_t *output;
42 
43 	/* Allocate output buffer */
44 	output = malloc(length * sizeof(float16_t));
45 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
46 
47 	/* Run test function */
48 	for (index = 0; index < length; index++) {
49 		output[index] = arm_cos_f16(((float16_t *)in_angles)[index]);
50 	}
51 
52 	/* Validate output */
53 	zassert_true(
54 		test_snr_error_f16(length, output, (float16_t *)ref_cos,
55 			SNR_ERROR_THRESH),
56 		ASSERT_MSG_SNR_LIMIT_EXCEED);
57 
58 	zassert_true(
59 		test_close_error_f16(length, output, (float16_t *)ref_cos,
60 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
61 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
62 
63 	/* Free output buffer */
64 	free(output);
65 }
66 
67 static void test_arm_sin_f16(void)
68 {
69 	size_t index;
70 	size_t length = ARRAY_SIZE(in_angles);
71 	float16_t *output;
72 
73 	/* Allocate output buffer */
74 	output = malloc(length * sizeof(float16_t));
75 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
76 
77 	/* Run test function */
78 	for (index = 0; index < length; index++) {
79 		output[index] = arm_sin_f16(((float16_t *)in_angles)[index]);
80 	}
81 
82 	/* Validate output */
83 	zassert_true(
84 		test_snr_error_f16(length, output, (float16_t *)ref_sin,
85 			SNR_ERROR_THRESH),
86 		ASSERT_MSG_SNR_LIMIT_EXCEED);
87 
88 	zassert_true(
89 		test_close_error_f16(length, output, (float16_t *)ref_sin,
90 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
91 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
92 
93 	/* Free output buffer */
94 	free(output);
95 }
96 #endif
97 
98 ZTEST_SUITE(fastmath_f16, NULL, NULL, NULL, NULL, NULL);
99 
ZTEST(fastmath_f16,test_arm_sqrt_f16)100 ZTEST(fastmath_f16, test_arm_sqrt_f16)
101 {
102 	size_t index;
103 	size_t length = ARRAY_SIZE(in_sqrt);
104 	arm_status status;
105 	float16_t *output;
106 
107 	/* Allocate output buffer */
108 	output = malloc(length * sizeof(float16_t));
109 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
110 
111 	/* Run test function */
112 	for (index = 0; index < length; index++) {
113 		status = arm_sqrt_f16(
114 			((float16_t *)in_sqrt)[index], &output[index]);
115 
116 		/* Validate operation status */
117 		if (((float16_t *)in_sqrt)[index] < 0.0f) {
118 			zassert_equal(status, ARM_MATH_ARGUMENT_ERROR,
119 				"square root did fail with an input value "
120 				"of '0'");
121 		} else {
122 			zassert_equal(status, ARM_MATH_SUCCESS,
123 				"square root operation did not succeed");
124 		}
125 	}
126 
127 	/* Validate output */
128 	zassert_true(
129 		test_snr_error_f16(length, output, (float16_t *)ref_sqrt,
130 			SNR_ERROR_THRESH),
131 		ASSERT_MSG_SNR_LIMIT_EXCEED);
132 
133 	zassert_true(
134 		test_close_error_f16(length, output, (float16_t *)ref_sqrt,
135 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
136 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
137 
138 	/* Free output buffer */
139 	free(output);
140 }
141 
test_arm_vlog_f16(const uint16_t * input1,const uint16_t * ref,size_t length)142 static void test_arm_vlog_f16(
143 	const uint16_t *input1, const uint16_t *ref, size_t length)
144 {
145 	float16_t *output;
146 
147 	/* Allocate output buffer */
148 	output = malloc(length * sizeof(float16_t));
149 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
150 
151 	/* Run test function */
152 	arm_vlog_f16((float16_t *)input1, output, length);
153 
154 	/* Validate output */
155 	zassert_true(
156 		test_snr_error_f16(length, output, (float16_t *)ref,
157 			SNR_LOG_ERROR_THRESH),
158 		ASSERT_MSG_SNR_LIMIT_EXCEED);
159 
160 	zassert_true(
161 		test_close_error_f16(length, output, (float16_t *)ref,
162 			ABS_LOG_ERROR_THRESH, REL_LOG_ERROR_THRESH),
163 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
164 
165 	/* Free output buffer */
166 	free(output);
167 }
168 
169 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, all, in_log, ref_log, 25);
170 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 3, in_log, ref_log, 3);
171 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 8, in_log, ref_log, 8);
172 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 11, in_log, ref_log, 11);
173 
test_arm_vexp_f16(const uint16_t * input1,const uint16_t * ref,size_t length)174 static void test_arm_vexp_f16(
175 	const uint16_t *input1, const uint16_t *ref, size_t length)
176 {
177 	float16_t *output;
178 
179 	/* Allocate output buffer */
180 	output = malloc(length * sizeof(float16_t));
181 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
182 
183 	/* Run test function */
184 	arm_vexp_f16((float16_t *)input1, output, length);
185 
186 	/* Validate output */
187 	zassert_true(
188 		test_snr_error_f16(length, output, (float16_t *)ref,
189 			SNR_ERROR_THRESH),
190 		ASSERT_MSG_SNR_LIMIT_EXCEED);
191 
192 	zassert_true(
193 		test_close_error_f16(length, output, (float16_t *)ref,
194 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
195 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
196 
197 	/* Free output buffer */
198 	free(output);
199 }
200 
201 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, all, in_exp, ref_exp, 52);
202 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 3, in_exp, ref_exp, 3);
203 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 8, in_exp, ref_exp, 8);
204 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 11, in_exp, ref_exp, 11);
205 
ZTEST(fastmath_f16,test_arm_vinverse_f16)206 ZTEST(fastmath_f16, test_arm_vinverse_f16)
207 {
208 	size_t length = ARRAY_SIZE(ref_vinverse);
209 	float16_t *output;
210 
211 	/* Allocate output buffer */
212 	output = malloc(length * sizeof(float16_t));
213 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
214 
215 	/* Run test function */
216 	arm_vinverse_f16((float16_t *)in_vinverse, output, length);
217 
218 	/* Validate output */
219 	zassert_true(
220 		test_snr_error_f16(length, output, (float16_t *)ref_vinverse,
221 			SNR_ERROR_THRESH),
222 		ASSERT_MSG_SNR_LIMIT_EXCEED);
223 
224 	zassert_true(
225 		test_close_error_f16(length, output, (float16_t *)ref_vinverse,
226 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
227 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
228 
229 	/* Free output buffer */
230 	free(output);
231 }
232 
233 /* TODO: Add inverse test */
234