1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "binary_f64.pat"
15 
16 #define SNR_ERROR_THRESH	((float64_t)120)
17 #define REL_ERROR_THRESH	(1.0e-6)
18 #define ABS_ERROR_THRESH	(1.0e-5)
19 
20 #define NUM_MATRICES		(ARRAY_SIZE(in_dims) / 3)
21 #define MAX_MATRIX_DIM		(40)
22 
23 #define OP2_MULT		(0)
24 #define OP2C_CMPLX_MULT		(0)
25 
test_op2(int op,const uint64_t * input1,const uint64_t * input2,const uint64_t * ref,size_t length)26 static void test_op2(int op, const uint64_t *input1, const uint64_t *input2,
27 	const uint64_t *ref, size_t length)
28 {
29 	size_t index;
30 	uint16_t *dims = (uint16_t *)in_dims;
31 	float64_t *tmp1, *tmp2, *output;
32 	uint16_t rows, internal, columns;
33 	arm_status status;
34 
35 	arm_matrix_instance_f64 mat_in1;
36 	arm_matrix_instance_f64 mat_in2;
37 	arm_matrix_instance_f64 mat_out;
38 
39 	/* Allocate buffers */
40 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float64_t));
41 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
42 
43 	tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float64_t));
44 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
45 
46 	output = malloc(length * sizeof(float64_t));
47 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
48 
49 	/* Initialise contexts */
50 	mat_in1.pData = tmp1;
51 	mat_in2.pData = tmp2;
52 	mat_out.pData = output;
53 
54 	/* Iterate matrices */
55 	for (index = 0; index < NUM_MATRICES; index++) {
56 		rows = *dims++;
57 		internal = *dims++;
58 		columns = *dims++;
59 
60 		/* Initialise matrix dimensions */
61 		mat_in1.numRows = rows;
62 		mat_in1.numCols = internal;
63 
64 		mat_in2.numRows = internal;
65 		mat_in2.numCols = columns;
66 
67 		mat_out.numRows = rows;
68 		mat_out.numCols = columns;
69 
70 		/* Load matrix data */
71 		memcpy(mat_in1.pData, input1,
72 		       rows * internal * sizeof(float64_t));
73 
74 		memcpy(mat_in2.pData, input2,
75 		       internal * columns * sizeof(float64_t));
76 
77 		/* Run test function */
78 		switch (op) {
79 		case OP2_MULT:
80 			status = arm_mat_mult_f64(&mat_in1, &mat_in2,
81 						  &mat_out);
82 			break;
83 		default:
84 			zassert_unreachable("invalid operation");
85 		}
86 
87 		/* Validate status */
88 		zassert_equal(status, ARM_MATH_SUCCESS,
89 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
90 
91 		/* Increment output pointer */
92 		mat_out.pData += (rows * columns);
93 	}
94 
95 	/* Validate output */
96 	zassert_true(
97 		test_snr_error_f64(length, output, (float64_t *)ref,
98 			SNR_ERROR_THRESH),
99 		ASSERT_MSG_SNR_LIMIT_EXCEED);
100 
101 	zassert_true(
102 		test_close_error_f64(length, output, (float64_t *)ref,
103 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
104 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
105 
106 	/* Free buffers */
107 	free(tmp1);
108 	free(tmp2);
109 	free(output);
110 }
111 
112 DEFINE_TEST_VARIANT5(matrix_binary_f64,
113 	op2, arm_mat_mult_f64, OP2_MULT,
114 	in_mult1, in_mult2, ref_mult,
115 	ARRAY_SIZE(ref_mult));
116 
117 #if 0
118 /*
119  * NOTE: arm_mat_cmplx_mult_f64 is not implemented for now. This test must be
120  * enabled once this function is implemented.
121  */
122 static void test_op2c(int op, const uint64_t *input1, const uint64_t *input2,
123 	const uint64_t *ref, size_t length)
124 {
125 	size_t index;
126 	uint16_t *dims = (uint16_t *)in_dims;
127 	float64_t *tmp1, *tmp2, *output;
128 	uint16_t rows, internal, columns;
129 	arm_status status;
130 
131 	arm_matrix_instance_f64 mat_in1;
132 	arm_matrix_instance_f64 mat_in2;
133 	arm_matrix_instance_f64 mat_out;
134 
135 	/* Allocate buffers */
136 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float64_t));
137 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
138 
139 	tmp2 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float64_t));
140 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
141 
142 	output = malloc(2 * length * sizeof(float64_t));
143 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
144 
145 	/* Initialise contexts */
146 	mat_in1.pData = tmp1;
147 	mat_in2.pData = tmp2;
148 	mat_out.pData = output;
149 
150 	/* Iterate matrices */
151 	for (index = 0; index < NUM_MATRICES; index++) {
152 		rows = *dims++;
153 		internal = *dims++;
154 		columns = *dims++;
155 
156 		/* Initialise matrix dimensions */
157 		mat_in1.numRows = rows;
158 		mat_in1.numCols = internal;
159 
160 		mat_in2.numRows = internal;
161 		mat_in2.numCols = columns;
162 
163 		mat_out.numRows = rows;
164 		mat_out.numCols = columns;
165 
166 		/* Load matrix data */
167 		memcpy(mat_in1.pData, input1,
168 		       2 * rows * internal * sizeof(float64_t));
169 
170 		memcpy(mat_in2.pData, input2,
171 		       2 * internal * columns * sizeof(float64_t));
172 
173 		/* Run test function */
174 		switch (op) {
175 		case OP2C_CMPLX_MULT:
176 			status = arm_mat_cmplx_mult_f64(&mat_in1, &mat_in2,
177 							&mat_out);
178 			break;
179 		default:
180 			zassert_unreachable("invalid operation");
181 		}
182 
183 		/* Validate status */
184 		zassert_equal(status, ARM_MATH_SUCCESS,
185 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
186 
187 		/* Increment output pointer */
188 		mat_out.pData += (2 * rows * columns);
189 	}
190 
191 	/* Validate output */
192 	zassert_true(
193 		test_snr_error_f64(2 * length, output, (float64_t *)ref,
194 			SNR_ERROR_THRESH),
195 		ASSERT_MSG_SNR_LIMIT_EXCEED);
196 
197 	zassert_true(
198 		test_close_error_f64(length, output, (float64_t *)ref,
199 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
200 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
201 
202 	/* Free buffers */
203 	free(tmp1);
204 	free(tmp2);
205 	free(output);
206 }
207 
208 DEFINE_TEST_VARIANT5(matrix_binary_f64,
209 	op2c, arm_mat_cmplx_mult_f64, OP2C_CMPLX_MULT,
210 	in_cmplx_mult1, in_cmplx_mult2, ref_cmplx_mult,
211 	ARRAY_SIZE(ref_cmplx_mult) / 2);
212 #endif
213 
214 ZTEST_SUITE(matrix_binary_f64, NULL, NULL, NULL, NULL, NULL);
215