1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "unary_q7.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)20)
17 #define SNR_ERROR_THRESH_LOW ((float32_t)11)
18 #define ABS_ERROR_THRESH_Q7 ((q7_t)2)
19
20 #define NUM_MATRICES (ARRAY_SIZE(in_dims) / 2)
21 #define MAX_MATRIX_DIM (47)
22
23 #define OP1_TRANS (1)
24 #define OP2V_VEC_MULT (0)
25
test_op1(int op,const q7_t * ref,size_t length,bool transpose)26 static void test_op1(int op, const q7_t *ref, size_t length, bool transpose)
27 {
28 size_t index;
29 uint16_t *dims = (uint16_t *)in_dims;
30 q7_t *tmp1, *output;
31 uint16_t rows, columns;
32 arm_status status;
33
34 arm_matrix_instance_q7 mat_in1;
35 arm_matrix_instance_q7 mat_out;
36
37 /* Allocate buffers */
38 tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
39 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40
41 output = malloc(length * sizeof(q7_t));
42 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43
44 /* Initialise contexts */
45 mat_in1.pData = tmp1;
46 mat_out.pData = output;
47
48 /* Iterate matrices */
49 for (index = 0; index < NUM_MATRICES; index++) {
50 rows = *dims++;
51 columns = *dims++;
52
53 /* Initialise matrix dimensions */
54 mat_in1.numRows = rows;
55 mat_in1.numCols = columns;
56 mat_out.numRows = transpose ? columns : rows;
57 mat_out.numCols = transpose ? rows : columns;
58
59 /* Load matrix data */
60 memcpy(mat_in1.pData, in_com1, rows * columns * sizeof(q7_t));
61
62 /* Run test function */
63 switch (op) {
64 case OP1_TRANS:
65 status = arm_mat_trans_q7(&mat_in1, &mat_out);
66 break;
67 default:
68 zassert_unreachable("invalid operation");
69 }
70
71 /* Validate status */
72 zassert_equal(status, ARM_MATH_SUCCESS,
73 ASSERT_MSG_INCORRECT_COMP_RESULT);
74
75 /* Increment output pointer */
76 mat_out.pData += (rows * columns);
77 }
78
79 /* Validate output */
80 zassert_true(
81 test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
82 ASSERT_MSG_SNR_LIMIT_EXCEED);
83
84 zassert_true(
85 test_near_equal_q7(length, output, ref, ABS_ERROR_THRESH_Q7),
86 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
87
88 /* Free buffers */
89 free(tmp1);
90 free(output);
91 }
92
93 DEFINE_TEST_VARIANT4(op1, arm_mat_trans_q7, OP1_TRANS,
94 ref_trans, ARRAY_SIZE(ref_trans), true);
95
test_op2v(int op,const q7_t * ref,size_t length)96 static void test_op2v(int op, const q7_t *ref, size_t length)
97 {
98 size_t index;
99 const uint16_t *dims = in_dims;
100 q7_t *tmp1, *vec, *output_buf, *output;
101 uint16_t rows, internal;
102
103 arm_matrix_instance_q7 mat_in1;
104
105 /* Allocate buffers */
106 tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
107 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
108
109 vec = malloc(2 * MAX_MATRIX_DIM * sizeof(q7_t));
110 zassert_not_null(vec, ASSERT_MSG_BUFFER_ALLOC_FAILED);
111
112 output_buf = malloc(length * sizeof(q7_t));
113 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
114
115 /* Initialise contexts */
116 mat_in1.pData = tmp1;
117 output = output_buf;
118
119 /* Iterate matrices */
120 for (index = 0; index < NUM_MATRICES; index++) {
121 rows = *dims++;
122 internal = *dims++;
123
124 /* Initialise matrix dimensions */
125 mat_in1.numRows = rows;
126 mat_in1.numCols = internal;
127
128 /* Load matrix data */
129 memcpy(mat_in1.pData, in_com1,
130 2 * rows * internal * sizeof(q7_t));
131 memcpy(vec, in_vec1, 2 * internal * sizeof(q7_t));
132
133 /* Run test function */
134 switch (op) {
135 case OP2V_VEC_MULT:
136 arm_mat_vec_mult_q7(&mat_in1, vec, output);
137 break;
138 default:
139 zassert_unreachable("invalid operation");
140 }
141
142 /* Increment output pointer */
143 output += rows;
144 }
145
146 /* Validate output */
147 zassert_true(
148 test_snr_error_q7(length, output_buf, ref, SNR_ERROR_THRESH_LOW),
149 ASSERT_MSG_SNR_LIMIT_EXCEED);
150
151 zassert_true(
152 test_near_equal_q7(length, output_buf, ref, ABS_ERROR_THRESH_Q7),
153 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
154
155 /* Free buffers */
156 free(tmp1);
157 free(vec);
158 free(output_buf);
159 }
160
161 DEFINE_TEST_VARIANT3(op2v, arm_mat_vec_mult_q7, OP2V_VEC_MULT,
162 ref_vec_mult, ARRAY_SIZE(ref_vec_mult));
163
test_matrix_unary_q7(void)164 void test_matrix_unary_q7(void)
165 {
166 ztest_test_suite(matrix_unary_q7,
167 ztest_unit_test(test_op1_arm_mat_trans_q7),
168 ztest_unit_test(test_op2v_arm_mat_vec_mult_q7)
169 );
170
171 ztest_run_test_suite(matrix_unary_q7);
172 }
173