1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "unary_q7.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)20)
17 #define SNR_ERROR_THRESH_LOW ((float32_t)11)
18 #define ABS_ERROR_THRESH_Q7 ((q7_t)2)
19
20 #define NUM_MATRICES (ARRAY_SIZE(in_dims) / 2)
21 #define MAX_MATRIX_DIM (47)
22
23 #define OP1_TRANS (1)
24 #define OP2V_VEC_MULT (0)
25
test_op1(int op,const q7_t * ref,size_t length,bool transpose)26 static void test_op1(int op, const q7_t *ref, size_t length, bool transpose)
27 {
28 size_t index;
29 uint16_t *dims = (uint16_t *)in_dims;
30 q7_t *tmp1, *output;
31 uint16_t rows, columns;
32 arm_status status;
33
34 arm_matrix_instance_q7 mat_in1;
35 arm_matrix_instance_q7 mat_out;
36
37 /* Allocate buffers */
38 tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
39 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40
41 output = malloc(length * sizeof(q7_t));
42 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43
44 /* Initialise contexts */
45 mat_in1.pData = tmp1;
46 mat_out.pData = output;
47
48 /* Iterate matrices */
49 for (index = 0; index < NUM_MATRICES; index++) {
50 rows = *dims++;
51 columns = *dims++;
52
53 /* Initialise matrix dimensions */
54 mat_in1.numRows = rows;
55 mat_in1.numCols = columns;
56 mat_out.numRows = transpose ? columns : rows;
57 mat_out.numCols = transpose ? rows : columns;
58
59 /* Load matrix data */
60 memcpy(mat_in1.pData, in_com1, rows * columns * sizeof(q7_t));
61
62 /* Run test function */
63 switch (op) {
64 case OP1_TRANS:
65 status = arm_mat_trans_q7(&mat_in1, &mat_out);
66 break;
67 default:
68 zassert_unreachable("invalid operation");
69 }
70
71 /* Validate status */
72 zassert_equal(status, ARM_MATH_SUCCESS,
73 ASSERT_MSG_INCORRECT_COMP_RESULT);
74
75 /* Increment output pointer */
76 mat_out.pData += (rows * columns);
77 }
78
79 /* Validate output */
80 zassert_true(
81 test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
82 ASSERT_MSG_SNR_LIMIT_EXCEED);
83
84 zassert_true(
85 test_near_equal_q7(length, output, ref, ABS_ERROR_THRESH_Q7),
86 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
87
88 /* Free buffers */
89 free(tmp1);
90 free(output);
91 }
92
93 DEFINE_TEST_VARIANT4(matrix_unary_q7,
94 op1, arm_mat_trans_q7, OP1_TRANS,
95 ref_trans, ARRAY_SIZE(ref_trans), true);
96
test_op2v(int op,const q7_t * ref,size_t length)97 static void test_op2v(int op, const q7_t *ref, size_t length)
98 {
99 size_t index;
100 const uint16_t *dims = in_dims;
101 q7_t *tmp1, *vec, *output_buf, *output;
102 uint16_t rows, internal;
103
104 arm_matrix_instance_q7 mat_in1;
105
106 /* Allocate buffers */
107 tmp1 = malloc(2 * MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
108 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
109
110 vec = malloc(2 * MAX_MATRIX_DIM * sizeof(q7_t));
111 zassert_not_null(vec, ASSERT_MSG_BUFFER_ALLOC_FAILED);
112
113 output_buf = malloc(length * sizeof(q7_t));
114 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
115
116 /* Initialise contexts */
117 mat_in1.pData = tmp1;
118 output = output_buf;
119
120 /* Iterate matrices */
121 for (index = 0; index < NUM_MATRICES; index++) {
122 rows = *dims++;
123 internal = *dims++;
124
125 /* Initialise matrix dimensions */
126 mat_in1.numRows = rows;
127 mat_in1.numCols = internal;
128
129 /* Load matrix data */
130 memcpy(mat_in1.pData, in_com1,
131 2 * rows * internal * sizeof(q7_t));
132 memcpy(vec, in_vec1, 2 * internal * sizeof(q7_t));
133
134 /* Run test function */
135 switch (op) {
136 case OP2V_VEC_MULT:
137 arm_mat_vec_mult_q7(&mat_in1, vec, output);
138 break;
139 default:
140 zassert_unreachable("invalid operation");
141 }
142
143 /* Increment output pointer */
144 output += rows;
145 }
146
147 /* Validate output */
148 zassert_true(
149 test_snr_error_q7(length, output_buf, ref, SNR_ERROR_THRESH_LOW),
150 ASSERT_MSG_SNR_LIMIT_EXCEED);
151
152 zassert_true(
153 test_near_equal_q7(length, output_buf, ref, ABS_ERROR_THRESH_Q7),
154 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
155
156 /* Free buffers */
157 free(tmp1);
158 free(vec);
159 free(output_buf);
160 }
161
162 DEFINE_TEST_VARIANT3(matrix_unary_q7,
163 op2v, arm_mat_vec_mult_q7, OP2V_VEC_MULT,
164 ref_vec_mult, ARRAY_SIZE(ref_vec_mult));
165
166 ZTEST_SUITE(matrix_unary_q7, NULL, NULL, NULL, NULL, NULL);
167