1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "f16.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)39)
17 #define REL_ERROR_THRESH	(6.0e-2)
18 
19 ZTEST_SUITE(complexmath_f16, NULL, NULL, NULL, NULL, NULL);
20 
test_arm_cmplx_conj_f16(const uint16_t * input1,const uint16_t * ref,size_t length)21 static void test_arm_cmplx_conj_f16(
22 	const uint16_t *input1, const uint16_t *ref, size_t length)
23 {
24 	size_t buf_length;
25 	float16_t *output;
26 
27 	/* Complex number buffer length is twice the data length */
28 	buf_length = 2 * length;
29 
30 	/* Allocate output buffer */
31 	output = malloc(buf_length * sizeof(float16_t));
32 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
33 
34 	/* Run test function */
35 	arm_cmplx_conj_f16((float16_t *)input1, output, length);
36 
37 	/* Validate output */
38 	zassert_true(
39 		test_snr_error_f16(buf_length, output, (float16_t *)ref,
40 			SNR_ERROR_THRESH),
41 		ASSERT_MSG_SNR_LIMIT_EXCEED);
42 
43 	zassert_true(
44 		test_rel_error_f16(buf_length, output, (float16_t *)ref,
45 			REL_ERROR_THRESH),
46 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
47 
48 	/* Free output buffer */
49 	free(output);
50 }
51 
52 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 7, in_com1, ref_conj, 7);
53 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 16, in_com1, ref_conj, 16);
54 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_conj_f16, 23, in_com1, ref_conj, 23);
55 
test_arm_cmplx_dot_prod_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)56 static void test_arm_cmplx_dot_prod_f16(
57 	const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
58 	size_t length)
59 {
60 	float16_t *output;
61 
62 	/* Allocate output buffer */
63 	output = malloc(2 * sizeof(float16_t));
64 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
65 
66 	/* Run test function */
67 	arm_cmplx_dot_prod_f16(
68 		(float16_t *)input1, (float16_t *)input2, length,
69 		&output[0], &output[1]);
70 
71 	/* Validate output */
72 	zassert_true(
73 		test_snr_error_f16(2, output, (float16_t *)ref,
74 			SNR_ERROR_THRESH),
75 		ASSERT_MSG_SNR_LIMIT_EXCEED);
76 
77 	zassert_true(
78 		test_rel_error_f16(2, output, (float16_t *)ref,
79 			REL_ERROR_THRESH),
80 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
81 
82 	/* Free output buffer */
83 	free(output);
84 }
85 
86 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 7, in_com1, in_com2, ref_dot_prod_3,
87 		     7);
88 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 16, in_com1, in_com2, ref_dot_prod_4n,
89 		     16);
90 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_dot_prod_f16, 23, in_com1, in_com2,
91 		     ref_dot_prod_4n1, 23);
92 
test_arm_cmplx_mag_f16(const uint16_t * input1,const uint16_t * ref,size_t length)93 static void test_arm_cmplx_mag_f16(
94 	const uint16_t *input1, const uint16_t *ref, size_t length)
95 {
96 	float16_t *output;
97 
98 	/* Allocate output buffer */
99 	output = malloc(length * sizeof(float16_t));
100 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
101 
102 	/* Run test function */
103 	arm_cmplx_mag_f16((float16_t *)input1, output, length);
104 
105 	/* Validate output */
106 	zassert_true(
107 		test_snr_error_f16(length, output, (float16_t *)ref,
108 			SNR_ERROR_THRESH),
109 		ASSERT_MSG_SNR_LIMIT_EXCEED);
110 
111 	zassert_true(
112 		test_rel_error_f16(length, output, (float16_t *)ref,
113 			REL_ERROR_THRESH),
114 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
115 
116 	/* Free output buffer */
117 	free(output);
118 }
119 
120 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 7, in_com1, ref_mag, 7);
121 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 16, in_com1, ref_mag, 16);
122 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_f16, 23, in_com1, ref_mag, 23);
123 
test_arm_cmplx_mag_squared_f16(const uint16_t * input1,const uint16_t * ref,size_t length)124 static void test_arm_cmplx_mag_squared_f16(
125 	const uint16_t *input1, const uint16_t *ref, size_t length)
126 {
127 	float16_t *output;
128 
129 	/* Allocate output buffer */
130 	output = malloc(length * sizeof(float16_t));
131 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
132 
133 	/* Run test function */
134 	arm_cmplx_mag_squared_f16((float16_t *)input1, output, length);
135 
136 	/* Validate output */
137 	zassert_true(
138 		test_snr_error_f16(length, output, (float16_t *)ref,
139 			SNR_ERROR_THRESH),
140 		ASSERT_MSG_SNR_LIMIT_EXCEED);
141 
142 	zassert_true(
143 		test_rel_error_f16(length, output, (float16_t *)ref,
144 			REL_ERROR_THRESH),
145 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
146 
147 	/* Free output buffer */
148 	free(output);
149 }
150 
151 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 7, in_com1, ref_mag_squared, 7);
152 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 16, in_com1, ref_mag_squared, 16);
153 DEFINE_TEST_VARIANT3(complexmath_f16, arm_cmplx_mag_squared_f16, 23, in_com1, ref_mag_squared, 23);
154 
test_arm_cmplx_mult_cmplx_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)155 static void test_arm_cmplx_mult_cmplx_f16(
156 	const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
157 	size_t length)
158 {
159 	size_t buf_length;
160 	float16_t *output;
161 
162 	/* Complex number buffer length is twice the data length */
163 	buf_length = 2 * length;
164 
165 	/* Allocate output buffer */
166 	output = malloc(buf_length * sizeof(float16_t));
167 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
168 
169 	/* Run test function */
170 	arm_cmplx_mult_cmplx_f16(
171 		(float16_t *)input1, (float16_t *)input2, output, length);
172 
173 	/* Validate output */
174 	zassert_true(
175 		test_snr_error_f16(buf_length, output, (float16_t *)ref,
176 			SNR_ERROR_THRESH),
177 		ASSERT_MSG_SNR_LIMIT_EXCEED);
178 
179 	zassert_true(
180 		test_rel_error_f16(buf_length, output, (float16_t *)ref,
181 			REL_ERROR_THRESH),
182 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
183 
184 	/* Free output buffer */
185 	free(output);
186 }
187 
188 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 7, in_com1, in_com2, ref_mult_cmplx,
189 		     7);
190 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 16, in_com1, in_com2,
191 		     ref_mult_cmplx, 16);
192 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_cmplx_f16, 23, in_com1, in_com2,
193 		     ref_mult_cmplx, 23);
194 
test_arm_cmplx_mult_real_f16(const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)195 static void test_arm_cmplx_mult_real_f16(
196 	const uint16_t *input1, const uint16_t *input2, const uint16_t *ref,
197 	size_t length)
198 {
199 	size_t buf_length;
200 	float16_t *output;
201 
202 	/* Complex number buffer length is twice the data length */
203 	buf_length = 2 * length;
204 
205 	/* Allocate output buffer */
206 	output = malloc(buf_length * sizeof(float16_t));
207 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
208 
209 	/* Run test function */
210 	arm_cmplx_mult_real_f16(
211 		(float16_t *)input1, (float16_t *)input2, output, length);
212 
213 	/* Validate output */
214 	zassert_true(
215 		test_snr_error_f16(
216 			buf_length, output, (float16_t *)ref,
217 			SNR_ERROR_THRESH),
218 		ASSERT_MSG_SNR_LIMIT_EXCEED);
219 
220 	zassert_true(
221 		test_rel_error_f16(
222 			buf_length, output, (float16_t *)ref,
223 			REL_ERROR_THRESH),
224 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
225 
226 	/* Free output buffer */
227 	free(output);
228 }
229 
230 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 7, in_com1, in_com3, ref_mult_real,
231 		     7);
232 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 16, in_com1, in_com3, ref_mult_real,
233 		     16);
234 DEFINE_TEST_VARIANT4(complexmath_f16, arm_cmplx_mult_real_f16, 23, in_com1, in_com3, ref_mult_real,
235 		     23);
236