1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "binary_q15.pat"
15 
16 #define SNR_ERROR_THRESH		((float32_t)70)
17 #define SNR_LOW_ERROR_THRESH		((float32_t)30)
18 #define ABS_ERROR_THRESH_Q15		((q15_t)1000)
19 #define ABS_HIGH_ERROR_THRESH_Q15	((q15_t)2000)
20 #define ABS_ERROR_THRESH_Q63		((q63_t)(1 << 16))
21 
22 #define NUM_MATRICES			(ARRAY_SIZE(in_dims) / 3)
23 #define MAX_MATRIX_DIM			(40)
24 
25 #define OP2_MULT			(0)
26 #define OP2C_CMPLX_MULT			(0)
27 
test_op2(int op,const q15_t * input1,const q15_t * input2,const q15_t * ref,size_t length)28 static void test_op2(int op, const q15_t *input1, const q15_t *input2,
29 	const q15_t *ref, size_t length)
30 {
31 	size_t index;
32 	uint16_t *dims = (uint16_t *)in_dims;
33 	q15_t *tmp1, *tmp2, *scratch, *output;
34 	uint16_t rows, internal, columns;
35 	arm_status status;
36 
37 	arm_matrix_instance_q15 mat_in1;
38 	arm_matrix_instance_q15 mat_in2;
39 	arm_matrix_instance_q15 mat_out;
40 
41 	/* Allocate buffers */
42 	output = malloc(length * sizeof(q15_t));
43 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
44 
45 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
46 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
47 
48 	tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
49 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
50 
51 	scratch = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
52 	zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
53 
54 	/* Initialise contexts */
55 	mat_in1.pData = tmp1;
56 	mat_in2.pData = tmp2;
57 	mat_out.pData = output;
58 
59 	/* Iterate matrices */
60 	for (index = 0; index < NUM_MATRICES; index++) {
61 		rows = *dims++;
62 		internal = *dims++;
63 		columns = *dims++;
64 
65 		/* Initialise matrix dimensions */
66 		mat_in1.numRows = rows;
67 		mat_in1.numCols = internal;
68 
69 		mat_in2.numRows = internal;
70 		mat_in2.numCols = columns;
71 
72 		mat_out.numRows = rows;
73 		mat_out.numCols = columns;
74 
75 		/* Load matrix data */
76 		memcpy(mat_in1.pData, input1,
77 		       rows * internal * sizeof(q15_t));
78 
79 		memcpy(mat_in2.pData, input2,
80 		       internal * columns * sizeof(q15_t));
81 
82 		/* Run test function */
83 		switch (op) {
84 		case OP2_MULT:
85 			status = arm_mat_mult_q15(
86 					&mat_in1, &mat_in2, &mat_out,
87 					scratch);
88 			break;
89 		default:
90 			zassert_unreachable("invalid operation");
91 		}
92 
93 		/* Validate status */
94 		zassert_equal(status, ARM_MATH_SUCCESS,
95 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
96 
97 		/* Increment output pointer */
98 		mat_out.pData += (rows * columns);
99 	}
100 
101 	/* Validate output */
102 	zassert_true(
103 		test_snr_error_q15(length, output, ref, SNR_LOW_ERROR_THRESH),
104 		ASSERT_MSG_SNR_LIMIT_EXCEED);
105 
106 	zassert_true(
107 		test_near_equal_q15(length, output, ref,
108 			ABS_HIGH_ERROR_THRESH_Q15),
109 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
110 
111 	/* Free buffers */
112 	free(tmp1);
113 	free(tmp2);
114 	free(scratch);
115 	free(output);
116 }
117 
118 DEFINE_TEST_VARIANT5(matrix_binary_q15,
119 	op2, arm_mat_mult_q15, OP2_MULT,
120 	in_mult1, in_mult2, ref_mult,
121 	ARRAY_SIZE(ref_mult));
122 
test_op2c(int op,const q15_t * input1,const q15_t * input2,const q15_t * ref,size_t length)123 static void test_op2c(int op, const q15_t *input1, const q15_t *input2,
124 	const q15_t *ref, size_t length)
125 {
126 	size_t index;
127 	uint16_t *dims = (uint16_t *)in_dims;
128 	q15_t *tmp1, *tmp2, *scratch, *output;
129 	uint16_t rows, internal, columns;
130 	arm_status status;
131 
132 	arm_matrix_instance_q15 mat_in1;
133 	arm_matrix_instance_q15 mat_in2;
134 	arm_matrix_instance_q15 mat_out;
135 
136 	/* Allocate buffers */
137 	output = malloc(2 * length * sizeof(q15_t));
138 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
139 
140 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
141 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
142 
143 	tmp2 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
144 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
145 
146 	scratch = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q15_t));
147 	zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
148 
149 	/* Initialise contexts */
150 	mat_in1.pData = tmp1;
151 	mat_in2.pData = tmp2;
152 	mat_out.pData = output;
153 
154 	/* Iterate matrices */
155 	for (index = 0; index < NUM_MATRICES; index++) {
156 		rows = *dims++;
157 		internal = *dims++;
158 		columns = *dims++;
159 
160 		/* Initialise matrix dimensions */
161 		mat_in1.numRows = rows;
162 		mat_in1.numCols = internal;
163 
164 		mat_in2.numRows = internal;
165 		mat_in2.numCols = columns;
166 
167 		mat_out.numRows = rows;
168 		mat_out.numCols = columns;
169 
170 		/* Load matrix data */
171 		memcpy(mat_in1.pData, input1,
172 		       2 * rows * internal * sizeof(q15_t));
173 
174 		memcpy(mat_in2.pData, input2,
175 		       2 * internal * columns * sizeof(q15_t));
176 
177 		/* Run test function */
178 		switch (op) {
179 		case OP2C_CMPLX_MULT:
180 			status = arm_mat_cmplx_mult_q15(
181 					&mat_in1, &mat_in2, &mat_out,
182 					scratch);
183 			break;
184 		default:
185 			zassert_unreachable("invalid operation");
186 		}
187 
188 		/* Validate status */
189 		zassert_equal(status, ARM_MATH_SUCCESS,
190 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
191 
192 		/* Increment output pointer */
193 		mat_out.pData += (2 * rows * columns);
194 	}
195 
196 	/* Validate output */
197 	zassert_true(
198 		test_snr_error_q15(2 * length, output, ref, SNR_ERROR_THRESH),
199 		ASSERT_MSG_SNR_LIMIT_EXCEED);
200 
201 	zassert_true(
202 		test_near_equal_q15(2 * length, output, ref,
203 			ABS_ERROR_THRESH_Q15),
204 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
205 
206 	/* Free buffers */
207 	free(tmp1);
208 	free(tmp2);
209 	free(scratch);
210 	free(output);
211 }
212 
213 DEFINE_TEST_VARIANT5(matrix_binary_q15,
214 	op2c, arm_mat_cmplx_mult_q15, OP2C_CMPLX_MULT,
215 	in_cmplx_mult1, in_cmplx_mult2, ref_cmplx_mult,
216 	ARRAY_SIZE(ref_cmplx_mult) / 2);
217 
218 ZTEST_SUITE(matrix_binary_q15, NULL, NULL, NULL, NULL, NULL);
219