1 /*
2  * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "q31.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)100)
17 #define ABS_ERROR_THRESH_Q31	((q31_t)100)
18 #define ABS_ERROR_THRESH_Q63	((q63_t)(1 << 18))
19 
20 ZTEST_SUITE(complexmath_q31, NULL, NULL, NULL, NULL, NULL);
21 
test_arm_cmplx_conj_q31(const q31_t * input1,const q31_t * ref,size_t length)22 static void test_arm_cmplx_conj_q31(
23 	const q31_t *input1, const q31_t *ref, size_t length)
24 {
25 	size_t buf_length;
26 	q31_t *output;
27 
28 	/* Complex number buffer length is twice the data length */
29 	buf_length = 2 * length;
30 
31 	/* Allocate output buffer */
32 	output = malloc(buf_length * sizeof(q31_t));
33 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
34 
35 	/* Run test function */
36 	arm_cmplx_conj_q31(input1, output, length);
37 
38 	/* Validate output */
39 	zassert_true(
40 		test_snr_error_q31(buf_length, output, ref, SNR_ERROR_THRESH),
41 		ASSERT_MSG_SNR_LIMIT_EXCEED);
42 
43 	zassert_true(
44 		test_near_equal_q31(buf_length, output, ref,
45 			ABS_ERROR_THRESH_Q31),
46 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
47 
48 	/* Free output buffer */
49 	free(output);
50 }
51 
52 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_conj_q31, 3, in_com1, ref_conj, 3);
53 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_conj_q31, 8, in_com1, ref_conj, 8);
54 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_conj_q31, 11, in_com1, ref_conj, 11);
55 
test_arm_cmplx_dot_prod_q31(const q31_t * input1,const q31_t * input2,const q63_t * ref,size_t length)56 static void test_arm_cmplx_dot_prod_q31(
57 	const q31_t *input1, const q31_t *input2, const q63_t *ref,
58 	size_t length)
59 {
60 	q63_t *output;
61 
62 	/* Allocate output buffer */
63 	output = malloc(2 * sizeof(q63_t));
64 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
65 
66 	/* Run test function */
67 	arm_cmplx_dot_prod_q31(input1, input2, length, &output[0], &output[1]);
68 
69 	/* Validate output */
70 	zassert_true(
71 		test_snr_error_q63(2, output, ref, SNR_ERROR_THRESH),
72 		ASSERT_MSG_SNR_LIMIT_EXCEED);
73 
74 	zassert_true(
75 		test_near_equal_q63(2, output, ref, ABS_ERROR_THRESH_Q63),
76 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
77 
78 	/* Free output buffer */
79 	free(output);
80 }
81 
82 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_dot_prod_q31, 3, in_com1, in_com2, ref_dot_prod_3,
83 		     3);
84 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_dot_prod_q31, 8, in_com1, in_com2, ref_dot_prod_4n,
85 		     8);
86 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_dot_prod_q31, 11, in_com1, in_com2,
87 		     ref_dot_prod_4n1, 11);
88 
test_arm_cmplx_mag_q31(const q31_t * input1,const q31_t * ref,size_t length)89 static void test_arm_cmplx_mag_q31(
90 	const q31_t *input1, const q31_t *ref, size_t length)
91 {
92 	q31_t *output;
93 
94 	/* Allocate output buffer */
95 	output = malloc(length * sizeof(q31_t));
96 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
97 
98 	/* Run test function */
99 	arm_cmplx_mag_q31(input1, output, length);
100 
101 	/* Validate output */
102 	zassert_true(
103 		test_snr_error_q31(length, output, ref, SNR_ERROR_THRESH),
104 		ASSERT_MSG_SNR_LIMIT_EXCEED);
105 
106 	zassert_true(
107 		test_near_equal_q31(length, output, ref, ABS_ERROR_THRESH_Q31),
108 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
109 
110 	/* Free output buffer */
111 	free(output);
112 }
113 
114 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_q31, 3, in_com1, ref_mag, 3);
115 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_q31, 8, in_com1, ref_mag, 8);
116 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_q31, 11, in_com1, ref_mag, 11);
117 
test_arm_cmplx_mag_squared_q31(const q31_t * input1,const q31_t * ref,size_t length)118 static void test_arm_cmplx_mag_squared_q31(
119 	const q31_t *input1, const q31_t *ref, size_t length)
120 {
121 	q31_t *output;
122 
123 	/* Allocate output buffer */
124 	output = malloc(length * sizeof(q31_t));
125 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
126 
127 	/* Run test function */
128 	arm_cmplx_mag_squared_q31(input1, output, length);
129 
130 	/* Validate output */
131 	zassert_true(
132 		test_snr_error_q31(length, output, ref, SNR_ERROR_THRESH),
133 		ASSERT_MSG_SNR_LIMIT_EXCEED);
134 
135 	zassert_true(
136 		test_near_equal_q31(length, output, ref, ABS_ERROR_THRESH_Q31),
137 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
138 
139 	/* Free output buffer */
140 	free(output);
141 }
142 
143 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_squared_q31, 3, in_com1, ref_mag_squared, 3);
144 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_squared_q31, 8, in_com1, ref_mag_squared, 8);
145 DEFINE_TEST_VARIANT3(complexmath_q31, arm_cmplx_mag_squared_q31, 11, in_com1, ref_mag_squared, 11);
146 
test_arm_cmplx_mult_cmplx_q31(const q31_t * input1,const q31_t * input2,const q31_t * ref,size_t length)147 static void test_arm_cmplx_mult_cmplx_q31(
148 	const q31_t *input1, const q31_t *input2, const q31_t *ref,
149 	size_t length)
150 {
151 	size_t buf_length;
152 	q31_t *output;
153 
154 	/* Complex number buffer length is twice the data length */
155 	buf_length = 2 * length;
156 
157 	/* Allocate output buffer */
158 	output = malloc(buf_length * sizeof(q31_t));
159 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
160 
161 	/* Run test function */
162 	arm_cmplx_mult_cmplx_q31(input1, input2, output, length);
163 
164 	/* Validate output */
165 	zassert_true(
166 		test_snr_error_q31(buf_length, output, ref, SNR_ERROR_THRESH),
167 		ASSERT_MSG_SNR_LIMIT_EXCEED);
168 
169 	zassert_true(
170 		test_near_equal_q31(buf_length, output, ref,
171 			ABS_ERROR_THRESH_Q31),
172 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
173 
174 	/* Free output buffer */
175 	free(output);
176 }
177 
178 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_cmplx_q31, 3, in_com1, in_com2, ref_mult_cmplx,
179 		     3);
180 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_cmplx_q31, 8, in_com1, in_com2, ref_mult_cmplx,
181 		     8);
182 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_cmplx_q31, 11, in_com1, in_com2,
183 		     ref_mult_cmplx, 11);
184 
test_arm_cmplx_mult_real_q31(const q31_t * input1,const q31_t * input2,const q31_t * ref,size_t length)185 static void test_arm_cmplx_mult_real_q31(
186 	const q31_t *input1, const q31_t *input2, const q31_t *ref,
187 	size_t length)
188 {
189 	size_t buf_length;
190 	q31_t *output;
191 
192 	/* Complex number buffer length is twice the data length */
193 	buf_length = 2 * length;
194 
195 	/* Allocate output buffer */
196 	output = malloc(buf_length * sizeof(q31_t));
197 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
198 
199 	/* Run test function */
200 	arm_cmplx_mult_real_q31(input1, input2, output, length);
201 
202 	/* Validate output */
203 	zassert_true(
204 		test_snr_error_q31(buf_length, output, ref, SNR_ERROR_THRESH),
205 		ASSERT_MSG_SNR_LIMIT_EXCEED);
206 
207 	zassert_true(
208 		test_near_equal_q31(buf_length, output, ref,
209 			ABS_ERROR_THRESH_Q31),
210 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
211 
212 	/* Free output buffer */
213 	free(output);
214 }
215 
216 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_real_q31, 3, in_com1, in_com3, ref_mult_real,
217 		     3);
218 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_real_q31, 8, in_com1, in_com3, ref_mult_real,
219 		     8);
220 DEFINE_TEST_VARIANT4(complexmath_q31, arm_cmplx_mult_real_q31, 11, in_com1, in_com3, ref_mult_real,
221 		     11);
222