1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "unary_q7.pat"
15 
16 #define SNR_ERROR_THRESH	((float32_t)20)
17 #define SNR_ERROR_THRESH_LOW	((float32_t)11)
18 #define ABS_ERROR_THRESH_Q7	((q7_t)2)
19 
20 #define NUM_MATRICES		(ARRAY_SIZE(in_dims) / 2)
21 #define MAX_MATRIX_DIM		(47)
22 
23 #define OP1_TRANS		(1)
24 #define OP2V_VEC_MULT		(0)
25 
test_op1(int op,const q7_t * ref,size_t length,bool transpose)26 static void test_op1(int op, const q7_t *ref, size_t length, bool transpose)
27 {
28 	size_t index;
29 	uint16_t *dims = (uint16_t *)in_dims;
30 	q7_t *tmp1, *output;
31 	uint16_t rows, columns;
32 	arm_status status;
33 
34 	arm_matrix_instance_q7 mat_in1;
35 	arm_matrix_instance_q7 mat_out;
36 
37 	/* Allocate buffers */
38 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
39 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40 
41 	output = malloc(length * sizeof(q7_t));
42 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43 
44 	/* Initialise contexts */
45 	mat_in1.pData = tmp1;
46 	mat_out.pData = output;
47 
48 	/* Iterate matrices */
49 	for (index = 0; index < NUM_MATRICES; index++) {
50 		rows = *dims++;
51 		columns = *dims++;
52 
53 		/* Initialise matrix dimensions */
54 		mat_in1.numRows = rows;
55 		mat_in1.numCols = columns;
56 		mat_out.numRows = transpose ? columns : rows;
57 		mat_out.numCols = transpose ? rows : columns;
58 
59 		/* Load matrix data */
60 		memcpy(mat_in1.pData, in_com1, rows * columns * sizeof(q7_t));
61 
62 		/* Run test function */
63 		switch (op) {
64 		case OP1_TRANS:
65 			status = arm_mat_trans_q7(&mat_in1, &mat_out);
66 			break;
67 		default:
68 			zassert_unreachable("invalid operation");
69 		}
70 
71 		/* Validate status */
72 		zassert_equal(status, ARM_MATH_SUCCESS,
73 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
74 
75 		/* Increment output pointer */
76 		mat_out.pData += (rows * columns);
77 	}
78 
79 	/* Validate output */
80 	zassert_true(
81 		test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
82 		ASSERT_MSG_SNR_LIMIT_EXCEED);
83 
84 	zassert_true(
85 		test_near_equal_q7(length, output, ref, ABS_ERROR_THRESH_Q7),
86 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
87 
88 	/* Free buffers */
89 	free(tmp1);
90 	free(output);
91 }
92 
93 DEFINE_TEST_VARIANT4(matrix_unary_q7,
94 	op1, arm_mat_trans_q7, OP1_TRANS,
95 	ref_trans, ARRAY_SIZE(ref_trans), true);
96 
test_op2v(int op,const q7_t * ref,size_t length)97 static void test_op2v(int op, const q7_t *ref, size_t length)
98 {
99 	size_t index;
100 	const uint16_t *dims = in_dims;
101 	q7_t *tmp1, *vec, *output_buf, *output;
102 	uint16_t rows, internal;
103 
104 	arm_matrix_instance_q7 mat_in1;
105 
106 	/* Allocate buffers */
107 	tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
108 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
109 
110 	vec = malloc(2 * MAX_MATRIX_DIM * sizeof(q7_t));
111 	zassert_not_null(vec, ASSERT_MSG_BUFFER_ALLOC_FAILED);
112 
113 	output_buf = malloc(length * sizeof(q7_t));
114 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
115 
116 	/* Initialise contexts */
117 	mat_in1.pData = tmp1;
118 	output = output_buf;
119 
120 	/* Iterate matrices */
121 	for (index = 0; index < NUM_MATRICES; index++) {
122 		rows = *dims++;
123 		internal = *dims++;
124 
125 		/* Initialise matrix dimensions */
126 		mat_in1.numRows = rows;
127 		mat_in1.numCols = internal;
128 
129 		/* Load matrix data */
130 		memcpy(mat_in1.pData, in_com1,
131 		       2 * rows * internal * sizeof(q7_t));
132 		memcpy(vec, in_vec1, 2 * internal * sizeof(q7_t));
133 
134 		/* Run test function */
135 		switch (op) {
136 		case OP2V_VEC_MULT:
137 			arm_mat_vec_mult_q7(&mat_in1, vec, output);
138 			break;
139 		default:
140 			zassert_unreachable("invalid operation");
141 		}
142 
143 		/* Increment output pointer */
144 		output += rows;
145 	}
146 
147 	/* Validate output */
148 	zassert_true(
149 		test_snr_error_q7(length, output_buf, ref, SNR_ERROR_THRESH_LOW),
150 		ASSERT_MSG_SNR_LIMIT_EXCEED);
151 
152 	zassert_true(
153 		test_near_equal_q7(length, output_buf, ref, ABS_ERROR_THRESH_Q7),
154 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
155 
156 	/* Free buffers */
157 	free(tmp1);
158 	free(vec);
159 	free(output_buf);
160 }
161 
162 DEFINE_TEST_VARIANT3(matrix_unary_q7,
163 	op2v, arm_mat_vec_mult_q7, OP2V_VEC_MULT,
164 	ref_vec_mult, ARRAY_SIZE(ref_vec_mult));
165 
166 ZTEST_SUITE(matrix_unary_q7, NULL, NULL, NULL, NULL, NULL);
167