1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13
14 #include "f16.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)60)
17 #define SNR_LOG_ERROR_THRESH ((float32_t)40)
18 #define REL_ERROR_THRESH (1.0e-3)
19 #define REL_LOG_ERROR_THRESH (3.0e-2)
20 #define ABS_ERROR_THRESH (1.0e-3)
21 #define ABS_LOG_ERROR_THRESH (3.0e-2)
22
23 #ifdef CONFIG_ARMV8_1_M_MVEF
24 /*
25 * NOTE: The MVE vector version of the `vinverse` function is slightly less
26 * accurate than the scalar version.
27 */
28 #undef REL_ERROR_THRESH
29 #define REL_ERROR_THRESH (1.1e-3)
30 #endif
31
32 #if 0
33 /*
34 * NOTE: These tests must be enabled once the F16 sine and cosine function
35 * implementations are added.
36 */
37 static void test_arm_cos_f16(void)
38 {
39 size_t index;
40 size_t length = ARRAY_SIZE(in_angles);
41 float16_t *output;
42
43 /* Allocate output buffer */
44 output = malloc(length * sizeof(float16_t));
45 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
46
47 /* Run test function */
48 for (index = 0; index < length; index++) {
49 output[index] = arm_cos_f16(((float16_t *)in_angles)[index]);
50 }
51
52 /* Validate output */
53 zassert_true(
54 test_snr_error_f16(length, output, (float16_t *)ref_cos,
55 SNR_ERROR_THRESH),
56 ASSERT_MSG_SNR_LIMIT_EXCEED);
57
58 zassert_true(
59 test_close_error_f16(length, output, (float16_t *)ref_cos,
60 ABS_ERROR_THRESH, REL_ERROR_THRESH),
61 ASSERT_MSG_ERROR_LIMIT_EXCEED);
62
63 /* Free output buffer */
64 free(output);
65 }
66
67 static void test_arm_sin_f16(void)
68 {
69 size_t index;
70 size_t length = ARRAY_SIZE(in_angles);
71 float16_t *output;
72
73 /* Allocate output buffer */
74 output = malloc(length * sizeof(float16_t));
75 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
76
77 /* Run test function */
78 for (index = 0; index < length; index++) {
79 output[index] = arm_sin_f16(((float16_t *)in_angles)[index]);
80 }
81
82 /* Validate output */
83 zassert_true(
84 test_snr_error_f16(length, output, (float16_t *)ref_sin,
85 SNR_ERROR_THRESH),
86 ASSERT_MSG_SNR_LIMIT_EXCEED);
87
88 zassert_true(
89 test_close_error_f16(length, output, (float16_t *)ref_sin,
90 ABS_ERROR_THRESH, REL_ERROR_THRESH),
91 ASSERT_MSG_ERROR_LIMIT_EXCEED);
92
93 /* Free output buffer */
94 free(output);
95 }
96 #endif
97
98 ZTEST_SUITE(fastmath_f16, NULL, NULL, NULL, NULL, NULL);
99
ZTEST(fastmath_f16,test_arm_sqrt_f16)100 ZTEST(fastmath_f16, test_arm_sqrt_f16)
101 {
102 size_t index;
103 size_t length = ARRAY_SIZE(in_sqrt);
104 arm_status status;
105 float16_t *output;
106
107 /* Allocate output buffer */
108 output = malloc(length * sizeof(float16_t));
109 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
110
111 /* Run test function */
112 for (index = 0; index < length; index++) {
113 status = arm_sqrt_f16(
114 ((float16_t *)in_sqrt)[index], &output[index]);
115
116 /* Validate operation status */
117 if (((float16_t *)in_sqrt)[index] < 0.0f) {
118 zassert_equal(status, ARM_MATH_ARGUMENT_ERROR,
119 "square root did fail with an input value "
120 "of '0'");
121 } else {
122 zassert_equal(status, ARM_MATH_SUCCESS,
123 "square root operation did not succeed");
124 }
125 }
126
127 /* Validate output */
128 zassert_true(
129 test_snr_error_f16(length, output, (float16_t *)ref_sqrt,
130 SNR_ERROR_THRESH),
131 ASSERT_MSG_SNR_LIMIT_EXCEED);
132
133 zassert_true(
134 test_close_error_f16(length, output, (float16_t *)ref_sqrt,
135 ABS_ERROR_THRESH, REL_ERROR_THRESH),
136 ASSERT_MSG_ERROR_LIMIT_EXCEED);
137
138 /* Free output buffer */
139 free(output);
140 }
141
test_arm_vlog_f16(const uint16_t * input1,const uint16_t * ref,size_t length)142 static void test_arm_vlog_f16(
143 const uint16_t *input1, const uint16_t *ref, size_t length)
144 {
145 float16_t *output;
146
147 /* Allocate output buffer */
148 output = malloc(length * sizeof(float16_t));
149 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
150
151 /* Run test function */
152 arm_vlog_f16((float16_t *)input1, output, length);
153
154 /* Validate output */
155 zassert_true(
156 test_snr_error_f16(length, output, (float16_t *)ref,
157 SNR_LOG_ERROR_THRESH),
158 ASSERT_MSG_SNR_LIMIT_EXCEED);
159
160 zassert_true(
161 test_close_error_f16(length, output, (float16_t *)ref,
162 ABS_LOG_ERROR_THRESH, REL_LOG_ERROR_THRESH),
163 ASSERT_MSG_ERROR_LIMIT_EXCEED);
164
165 /* Free output buffer */
166 free(output);
167 }
168
169 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, all, in_log, ref_log, 25);
170 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 3, in_log, ref_log, 3);
171 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 8, in_log, ref_log, 8);
172 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vlog_f16, 11, in_log, ref_log, 11);
173
test_arm_vexp_f16(const uint16_t * input1,const uint16_t * ref,size_t length)174 static void test_arm_vexp_f16(
175 const uint16_t *input1, const uint16_t *ref, size_t length)
176 {
177 float16_t *output;
178
179 /* Allocate output buffer */
180 output = malloc(length * sizeof(float16_t));
181 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
182
183 /* Run test function */
184 arm_vexp_f16((float16_t *)input1, output, length);
185
186 /* Validate output */
187 zassert_true(
188 test_snr_error_f16(length, output, (float16_t *)ref,
189 SNR_ERROR_THRESH),
190 ASSERT_MSG_SNR_LIMIT_EXCEED);
191
192 zassert_true(
193 test_close_error_f16(length, output, (float16_t *)ref,
194 ABS_ERROR_THRESH, REL_ERROR_THRESH),
195 ASSERT_MSG_ERROR_LIMIT_EXCEED);
196
197 /* Free output buffer */
198 free(output);
199 }
200
201 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, all, in_exp, ref_exp, 52);
202 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 3, in_exp, ref_exp, 3);
203 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 8, in_exp, ref_exp, 8);
204 DEFINE_TEST_VARIANT3(fastmath_f16, arm_vexp_f16, 11, in_exp, ref_exp, 11);
205
ZTEST(fastmath_f16,test_arm_vinverse_f16)206 ZTEST(fastmath_f16, test_arm_vinverse_f16)
207 {
208 size_t length = ARRAY_SIZE(ref_vinverse);
209 float16_t *output;
210
211 /* Allocate output buffer */
212 output = malloc(length * sizeof(float16_t));
213 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
214
215 /* Run test function */
216 arm_vinverse_f16((float16_t *)in_vinverse, output, length);
217
218 /* Validate output */
219 zassert_true(
220 test_snr_error_f16(length, output, (float16_t *)ref_vinverse,
221 SNR_ERROR_THRESH),
222 ASSERT_MSG_SNR_LIMIT_EXCEED);
223
224 zassert_true(
225 test_close_error_f16(length, output, (float16_t *)ref_vinverse,
226 ABS_ERROR_THRESH, REL_ERROR_THRESH),
227 ASSERT_MSG_ERROR_LIMIT_EXCEED);
228
229 /* Free output buffer */
230 free(output);
231 }
232
233 /* TODO: Add inverse test */
234