1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "binary_f16.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)60)
17 #define REL_ERROR_THRESH	(2.0e-3)
18 #define ABS_ERROR_THRESH	(2.0e-3)
19 
20 #define NUM_MATRICES		(ARRAY_SIZE(in_dims) / 3)
21 #define MAX_MATRIX_DIM		(40)
22 
23 #define OP2_MULT		(0)
24 #define OP2C_CMPLX_MULT		(0)
25 
test_op2(int op,const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)26 static void test_op2(int op, const uint16_t *input1, const uint16_t *input2,
27 	const uint16_t *ref, size_t length)
28 {
29 	size_t index;
30 	uint16_t *dims = (uint16_t *)in_dims;
31 	float16_t *tmp1, *tmp2, *output;
32 	uint16_t rows, internal, columns;
33 	arm_status status;
34 
35 	arm_matrix_instance_f16 mat_in1;
36 	arm_matrix_instance_f16 mat_in2;
37 	arm_matrix_instance_f16 mat_out;
38 
39 	/* Allocate buffers */
40 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float16_t));
41 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
42 
43 	tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float16_t));
44 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
45 
46 	output = malloc(length * sizeof(float16_t));
47 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
48 
49 	/* Initialise contexts */
50 	mat_in1.pData = tmp1;
51 	mat_in2.pData = tmp2;
52 	mat_out.pData = output;
53 
54 	/* Iterate matrices */
55 	for (index = 0; index < NUM_MATRICES; index++) {
56 		rows = *dims++;
57 		internal = *dims++;
58 		columns = *dims++;
59 
60 		/* Initialise matrix dimensions */
61 		mat_in1.numRows = rows;
62 		mat_in1.numCols = internal;
63 
64 		mat_in2.numRows = internal;
65 		mat_in2.numCols = columns;
66 
67 		mat_out.numRows = rows;
68 		mat_out.numCols = columns;
69 
70 		/* Load matrix data */
71 		memcpy(mat_in1.pData, input1,
72 		       rows * internal * sizeof(float16_t));
73 
74 		memcpy(mat_in2.pData, input2,
75 		       internal * columns * sizeof(float16_t));
76 
77 		/* Run test function */
78 		switch (op) {
79 		case OP2_MULT:
80 			status = arm_mat_mult_f16(&mat_in1, &mat_in2,
81 						  &mat_out);
82 			break;
83 		default:
84 			zassert_unreachable("invalid operation");
85 		}
86 
87 		/* Validate status */
88 		zassert_equal(status, ARM_MATH_SUCCESS,
89 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
90 
91 		/* Increment output pointer */
92 		mat_out.pData += (rows * columns);
93 	}
94 
95 	/* Validate output */
96 	zassert_true(
97 		test_snr_error_f16(length, output, (float16_t *)ref,
98 			SNR_ERROR_THRESH),
99 		ASSERT_MSG_SNR_LIMIT_EXCEED);
100 
101 	zassert_true(
102 		test_close_error_f16(length, output, (float16_t *)ref,
103 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
104 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
105 
106 	/* Free buffers */
107 	free(tmp1);
108 	free(tmp2);
109 	free(output);
110 }
111 
112 DEFINE_TEST_VARIANT5(matrix_binary_f16,
113 	op2, arm_mat_mult_f16, OP2_MULT,
114 	in_mult1, in_mult2, ref_mult,
115 	ARRAY_SIZE(ref_mult));
116 
test_op2c(int op,const uint16_t * input1,const uint16_t * input2,const uint16_t * ref,size_t length)117 static void test_op2c(int op, const uint16_t *input1, const uint16_t *input2,
118 	const uint16_t *ref, size_t length)
119 {
120 	size_t index;
121 	uint16_t *dims = (uint16_t *)in_dims;
122 	float16_t *tmp1, *tmp2, *output;
123 	uint16_t rows, internal, columns;
124 	arm_status status;
125 
126 	arm_matrix_instance_f16 mat_in1;
127 	arm_matrix_instance_f16 mat_in2;
128 	arm_matrix_instance_f16 mat_out;
129 
130 	/* Allocate buffers */
131 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float16_t));
132 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
133 
134 	tmp2 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(float16_t));
135 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
136 
137 	output = malloc(2 * length * sizeof(float16_t));
138 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
139 
140 	/* Initialise contexts */
141 	mat_in1.pData = tmp1;
142 	mat_in2.pData = tmp2;
143 	mat_out.pData = output;
144 
145 	/* Iterate matrices */
146 	for (index = 0; index < NUM_MATRICES; index++) {
147 		rows = *dims++;
148 		internal = *dims++;
149 		columns = *dims++;
150 
151 		/* Initialise matrix dimensions */
152 		mat_in1.numRows = rows;
153 		mat_in1.numCols = internal;
154 
155 		mat_in2.numRows = internal;
156 		mat_in2.numCols = columns;
157 
158 		mat_out.numRows = rows;
159 		mat_out.numCols = columns;
160 
161 		/* Load matrix data */
162 		memcpy(mat_in1.pData, input1,
163 		       2 * rows * internal * sizeof(float16_t));
164 
165 		memcpy(mat_in2.pData, input2,
166 		       2 * internal * columns * sizeof(float16_t));
167 
168 		/* Run test function */
169 		switch (op) {
170 		case OP2C_CMPLX_MULT:
171 			status = arm_mat_cmplx_mult_f16(&mat_in1, &mat_in2,
172 							&mat_out);
173 			break;
174 		default:
175 			zassert_unreachable("invalid operation");
176 		}
177 
178 		/* Validate status */
179 		zassert_equal(status, ARM_MATH_SUCCESS,
180 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
181 
182 		/* Increment output pointer */
183 		mat_out.pData += (2 * rows * columns);
184 	}
185 
186 	/* Validate output */
187 	zassert_true(
188 		test_snr_error_f16(2 * length, output, (float16_t *)ref,
189 			SNR_ERROR_THRESH),
190 		ASSERT_MSG_SNR_LIMIT_EXCEED);
191 
192 	zassert_true(
193 		test_close_error_f16(length, output, (float16_t *)ref,
194 			ABS_ERROR_THRESH, REL_ERROR_THRESH),
195 		ASSERT_MSG_ERROR_LIMIT_EXCEED);
196 
197 	/* Free buffers */
198 	free(tmp1);
199 	free(tmp2);
200 	free(output);
201 }
202 
203 DEFINE_TEST_VARIANT5(matrix_binary_f16,
204 	op2c, arm_mat_cmplx_mult_f16, OP2C_CMPLX_MULT,
205 	in_cmplx_mult1, in_cmplx_mult2, ref_cmplx_mult,
206 	ARRAY_SIZE(ref_cmplx_mult) / 2);
207 
208 ZTEST_SUITE(matrix_binary_f16, NULL, NULL, NULL, NULL, NULL);
209