1 /*
2 * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "f32.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)120)
17 #define REL_ERROR_THRESH (7.0e-6)
18
test_arm_cmplx_conj_f32(const uint32_t * input1,const uint32_t * ref,size_t length)19 static void test_arm_cmplx_conj_f32(
20 const uint32_t *input1, const uint32_t *ref, size_t length)
21 {
22 size_t buf_length;
23 float32_t *output;
24
25 /* Complex number buffer length is twice the data length */
26 buf_length = 2 * length;
27
28 /* Allocate output buffer */
29 output = malloc(buf_length * sizeof(float32_t));
30 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
31
32 /* Run test function */
33 arm_cmplx_conj_f32((float32_t *)input1, output, length);
34
35 /* Validate output */
36 zassert_true(
37 test_snr_error_f32(buf_length, output, (float32_t *)ref,
38 SNR_ERROR_THRESH),
39 ASSERT_MSG_SNR_LIMIT_EXCEED);
40
41 zassert_true(
42 test_rel_error_f32(buf_length, output, (float32_t *)ref,
43 REL_ERROR_THRESH),
44 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
45
46 /* Free output buffer */
47 free(output);
48 }
49
50 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 3, in_com1, ref_conj, 3);
51 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 8, in_com1, ref_conj, 8);
52 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 11, in_com1, ref_conj, 11);
53
test_arm_cmplx_dot_prod_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)54 static void test_arm_cmplx_dot_prod_f32(
55 const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
56 size_t length)
57 {
58 float32_t *output;
59
60 /* Allocate output buffer */
61 output = malloc(2 * sizeof(float32_t));
62 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
63
64 /* Run test function */
65 arm_cmplx_dot_prod_f32(
66 (float32_t *)input1, (float32_t *)input2, length,
67 &output[0], &output[1]);
68
69 /* Validate output */
70 zassert_true(
71 test_snr_error_f32(2, output, (float32_t *)ref,
72 SNR_ERROR_THRESH),
73 ASSERT_MSG_SNR_LIMIT_EXCEED);
74
75 zassert_true(
76 test_rel_error_f32(2, output, (float32_t *)ref,
77 REL_ERROR_THRESH),
78 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
79
80 /* Free output buffer */
81 free(output);
82 }
83
84 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 3, in_com1, in_com2, ref_dot_prod_3, 3);
85 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 8, in_com1, in_com2, ref_dot_prod_4n, 8);
86 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 11, in_com1, in_com2, ref_dot_prod_4n1, 11);
87
test_arm_cmplx_mag_f32(const uint32_t * input1,const uint32_t * ref,size_t length)88 static void test_arm_cmplx_mag_f32(
89 const uint32_t *input1, const uint32_t *ref, size_t length)
90 {
91 float32_t *output;
92
93 /* Allocate output buffer */
94 output = malloc(length * sizeof(float32_t));
95 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
96
97 /* Run test function */
98 arm_cmplx_mag_f32((float32_t *)input1, output, length);
99
100 /* Validate output */
101 zassert_true(
102 test_snr_error_f32(length, output, (float32_t *)ref,
103 SNR_ERROR_THRESH),
104 ASSERT_MSG_SNR_LIMIT_EXCEED);
105
106 zassert_true(
107 test_rel_error_f32(length, output, (float32_t *)ref,
108 REL_ERROR_THRESH),
109 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
110
111 /* Free output buffer */
112 free(output);
113 }
114
115 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 3, in_com1, ref_mag, 3);
116 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 8, in_com1, ref_mag, 8);
117 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 11, in_com1, ref_mag, 11);
118
test_arm_cmplx_mag_squared_f32(const uint32_t * input1,const uint32_t * ref,size_t length)119 static void test_arm_cmplx_mag_squared_f32(
120 const uint32_t *input1, const uint32_t *ref, size_t length)
121 {
122 float32_t *output;
123
124 /* Allocate output buffer */
125 output = malloc(length * sizeof(float32_t));
126 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
127
128 /* Run test function */
129 arm_cmplx_mag_squared_f32((float32_t *)input1, output, length);
130
131 /* Validate output */
132 zassert_true(
133 test_snr_error_f32(length, output, (float32_t *)ref,
134 SNR_ERROR_THRESH),
135 ASSERT_MSG_SNR_LIMIT_EXCEED);
136
137 zassert_true(
138 test_rel_error_f32(length, output, (float32_t *)ref,
139 REL_ERROR_THRESH),
140 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
141
142 /* Free output buffer */
143 free(output);
144 }
145
146 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 3, in_com1, ref_mag_squared, 3);
147 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 8, in_com1, ref_mag_squared, 8);
148 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 11, in_com1, ref_mag_squared, 11);
149
test_arm_cmplx_mult_cmplx_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)150 static void test_arm_cmplx_mult_cmplx_f32(
151 const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
152 size_t length)
153 {
154 size_t buf_length;
155 float32_t *output;
156
157 /* Complex number buffer length is twice the data length */
158 buf_length = 2 * length;
159
160 /* Allocate output buffer */
161 output = malloc(buf_length * sizeof(float32_t));
162 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
163
164 /* Run test function */
165 arm_cmplx_mult_cmplx_f32(
166 (float32_t *)input1, (float32_t *)input2, output, length);
167
168 /* Validate output */
169 zassert_true(
170 test_snr_error_f32(buf_length, output, (float32_t *)ref,
171 SNR_ERROR_THRESH),
172 ASSERT_MSG_SNR_LIMIT_EXCEED);
173
174 zassert_true(
175 test_rel_error_f32(buf_length, output, (float32_t *)ref,
176 REL_ERROR_THRESH),
177 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
178
179 /* Free output buffer */
180 free(output);
181 }
182
183 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 3, in_com1, in_com2, ref_mult_cmplx, 3);
184 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 8, in_com1, in_com2, ref_mult_cmplx, 8);
185 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 11, in_com1, in_com2, ref_mult_cmplx, 11);
186
test_arm_cmplx_mult_real_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)187 static void test_arm_cmplx_mult_real_f32(
188 const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
189 size_t length)
190 {
191 size_t buf_length;
192 float32_t *output;
193
194 /* Complex number buffer length is twice the data length */
195 buf_length = 2 * length;
196
197 /* Allocate output buffer */
198 output = malloc(buf_length * sizeof(float32_t));
199 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
200
201 /* Run test function */
202 arm_cmplx_mult_real_f32(
203 (float32_t *)input1, (float32_t *)input2, output, length);
204
205 /* Validate output */
206 zassert_true(
207 test_snr_error_f32(
208 buf_length, output, (float32_t *)ref,
209 SNR_ERROR_THRESH),
210 ASSERT_MSG_SNR_LIMIT_EXCEED);
211
212 zassert_true(
213 test_rel_error_f32(
214 buf_length, output, (float32_t *)ref,
215 REL_ERROR_THRESH),
216 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
217
218 /* Free output buffer */
219 free(output);
220 }
221
222 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 3, in_com1, in_com3, ref_mult_real, 3);
223 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 8, in_com1, in_com3, ref_mult_real, 8);
224 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 11, in_com1, in_com3, ref_mult_real, 11);
225
test_complexmath_f32(void)226 void test_complexmath_f32(void)
227 {
228 ztest_test_suite(complexmath_f32,
229 ztest_unit_test(test_arm_cmplx_conj_f32_3),
230 ztest_unit_test(test_arm_cmplx_conj_f32_8),
231 ztest_unit_test(test_arm_cmplx_conj_f32_11),
232 ztest_unit_test(test_arm_cmplx_dot_prod_f32_3),
233 ztest_unit_test(test_arm_cmplx_dot_prod_f32_8),
234 ztest_unit_test(test_arm_cmplx_dot_prod_f32_11),
235 ztest_unit_test(test_arm_cmplx_mag_f32_3),
236 ztest_unit_test(test_arm_cmplx_mag_f32_8),
237 ztest_unit_test(test_arm_cmplx_mag_f32_11),
238 ztest_unit_test(test_arm_cmplx_mag_squared_f32_3),
239 ztest_unit_test(test_arm_cmplx_mag_squared_f32_8),
240 ztest_unit_test(test_arm_cmplx_mag_squared_f32_11),
241 ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_3),
242 ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_8),
243 ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_11),
244 ztest_unit_test(test_arm_cmplx_mult_real_f32_3),
245 ztest_unit_test(test_arm_cmplx_mult_real_f32_8),
246 ztest_unit_test(test_arm_cmplx_mult_real_f32_11)
247 );
248
249 ztest_run_test_suite(complexmath_f32);
250 }
251