1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "unary_q7.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)20)
17 #define SNR_ERROR_THRESH_LOW	((float32_t)11)
18 #define ABS_ERROR_THRESH_Q7	((q7_t)2)
19 
20 #define NUM_MATRICES		(ARRAY_SIZE(in_dims) / 2)
21 #define MAX_MATRIX_DIM		(47)
22 
23 #define OP1_TRANS		(1)
24 #define OP2V_VEC_MULT		(0)
25 
test_op1(int op,const q7_t * ref,size_t length,bool transpose)26 static void test_op1(int op, const q7_t *ref, size_t length, bool transpose)
27 {
28 	size_t index;
29 	uint16_t *dims = (uint16_t *)in_dims;
30 	q7_t *tmp1, *output;
31 	uint16_t rows, columns;
32 	arm_status status;
33 
34 	arm_matrix_instance_q7 mat_in1;
35 	arm_matrix_instance_q7 mat_out;
36 
37 	/* Allocate buffers */
38 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
39 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40 
41 	output = malloc(length * sizeof(q7_t));
42 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43 
44 	/* Initialise contexts */
45 	mat_in1.pData = tmp1;
46 	mat_out.pData = output;
47 
48 	/* Iterate matrices */
49 	for (index = 0; index < NUM_MATRICES; index++) {
50 		rows = *dims++;
51 		columns = *dims++;
52 
53 		/* Initialise matrix dimensions */
54 		mat_in1.numRows = rows;
55 		mat_in1.numCols = columns;
56 		mat_out.numRows = transpose ? columns : rows;
57 		mat_out.numCols = transpose ? rows : columns;
58 
59 		/* Load matrix data */
60 		memcpy(mat_in1.pData, in_com1, rows * columns * sizeof(q7_t));
61 
62 		/* Run test function */
63 		switch (op) {
64 		case OP1_TRANS:
65 			status = arm_mat_trans_q7(&mat_in1, &mat_out);
66 			break;
67 		default:
68 			zassert_unreachable("invalid operation");
69 		}
70 
71 		/* Validate status */
72 		zassert_equal(status, ARM_MATH_SUCCESS,
73 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
74 
75 		/* Increment output pointer */
76 		mat_out.pData += (rows * columns);
77 	}
78 
79 	/* Validate output */
80 	zassert_true(
81 		test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
82 		ASSERT_MSG_SNR_LIMIT_EXCEED);
83 
84 	zassert_true(
85 		test_near_equal_q7(length, output, ref, ABS_ERROR_THRESH_Q7),
86 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
87 
88 	/* Free buffers */
89 	free(tmp1);
90 	free(output);
91 }
92 
93 DEFINE_TEST_VARIANT4(op1, arm_mat_trans_q7, OP1_TRANS,
94 	ref_trans, ARRAY_SIZE(ref_trans), true);
95 
test_op2v(int op,const q7_t * ref,size_t length)96 static void test_op2v(int op, const q7_t *ref, size_t length)
97 {
98 	size_t index;
99 	const uint16_t *dims = in_dims;
100 	q7_t *tmp1, *vec, *output_buf, *output;
101 	uint16_t rows, internal;
102 
103 	arm_matrix_instance_q7 mat_in1;
104 
105 	/* Allocate buffers */
106 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
107 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
108 
109 	vec = malloc(2 * MAX_MATRIX_DIM * sizeof(q7_t));
110 	zassert_not_null(vec, ASSERT_MSG_BUFFER_ALLOC_FAILED);
111 
112 	output_buf = malloc(length * sizeof(q7_t));
113 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
114 
115 	/* Initialise contexts */
116 	mat_in1.pData = tmp1;
117 	output = output_buf;
118 
119 	/* Iterate matrices */
120 	for (index = 0; index < NUM_MATRICES; index++) {
121 		rows = *dims++;
122 		internal = *dims++;
123 
124 		/* Initialise matrix dimensions */
125 		mat_in1.numRows = rows;
126 		mat_in1.numCols = internal;
127 
128 		/* Load matrix data */
129 		memcpy(mat_in1.pData, in_com1,
130 		       2 * rows * internal * sizeof(q7_t));
131 		memcpy(vec, in_vec1, 2 * internal * sizeof(q7_t));
132 
133 		/* Run test function */
134 		switch (op) {
135 		case OP2V_VEC_MULT:
136 			arm_mat_vec_mult_q7(&mat_in1, vec, output);
137 			break;
138 		default:
139 			zassert_unreachable("invalid operation");
140 		}
141 
142 		/* Increment output pointer */
143 		output += rows;
144 	}
145 
146 	/* Validate output */
147 	zassert_true(
148 		test_snr_error_q7(length, output_buf, ref, SNR_ERROR_THRESH_LOW),
149 		ASSERT_MSG_SNR_LIMIT_EXCEED);
150 
151 	zassert_true(
152 		test_near_equal_q7(length, output_buf, ref, ABS_ERROR_THRESH_Q7),
153 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
154 
155 	/* Free buffers */
156 	free(tmp1);
157 	free(vec);
158 	free(output_buf);
159 }
160 
161 DEFINE_TEST_VARIANT3(op2v, arm_mat_vec_mult_q7, OP2V_VEC_MULT,
162 	ref_vec_mult, ARRAY_SIZE(ref_vec_mult));
163 
test_matrix_unary_q7(void)164 void test_matrix_unary_q7(void)
165 {
166 	ztest_test_suite(matrix_unary_q7,
167 		ztest_unit_test(test_op1_arm_mat_trans_q7),
168 		ztest_unit_test(test_op2v_arm_mat_vec_mult_q7)
169 		);
170 
171 	ztest_run_test_suite(matrix_unary_q7);
172 }
173