1 /*
2  * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "f32.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)120)
17 #define REL_ERROR_THRESH	(7.0e-6)
18 
19 ZTEST_SUITE(complexmath_f32, NULL, NULL, NULL, NULL, NULL);
20 
test_arm_cmplx_conj_f32(const uint32_t * input1,const uint32_t * ref,size_t length)21 static void test_arm_cmplx_conj_f32(
22 	const uint32_t *input1, const uint32_t *ref, size_t length)
23 {
24 	size_t buf_length;
25 	float32_t *output;
26 
27 	/* Complex number buffer length is twice the data length */
28 	buf_length = 2 * length;
29 
30 	/* Allocate output buffer */
31 	output = malloc(buf_length * sizeof(float32_t));
32 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
33 
34 	/* Run test function */
35 	arm_cmplx_conj_f32((float32_t *)input1, output, length);
36 
37 	/* Validate output */
38 	zassert_true(
39 		test_snr_error_f32(buf_length, output, (float32_t *)ref,
40 			SNR_ERROR_THRESH),
41 		ASSERT_MSG_SNR_LIMIT_EXCEED);
42 
43 	zassert_true(
44 		test_rel_error_f32(buf_length, output, (float32_t *)ref,
45 			REL_ERROR_THRESH),
46 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
47 
48 	/* Free output buffer */
49 	free(output);
50 }
51 
52 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_conj_f32, 3, in_com1, ref_conj, 3);
53 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_conj_f32, 8, in_com1, ref_conj, 8);
54 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_conj_f32, 11, in_com1, ref_conj, 11);
55 
test_arm_cmplx_dot_prod_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)56 static void test_arm_cmplx_dot_prod_f32(
57 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
58 	size_t length)
59 {
60 	float32_t *output;
61 
62 	/* Allocate output buffer */
63 	output = malloc(2 * sizeof(float32_t));
64 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
65 
66 	/* Run test function */
67 	arm_cmplx_dot_prod_f32(
68 		(float32_t *)input1, (float32_t *)input2, length,
69 		&output[0], &output[1]);
70 
71 	/* Validate output */
72 	zassert_true(
73 		test_snr_error_f32(2, output, (float32_t *)ref,
74 			SNR_ERROR_THRESH),
75 		ASSERT_MSG_SNR_LIMIT_EXCEED);
76 
77 	zassert_true(
78 		test_rel_error_f32(2, output, (float32_t *)ref,
79 			REL_ERROR_THRESH),
80 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
81 
82 	/* Free output buffer */
83 	free(output);
84 }
85 
86 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_dot_prod_f32, 3, in_com1, in_com2, ref_dot_prod_3,
87 		     3);
88 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_dot_prod_f32, 8, in_com1, in_com2, ref_dot_prod_4n,
89 		     8);
90 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_dot_prod_f32, 11, in_com1, in_com2,
91 		     ref_dot_prod_4n1, 11);
92 
test_arm_cmplx_mag_f32(const uint32_t * input1,const uint32_t * ref,size_t length)93 static void test_arm_cmplx_mag_f32(
94 	const uint32_t *input1, const uint32_t *ref, size_t length)
95 {
96 	float32_t *output;
97 
98 	/* Allocate output buffer */
99 	output = malloc(length * sizeof(float32_t));
100 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
101 
102 	/* Run test function */
103 	arm_cmplx_mag_f32((float32_t *)input1, output, length);
104 
105 	/* Validate output */
106 	zassert_true(
107 		test_snr_error_f32(length, output, (float32_t *)ref,
108 			SNR_ERROR_THRESH),
109 		ASSERT_MSG_SNR_LIMIT_EXCEED);
110 
111 	zassert_true(
112 		test_rel_error_f32(length, output, (float32_t *)ref,
113 			REL_ERROR_THRESH),
114 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
115 
116 	/* Free output buffer */
117 	free(output);
118 }
119 
120 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_f32, 3, in_com1, ref_mag, 3);
121 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_f32, 8, in_com1, ref_mag, 8);
122 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_f32, 11, in_com1, ref_mag, 11);
123 
test_arm_cmplx_mag_squared_f32(const uint32_t * input1,const uint32_t * ref,size_t length)124 static void test_arm_cmplx_mag_squared_f32(
125 	const uint32_t *input1, const uint32_t *ref, size_t length)
126 {
127 	float32_t *output;
128 
129 	/* Allocate output buffer */
130 	output = malloc(length * sizeof(float32_t));
131 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
132 
133 	/* Run test function */
134 	arm_cmplx_mag_squared_f32((float32_t *)input1, output, length);
135 
136 	/* Validate output */
137 	zassert_true(
138 		test_snr_error_f32(length, output, (float32_t *)ref,
139 			SNR_ERROR_THRESH),
140 		ASSERT_MSG_SNR_LIMIT_EXCEED);
141 
142 	zassert_true(
143 		test_rel_error_f32(length, output, (float32_t *)ref,
144 			REL_ERROR_THRESH),
145 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
146 
147 	/* Free output buffer */
148 	free(output);
149 }
150 
151 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_squared_f32, 3, in_com1, ref_mag_squared, 3);
152 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_squared_f32, 8, in_com1, ref_mag_squared, 8);
153 DEFINE_TEST_VARIANT3(complexmath_f32, arm_cmplx_mag_squared_f32, 11, in_com1, ref_mag_squared, 11);
154 
test_arm_cmplx_mult_cmplx_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)155 static void test_arm_cmplx_mult_cmplx_f32(
156 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
157 	size_t length)
158 {
159 	size_t buf_length;
160 	float32_t *output;
161 
162 	/* Complex number buffer length is twice the data length */
163 	buf_length = 2 * length;
164 
165 	/* Allocate output buffer */
166 	output = malloc(buf_length * sizeof(float32_t));
167 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
168 
169 	/* Run test function */
170 	arm_cmplx_mult_cmplx_f32(
171 		(float32_t *)input1, (float32_t *)input2, output, length);
172 
173 	/* Validate output */
174 	zassert_true(
175 		test_snr_error_f32(buf_length, output, (float32_t *)ref,
176 			SNR_ERROR_THRESH),
177 		ASSERT_MSG_SNR_LIMIT_EXCEED);
178 
179 	zassert_true(
180 		test_rel_error_f32(buf_length, output, (float32_t *)ref,
181 			REL_ERROR_THRESH),
182 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
183 
184 	/* Free output buffer */
185 	free(output);
186 }
187 
188 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_cmplx_f32, 3, in_com1, in_com2, ref_mult_cmplx,
189 		     3);
190 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_cmplx_f32, 8, in_com1, in_com2, ref_mult_cmplx,
191 		     8);
192 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_cmplx_f32, 11, in_com1, in_com2,
193 		     ref_mult_cmplx, 11);
194 
test_arm_cmplx_mult_real_f32(const uint32_t * input1,const uint32_t * input2,const uint32_t * ref,size_t length)195 static void test_arm_cmplx_mult_real_f32(
196 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref,
197 	size_t length)
198 {
199 	size_t buf_length;
200 	float32_t *output;
201 
202 	/* Complex number buffer length is twice the data length */
203 	buf_length = 2 * length;
204 
205 	/* Allocate output buffer */
206 	output = malloc(buf_length * sizeof(float32_t));
207 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
208 
209 	/* Run test function */
210 	arm_cmplx_mult_real_f32(
211 		(float32_t *)input1, (float32_t *)input2, output, length);
212 
213 	/* Validate output */
214 	zassert_true(
215 		test_snr_error_f32(
216 			buf_length, output, (float32_t *)ref,
217 			SNR_ERROR_THRESH),
218 		ASSERT_MSG_SNR_LIMIT_EXCEED);
219 
220 	zassert_true(
221 		test_rel_error_f32(
222 			buf_length, output, (float32_t *)ref,
223 			REL_ERROR_THRESH),
224 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
225 
226 	/* Free output buffer */
227 	free(output);
228 }
229 
230 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_real_f32, 3, in_com1, in_com3, ref_mult_real,
231 		     3);
232 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_real_f32, 8, in_com1, in_com3, ref_mult_real,
233 		     8);
234 DEFINE_TEST_VARIANT4(complexmath_f32, arm_cmplx_mult_real_f32, 11, in_com1, in_com3, ref_mult_real,
235 		     11);
236