1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13
14 #include "f16.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)39)
17 #define REL_ERROR_THRESH (6.0e-2)
18
19 ZTEST_SUITE(complexmath_f16, NULL, NULL, NULL, NULL, NULL);
20
test_arm_cmplx_conj_f16(const uint16_t * input1,const uint16_t * ref,size_t length)21 static void test_arm_cmplx_conj_f16(
22 const uint16_t *input1, const uint16_t *ref, size_t length)
23 {
24 size_t buf_length;
25 float16_t *output;
26
27 /* Complex number buffer length is twice the data length */
28 buf_length = 2 * length;
29
30 /* Allocate output buffer */
31 output = malloc(buf_length * sizeof(float16_t));
32 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
33
34 /* Run test function */
35 arm_cmplx_conj_f16((float16_t *)input1, output, length);
36
37 /* Validate output */
38 zassert_true(
39 test_snr_error_f16(buf_length, output, (float16_t *)ref,
40 SNR_ERROR_THRESH),
41 ASSERT_MSG_SNR_LIMIT_EXCEED);
42
43 zassert_true(
44 test_rel_error_f16(buf_length, output, (float16_t *)ref,
45 REL_ERROR_THRESH),
46 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
47
48 /* Free output buffer */
49 free(output);
50 }
51
52 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 7, in_com1, ref_conj, 7);
53 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 16, in_com1, ref_conj, 16);
54 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 23, in_com1, ref_conj, 23);
55
test_arm_cmplx_dot_prod_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)56 static void test_arm_cmplx_dot_prod_f16(
57 const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
58 size_t length)
59 {
60 float16_t *output;
61
62 /* Allocate output buffer */
63 output = malloc(2 * sizeof(float16_t));
64 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
65
66 /* Run test function */
67 arm_cmplx_dot_prod_f16(
68 (float16_t *)input1, (float16_t *)input2, length,
69 &output[0], &output[1]);
70
71 /* Validate output */
72 zassert_true(
73 test_snr_error_f16(2, output, (float16_t *)ref,
74 SNR_ERROR_THRESH),
75 ASSERT_MSG_SNR_LIMIT_EXCEED);
76
77 zassert_true(
78 test_rel_error_f16(2, output, (float16_t *)ref,
79 REL_ERROR_THRESH),
80 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
81
82 /* Free output buffer */
83 free(output);
84 }
85
86 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 7, in_com1, in_com2, ref_dot_prod_3,
87 7);
88 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 16, in_com1, in_com2, ref_dot_prod_4n,
89 16);
90 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 23, in_com1, in_com2,
91 ref_dot_prod_4n1, 23);
92
test_arm_cmplx_mag_f16(const uint16_t * input1,const uint16_t * ref,size_t length)93 static void test_arm_cmplx_mag_f16(
94 const uint16_t *input1, const uint16_t *ref, size_t length)
95 {
96 float16_t *output;
97
98 /* Allocate output buffer */
99 output = malloc(length * sizeof(float16_t));
100 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
101
102 /* Run test function */
103 arm_cmplx_mag_f16((float16_t *)input1, output, length);
104
105 /* Validate output */
106 zassert_true(
107 test_snr_error_f16(length, output, (float16_t *)ref,
108 SNR_ERROR_THRESH),
109 ASSERT_MSG_SNR_LIMIT_EXCEED);
110
111 zassert_true(
112 test_rel_error_f16(length, output, (float16_t *)ref,
113 REL_ERROR_THRESH),
114 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
115
116 /* Free output buffer */
117 free(output);
118 }
119
120 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 7, in_com1, ref_mag, 7);
121 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 16, in_com1, ref_mag, 16);
122 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 23, in_com1, ref_mag, 23);
123
test_arm_cmplx_mag_squared_f16(const uint16_t * input1,const uint16_t * ref,size_t length)124 static void test_arm_cmplx_mag_squared_f16(
125 const uint16_t *input1, const uint16_t *ref, size_t length)
126 {
127 float16_t *output;
128
129 /* Allocate output buffer */
130 output = malloc(length * sizeof(float16_t));
131 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
132
133 /* Run test function */
134 arm_cmplx_mag_squared_f16((float16_t *)input1, output, length);
135
136 /* Validate output */
137 zassert_true(
138 test_snr_error_f16(length, output, (float16_t *)ref,
139 SNR_ERROR_THRESH),
140 ASSERT_MSG_SNR_LIMIT_EXCEED);
141
142 zassert_true(
143 test_rel_error_f16(length, output, (float16_t *)ref,
144 REL_ERROR_THRESH),
145 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
146
147 /* Free output buffer */
148 free(output);
149 }
150
151 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 7, in_com1, ref_mag_squared, 7);
152 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 16, in_com1, ref_mag_squared, 16);
153 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 23, in_com1, ref_mag_squared, 23);
154
test_arm_cmplx_mult_cmplx_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)155 static void test_arm_cmplx_mult_cmplx_f16(
156 const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
157 size_t length)
158 {
159 size_t buf_length;
160 float16_t *output;
161
162 /* Complex number buffer length is twice the data length */
163 buf_length = 2 * length;
164
165 /* Allocate output buffer */
166 output = malloc(buf_length * sizeof(float16_t));
167 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
168
169 /* Run test function */
170 arm_cmplx_mult_cmplx_f16(
171 (float16_t *)input1, (float16_t *)input2, output, length);
172
173 /* Validate output */
174 zassert_true(
175 test_snr_error_f16(buf_length, output, (float16_t *)ref,
176 SNR_ERROR_THRESH),
177 ASSERT_MSG_SNR_LIMIT_EXCEED);
178
179 zassert_true(
180 test_rel_error_f16(buf_length, output, (float16_t *)ref,
181 REL_ERROR_THRESH),
182 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
183
184 /* Free output buffer */
185 free(output);
186 }
187
188 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 7, in_com1, in_com2, ref_mult_cmplx,
189 7);
190 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 16, in_com1, in_com2,
191 ref_mult_cmplx, 16);
192 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 23, in_com1, in_com2,
193 ref_mult_cmplx, 23);
194
test_arm_cmplx_mult_real_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)195 static void test_arm_cmplx_mult_real_f16(
196 const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
197 size_t length)
198 {
199 size_t buf_length;
200 float16_t *output;
201
202 /* Complex number buffer length is twice the data length */
203 buf_length = 2 * length;
204
205 /* Allocate output buffer */
206 output = malloc(buf_length * sizeof(float16_t));
207 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
208
209 /* Run test function */
210 arm_cmplx_mult_real_f16(
211 (float16_t *)input1, (float16_t *)input2, output, length);
212
213 /* Validate output */
214 zassert_true(
215 test_snr_error_f16(
216 buf_length, output, (float16_t *)ref,
217 SNR_ERROR_THRESH),
218 ASSERT_MSG_SNR_LIMIT_EXCEED);
219
220 zassert_true(
221 test_rel_error_f16(
222 buf_length, output, (float16_t *)ref,
223 REL_ERROR_THRESH),
224 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
225
226 /* Free output buffer */
227 free(output);
228 }
229
230 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 7, in_com1, in_com3, ref_mult_real,
231 7);
232 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 16, in_com1, in_com3, ref_mult_real,
233 16);
234 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 23, in_com1, in_com3, ref_mult_real,
235 23);
236