1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "binary_q31.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)100)
17 #define ABS_ERROR_THRESH_Q31	((q31_t)5)
18 #define ABS_ERROR_THRESH_Q63	((q63_t)(1 << 16))
19 
20 #define NUM_MATRICES		(ARRAY_SIZE(in_dims) / 3)
21 #define MAX_MATRIX_DIM		(40)
22 
23 #define OP2_MULT		(0)
24 #define OP2C_CMPLX_MULT		(0)
25 
test_op2(int op,const q31_t * input1,const q31_t * input2,const q31_t * ref,size_t length)26 static void test_op2(int op, const q31_t *input1, const q31_t *input2,
27 	const q31_t *ref, size_t length)
28 {
29 	size_t index;
30 	uint16_t *dims = (uint16_t *)in_dims;
31 	q31_t *tmp1, *tmp2, *output;
32 	uint16_t rows, internal, columns;
33 	arm_status status;
34 
35 	arm_matrix_instance_q31 mat_in1;
36 	arm_matrix_instance_q31 mat_in2;
37 	arm_matrix_instance_q31 mat_out;
38 
39 	/* Allocate buffers */
40 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q31_t));
41 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
42 
43 	tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q31_t));
44 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
45 
46 	output = malloc(length * sizeof(q31_t));
47 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
48 
49 	/* Initialise contexts */
50 	mat_in1.pData = tmp1;
51 	mat_in2.pData = tmp2;
52 	mat_out.pData = output;
53 
54 	/* Iterate matrices */
55 	for (index = 0; index < NUM_MATRICES; index++) {
56 		rows = *dims++;
57 		internal = *dims++;
58 		columns = *dims++;
59 
60 		/* Initialise matrix dimensions */
61 		mat_in1.numRows = rows;
62 		mat_in1.numCols = internal;
63 
64 		mat_in2.numRows = internal;
65 		mat_in2.numCols = columns;
66 
67 		mat_out.numRows = rows;
68 		mat_out.numCols = columns;
69 
70 		/* Load matrix data */
71 		memcpy(mat_in1.pData, input1,
72 		       rows * internal * sizeof(q31_t));
73 
74 		memcpy(mat_in2.pData, input2,
75 		       internal * columns * sizeof(q31_t));
76 
77 		/* Run test function */
78 		switch (op) {
79 		case OP2_MULT:
80 			status = arm_mat_mult_q31(&mat_in1, &mat_in2,
81 						  &mat_out);
82 			break;
83 		default:
84 			zassert_unreachable("invalid operation");
85 		}
86 
87 		/* Validate status */
88 		zassert_equal(status, ARM_MATH_SUCCESS,
89 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
90 
91 		/* Increment output pointer */
92 		mat_out.pData += (rows * columns);
93 	}
94 
95 	/* Validate output */
96 	zassert_true(
97 		test_snr_error_q31(length, output, ref, SNR_ERROR_THRESH),
98 		ASSERT_MSG_SNR_LIMIT_EXCEED);
99 
100 	zassert_true(
101 		test_near_equal_q31(length, output, ref, ABS_ERROR_THRESH_Q31),
102 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
103 
104 	/* Free buffers */
105 	free(tmp1);
106 	free(tmp2);
107 	free(output);
108 }
109 
110 DEFINE_TEST_VARIANT5(matrix_binary_q31,
111 	op2, arm_mat_mult_q31, OP2_MULT,
112 	in_mult1, in_mult2, ref_mult,
113 	ARRAY_SIZE(ref_mult));
114 
test_op2c(int op,const q31_t * input1,const q31_t * input2,const q31_t * ref,size_t length)115 static void test_op2c(int op, const q31_t *input1, const q31_t *input2,
116 	const q31_t *ref, size_t length)
117 {
118 	size_t index;
119 	uint16_t *dims = (uint16_t *)in_dims;
120 	q31_t *tmp1, *tmp2, *output;
121 	uint16_t rows, internal, columns;
122 	arm_status status;
123 
124 	arm_matrix_instance_q31 mat_in1;
125 	arm_matrix_instance_q31 mat_in2;
126 	arm_matrix_instance_q31 mat_out;
127 
128 	/* Allocate buffers */
129 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q31_t));
130 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
131 
132 	tmp2 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q31_t));
133 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
134 
135 	output = malloc(2 * length * sizeof(q31_t));
136 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
137 
138 	/* Initialise contexts */
139 	mat_in1.pData = tmp1;
140 	mat_in2.pData = tmp2;
141 	mat_out.pData = output;
142 
143 	/* Iterate matrices */
144 	for (index = 0; index < NUM_MATRICES; index++) {
145 		rows = *dims++;
146 		internal = *dims++;
147 		columns = *dims++;
148 
149 		/* Initialise matrix dimensions */
150 		mat_in1.numRows = rows;
151 		mat_in1.numCols = internal;
152 
153 		mat_in2.numRows = internal;
154 		mat_in2.numCols = columns;
155 
156 		mat_out.numRows = rows;
157 		mat_out.numCols = columns;
158 
159 		/* Load matrix data */
160 		memcpy(mat_in1.pData, input1,
161 		       2 * rows * internal * sizeof(q31_t));
162 
163 		memcpy(mat_in2.pData, input2,
164 		       2 * internal * columns * sizeof(q31_t));
165 
166 		/* Run test function */
167 		switch (op) {
168 		case OP2C_CMPLX_MULT:
169 			status = arm_mat_cmplx_mult_q31(&mat_in1, &mat_in2,
170 							&mat_out);
171 			break;
172 		default:
173 			zassert_unreachable("invalid operation");
174 		}
175 
176 		/* Validate status */
177 		zassert_equal(status, ARM_MATH_SUCCESS,
178 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
179 
180 		/* Increment output pointer */
181 		mat_out.pData += (2 * rows * columns);
182 	}
183 
184 	/* Validate output */
185 	zassert_true(
186 		test_snr_error_q31(2 * length, output, ref, SNR_ERROR_THRESH),
187 		ASSERT_MSG_SNR_LIMIT_EXCEED);
188 
189 	zassert_true(
190 		test_near_equal_q31(2 * length, output, ref,
191 			ABS_ERROR_THRESH_Q31),
192 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
193 
194 	/* Free buffers */
195 	free(tmp1);
196 	free(tmp2);
197 	free(output);
198 }
199 
200 DEFINE_TEST_VARIANT5(matrix_binary_q31,
201 	op2c, arm_mat_cmplx_mult_q31, OP2C_CMPLX_MULT,
202 	in_cmplx_mult1, in_cmplx_mult2, ref_cmplx_mult,
203 	ARRAY_SIZE(ref_cmplx_mult) / 2);
204 
205 ZTEST_SUITE(matrix_binary_q31, NULL, NULL, NULL, NULL, NULL);
206