1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "binary_q15.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)70)
17 #define SNR_LOW_ERROR_THRESH ((float32_t)30)
18 #define ABS_ERROR_THRESH_Q15 ((q15_t)1000)
19 #define ABS_HIGH_ERROR_THRESH_Q15 ((q15_t)2000)
20 #define ABS_ERROR_THRESH_Q63 ((q63_t)(1 << 16))
21
22 #define NUM_MATRICES (ARRAY_SIZE(in_dims) / 3)
23 #define MAX_MATRIX_DIM (40)
24
25 #define OP2_MULT (0)
26 #define OP2C_CMPLX_MULT (0)
27
test_op2(int op,const q15_t * input1,const q15_t * input2,const q15_t * ref,size_t length)28 static void test_op2(int op, const q15_t *input1, const q15_t *input2,
29 const q15_t *ref, size_t length)
30 {
31 size_t index;
32 uint16_t *dims = (uint16_t *)in_dims;
33 q15_t *tmp1, *tmp2, *scratch, *output;
34 uint16_t rows, internal, columns;
35 arm_status status;
36
37 arm_matrix_instance_q15 mat_in1;
38 arm_matrix_instance_q15 mat_in2;
39 arm_matrix_instance_q15 mat_out;
40
41 /* Allocate buffers */
42 output = malloc(length * sizeof(q15_t));
43 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
44
45 tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
46 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
47
48 tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
49 zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
50
51 scratch = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
52 zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
53
54 /* Initialise contexts */
55 mat_in1.pData = tmp1;
56 mat_in2.pData = tmp2;
57 mat_out.pData = output;
58
59 /* Iterate matrices */
60 for (index = 0; index < NUM_MATRICES; index++) {
61 rows = *dims++;
62 internal = *dims++;
63 columns = *dims++;
64
65 /* Initialise matrix dimensions */
66 mat_in1.numRows = rows;
67 mat_in1.numCols = internal;
68
69 mat_in2.numRows = internal;
70 mat_in2.numCols = columns;
71
72 mat_out.numRows = rows;
73 mat_out.numCols = columns;
74
75 /* Load matrix data */
76 memcpy(mat_in1.pData, input1,
77 rows * internal * sizeof(q15_t));
78
79 memcpy(mat_in2.pData, input2,
80 internal * columns * sizeof(q15_t));
81
82 /* Run test function */
83 switch (op) {
84 case OP2_MULT:
85 status = arm_mat_mult_q15(
86 &mat_in1, &mat_in2, &mat_out,
87 scratch);
88 break;
89 default:
90 zassert_unreachable("invalid operation");
91 }
92
93 /* Validate status */
94 zassert_equal(status, ARM_MATH_SUCCESS,
95 ASSERT_MSG_INCORRECT_COMP_RESULT);
96
97 /* Increment output pointer */
98 mat_out.pData += (rows * columns);
99 }
100
101 /* Validate output */
102 zassert_true(
103 test_snr_error_q15(length, output, ref, SNR_LOW_ERROR_THRESH),
104 ASSERT_MSG_SNR_LIMIT_EXCEED);
105
106 zassert_true(
107 test_near_equal_q15(length, output, ref,
108 ABS_HIGH_ERROR_THRESH_Q15),
109 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
110
111 /* Free buffers */
112 free(tmp1);
113 free(tmp2);
114 free(scratch);
115 free(output);
116 }
117
118 DEFINE_TEST_VARIANT5(matrix_binary_q15,
119 op2, arm_mat_mult_q15, OP2_MULT,
120 in_mult1, in_mult2, ref_mult,
121 ARRAY_SIZE(ref_mult));
122
test_op2c(int op,const q15_t * input1,const q15_t * input2,const q15_t * ref,size_t length)123 static void test_op2c(int op, const q15_t *input1, const q15_t *input2,
124 const q15_t *ref, size_t length)
125 {
126 size_t index;
127 uint16_t *dims = (uint16_t *)in_dims;
128 q15_t *tmp1, *tmp2, *scratch, *output;
129 uint16_t rows, internal, columns;
130 arm_status status;
131
132 arm_matrix_instance_q15 mat_in1;
133 arm_matrix_instance_q15 mat_in2;
134 arm_matrix_instance_q15 mat_out;
135
136 /* Allocate buffers */
137 output = malloc(2 * length * sizeof(q15_t));
138 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
139
140 tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
141 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
142
143 tmp2 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
144 zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
145
146 scratch = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
147 zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
148
149 /* Initialise contexts */
150 mat_in1.pData = tmp1;
151 mat_in2.pData = tmp2;
152 mat_out.pData = output;
153
154 /* Iterate matrices */
155 for (index = 0; index < NUM_MATRICES; index++) {
156 rows = *dims++;
157 internal = *dims++;
158 columns = *dims++;
159
160 /* Initialise matrix dimensions */
161 mat_in1.numRows = rows;
162 mat_in1.numCols = internal;
163
164 mat_in2.numRows = internal;
165 mat_in2.numCols = columns;
166
167 mat_out.numRows = rows;
168 mat_out.numCols = columns;
169
170 /* Load matrix data */
171 memcpy(mat_in1.pData, input1,
172 2 * rows * internal * sizeof(q15_t));
173
174 memcpy(mat_in2.pData, input2,
175 2 * internal * columns * sizeof(q15_t));
176
177 /* Run test function */
178 switch (op) {
179 case OP2C_CMPLX_MULT:
180 status = arm_mat_cmplx_mult_q15(
181 &mat_in1, &mat_in2, &mat_out,
182 scratch);
183 break;
184 default:
185 zassert_unreachable("invalid operation");
186 }
187
188 /* Validate status */
189 zassert_equal(status, ARM_MATH_SUCCESS,
190 ASSERT_MSG_INCORRECT_COMP_RESULT);
191
192 /* Increment output pointer */
193 mat_out.pData += (2 * rows * columns);
194 }
195
196 /* Validate output */
197 zassert_true(
198 test_snr_error_q15(2 * length, output, ref, SNR_ERROR_THRESH),
199 ASSERT_MSG_SNR_LIMIT_EXCEED);
200
201 zassert_true(
202 test_near_equal_q15(2 * length, output, ref,
203 ABS_ERROR_THRESH_Q15),
204 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
205
206 /* Free buffers */
207 free(tmp1);
208 free(tmp2);
209 free(scratch);
210 free(output);
211 }
212
213 DEFINE_TEST_VARIANT5(matrix_binary_q15,
214 op2c, arm_mat_cmplx_mult_q15, OP2C_CMPLX_MULT,
215 in_cmplx_mult1, in_cmplx_mult2, ref_cmplx_mult,
216 ARRAY_SIZE(ref_cmplx_mult) / 2);
217
218 ZTEST_SUITE(matrix_binary_q15, NULL, NULL, NULL, NULL, NULL);
219