1 /*
2  * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "f32.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)120)
17 #define REL_ERROR_THRESH	(7.0e-6)
18 
test_arm_cmplx_conj_f32(const uint32_t * input1,const uint32_t * ref,size_t length)19 static void test_arm_cmplx_conj_f32(
20 	const uint32_t *input1, const uint32_t *ref, size_t length)
21 {
22 	size_t buf_length;
23 	float32_t *output;
24 
25 	/* Complex number buffer length is twice the data length */
26 	buf_length = 2 * length;
27 
28 	/* Allocate output buffer */
29 	output = malloc(buf_length * sizeof(float32_t));
30 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
31 
32 	/* Run test function */
33 	arm_cmplx_conj_f32((float32_t *)input1, output, length);
34 
35 	/* Validate output */
36 	zassert_true(
37 		test_snr_error_f32(buf_length, output, (float32_t *)ref,
38 			SNR_ERROR_THRESH),
39 		ASSERT_MSG_SNR_LIMIT_EXCEED);
40 
41 	zassert_true(
42 		test_rel_error_f32(buf_length, output, (float32_t *)ref,
43 			REL_ERROR_THRESH),
44 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
45 
46 	/* Free output buffer */
47 	free(output);
48 }
49 
50 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 3, in_com1, ref_conj, 3);
51 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 8, in_com1, ref_conj, 8);
52 DEFINE_TEST_VARIANT3(arm_cmplx_conj_f32, 11, in_com1, ref_conj, 11);
53 
test_arm_cmplx_dot_prod_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)54 static void test_arm_cmplx_dot_prod_f32(
55 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
56 	size_t length)
57 {
58 	float32_t *output;
59 
60 	/* Allocate output buffer */
61 	output = malloc(2 * sizeof(float32_t));
62 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
63 
64 	/* Run test function */
65 	arm_cmplx_dot_prod_f32(
66 		(float32_t *)input1, (float32_t *)input2, length,
67 		&output[0], &output[1]);
68 
69 	/* Validate output */
70 	zassert_true(
71 		test_snr_error_f32(2, output, (float32_t *)ref,
72 			SNR_ERROR_THRESH),
73 		ASSERT_MSG_SNR_LIMIT_EXCEED);
74 
75 	zassert_true(
76 		test_rel_error_f32(2, output, (float32_t *)ref,
77 			REL_ERROR_THRESH),
78 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
79 
80 	/* Free output buffer */
81 	free(output);
82 }
83 
84 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 3, in_com1, in_com2, ref_dot_prod_3, 3);
85 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 8, in_com1, in_com2, ref_dot_prod_4n, 8);
86 DEFINE_TEST_VARIANT4(arm_cmplx_dot_prod_f32, 11, in_com1, in_com2, ref_dot_prod_4n1, 11);
87 
test_arm_cmplx_mag_f32(const uint32_t * input1,const uint32_t * ref,size_t length)88 static void test_arm_cmplx_mag_f32(
89 	const uint32_t *input1, const uint32_t *ref, size_t length)
90 {
91 	float32_t *output;
92 
93 	/* Allocate output buffer */
94 	output = malloc(length * sizeof(float32_t));
95 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
96 
97 	/* Run test function */
98 	arm_cmplx_mag_f32((float32_t *)input1, output, length);
99 
100 	/* Validate output */
101 	zassert_true(
102 		test_snr_error_f32(length, output, (float32_t *)ref,
103 			SNR_ERROR_THRESH),
104 		ASSERT_MSG_SNR_LIMIT_EXCEED);
105 
106 	zassert_true(
107 		test_rel_error_f32(length, output, (float32_t *)ref,
108 			REL_ERROR_THRESH),
109 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
110 
111 	/* Free output buffer */
112 	free(output);
113 }
114 
115 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 3, in_com1, ref_mag, 3);
116 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 8, in_com1, ref_mag, 8);
117 DEFINE_TEST_VARIANT3(arm_cmplx_mag_f32, 11, in_com1, ref_mag, 11);
118 
test_arm_cmplx_mag_squared_f32(const uint32_t * input1,const uint32_t * ref,size_t length)119 static void test_arm_cmplx_mag_squared_f32(
120 	const uint32_t *input1, const uint32_t *ref, size_t length)
121 {
122 	float32_t *output;
123 
124 	/* Allocate output buffer */
125 	output = malloc(length * sizeof(float32_t));
126 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
127 
128 	/* Run test function */
129 	arm_cmplx_mag_squared_f32((float32_t *)input1, output, length);
130 
131 	/* Validate output */
132 	zassert_true(
133 		test_snr_error_f32(length, output, (float32_t *)ref,
134 			SNR_ERROR_THRESH),
135 		ASSERT_MSG_SNR_LIMIT_EXCEED);
136 
137 	zassert_true(
138 		test_rel_error_f32(length, output, (float32_t *)ref,
139 			REL_ERROR_THRESH),
140 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
141 
142 	/* Free output buffer */
143 	free(output);
144 }
145 
146 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 3, in_com1, ref_mag_squared, 3);
147 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 8, in_com1, ref_mag_squared, 8);
148 DEFINE_TEST_VARIANT3(arm_cmplx_mag_squared_f32, 11, in_com1, ref_mag_squared, 11);
149 
test_arm_cmplx_mult_cmplx_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)150 static void test_arm_cmplx_mult_cmplx_f32(
151 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
152 	size_t length)
153 {
154 	size_t buf_length;
155 	float32_t *output;
156 
157 	/* Complex number buffer length is twice the data length */
158 	buf_length = 2 * length;
159 
160 	/* Allocate output buffer */
161 	output = malloc(buf_length * sizeof(float32_t));
162 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
163 
164 	/* Run test function */
165 	arm_cmplx_mult_cmplx_f32(
166 		(float32_t *)input1, (float32_t *)input2, output, length);
167 
168 	/* Validate output */
169 	zassert_true(
170 		test_snr_error_f32(buf_length, output, (float32_t *)ref,
171 			SNR_ERROR_THRESH),
172 		ASSERT_MSG_SNR_LIMIT_EXCEED);
173 
174 	zassert_true(
175 		test_rel_error_f32(buf_length, output, (float32_t *)ref,
176 			REL_ERROR_THRESH),
177 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
178 
179 	/* Free output buffer */
180 	free(output);
181 }
182 
183 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 3, in_com1, in_com2, ref_mult_cmplx, 3);
184 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 8, in_com1, in_com2, ref_mult_cmplx, 8);
185 DEFINE_TEST_VARIANT4(arm_cmplx_mult_cmplx_f32, 11, in_com1, in_com2, ref_mult_cmplx, 11);
186 
test_arm_cmplx_mult_real_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)187 static void test_arm_cmplx_mult_real_f32(
188 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
189 	size_t length)
190 {
191 	size_t buf_length;
192 	float32_t *output;
193 
194 	/* Complex number buffer length is twice the data length */
195 	buf_length = 2 * length;
196 
197 	/* Allocate output buffer */
198 	output = malloc(buf_length * sizeof(float32_t));
199 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
200 
201 	/* Run test function */
202 	arm_cmplx_mult_real_f32(
203 		(float32_t *)input1, (float32_t *)input2, output, length);
204 
205 	/* Validate output */
206 	zassert_true(
207 		test_snr_error_f32(
208 			buf_length, output, (float32_t *)ref,
209 			SNR_ERROR_THRESH),
210 		ASSERT_MSG_SNR_LIMIT_EXCEED);
211 
212 	zassert_true(
213 		test_rel_error_f32(
214 			buf_length, output, (float32_t *)ref,
215 			REL_ERROR_THRESH),
216 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
217 
218 	/* Free output buffer */
219 	free(output);
220 }
221 
222 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 3, in_com1, in_com3, ref_mult_real, 3);
223 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 8, in_com1, in_com3, ref_mult_real, 8);
224 DEFINE_TEST_VARIANT4(arm_cmplx_mult_real_f32, 11, in_com1, in_com3, ref_mult_real, 11);
225 
test_complexmath_f32(void)226 void test_complexmath_f32(void)
227 {
228 	ztest_test_suite(complexmath_f32,
229 		ztest_unit_test(test_arm_cmplx_conj_f32_3),
230 		ztest_unit_test(test_arm_cmplx_conj_f32_8),
231 		ztest_unit_test(test_arm_cmplx_conj_f32_11),
232 		ztest_unit_test(test_arm_cmplx_dot_prod_f32_3),
233 		ztest_unit_test(test_arm_cmplx_dot_prod_f32_8),
234 		ztest_unit_test(test_arm_cmplx_dot_prod_f32_11),
235 		ztest_unit_test(test_arm_cmplx_mag_f32_3),
236 		ztest_unit_test(test_arm_cmplx_mag_f32_8),
237 		ztest_unit_test(test_arm_cmplx_mag_f32_11),
238 		ztest_unit_test(test_arm_cmplx_mag_squared_f32_3),
239 		ztest_unit_test(test_arm_cmplx_mag_squared_f32_8),
240 		ztest_unit_test(test_arm_cmplx_mag_squared_f32_11),
241 		ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_3),
242 		ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_8),
243 		ztest_unit_test(test_arm_cmplx_mult_cmplx_f32_11),
244 		ztest_unit_test(test_arm_cmplx_mult_real_f32_3),
245 		ztest_unit_test(test_arm_cmplx_mult_real_f32_8),
246 		ztest_unit_test(test_arm_cmplx_mult_real_f32_11)
247 		);
248 
249 	ztest_run_test_suite(complexmath_f32);
250 }
251