1 // SPDX-License-Identifier: BSD-3-Clause
2 //
3 // Copyright(c) 2022 Intel Corporation. All rights reserved.
4 //
5 // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6
7 #include <errno.h>
8 #include <stdlib.h>
9 #include <stdio.h>
10 #include <stdint.h>
11 #include <stdarg.h>
12 #include <stddef.h>
13 #include <setjmp.h>
14 #include <string.h>
15 #include <cmocka.h>
16 #include <math.h>
17 #include <sof/math/matrix.h>
18 #include "ref_matrix_mult_16_test1.h"
19 #include "ref_matrix_mult_16_test2.h"
20 #include "ref_matrix_mult_16_test3.h"
21 #include "ref_matrix_mult_16_test4.h"
22
23 #define MATRIX_MULT_16_MAX_ERROR_ABS 1.5
24 #define MATRIX_MULT_16_MAX_ERROR_RMS 0.5
25
matrix_mult_16_test(const int16_t * a_ref,const int16_t * b_ref,const int16_t * c_ref,int elementwise,int a_rows,int a_columns,int b_rows,int b_columns,int c_rows,int c_columns,int a_frac,int b_frac,int c_frac)26 static void matrix_mult_16_test(const int16_t *a_ref, const int16_t *b_ref, const int16_t *c_ref,
27 int elementwise, int a_rows, int a_columns,
28 int b_rows, int b_columns, int c_rows, int c_columns,
29 int a_frac, int b_frac, int c_frac)
30 {
31 struct mat_matrix_16b *a_matrix;
32 struct mat_matrix_16b *b_matrix;
33 struct mat_matrix_16b *c_matrix;
34 float delta;
35 float sum_squares = 0;
36 float error_rms;
37 float delta_max = 0;
38 int16_t x;
39 int i, j, k;
40
41 a_matrix = mat_matrix_alloc_16b(a_rows, a_columns, a_frac);
42 if (!a_matrix)
43 exit(EXIT_FAILURE);
44
45 b_matrix = mat_matrix_alloc_16b(b_rows, b_columns, b_frac);
46 if (!b_matrix) {
47 free(a_matrix);
48 exit(EXIT_FAILURE);
49 }
50
51 c_matrix = mat_matrix_alloc_16b(c_rows, c_columns, c_frac);
52 if (!c_matrix) {
53 free(a_matrix);
54 free(b_matrix);
55 exit(EXIT_FAILURE);
56 }
57
58 /* Initialize matrices a and b from test vectors and do matrix multiply */
59 mat_copy_from_linear_16b(a_matrix, a_ref);
60 mat_copy_from_linear_16b(b_matrix, b_ref);
61 if (elementwise)
62 mat_multiply_elementwise(a_matrix, b_matrix, c_matrix);
63 else
64 mat_multiply(a_matrix, b_matrix, c_matrix);
65
66 /* Check */
67 k = 0;
68 for (i = 0; i < c_matrix->rows; i++) {
69 for (j = 0; j < c_matrix->columns; j++) {
70 x = mat_get_scalar_16b(c_matrix, i, j);
71 delta = (float)x - (float)c_ref[k++];
72 sum_squares += delta * delta;
73 if (delta > delta_max)
74 delta_max = delta;
75 else if (-delta > delta_max)
76 delta_max = -delta;
77 }
78 }
79
80 error_rms = sqrt(sum_squares / (float)(c_matrix->rows * c_matrix->columns));
81 printf("Max absolute error = %5.2f (max %5.2f), error RMS = %5.2f (max %5.2f)\n",
82 delta_max, MATRIX_MULT_16_MAX_ERROR_ABS, error_rms, MATRIX_MULT_16_MAX_ERROR_RMS);
83
84 assert_true(error_rms < MATRIX_MULT_16_MAX_ERROR_RMS);
85 assert_true(delta_max < MATRIX_MULT_16_MAX_ERROR_ABS);
86 }
87
test_matrix_mult_16_test1(void ** state)88 static void test_matrix_mult_16_test1(void **state)
89 {
90 (void)state;
91
92 matrix_mult_16_test(matrix_mult_16_test1_a,
93 matrix_mult_16_test1_b,
94 matrix_mult_16_test1_c,
95 MATRIX_MULT_16_TEST1_ELEMENTWISE,
96 MATRIX_MULT_16_TEST1_A_ROWS,
97 MATRIX_MULT_16_TEST1_A_COLUMNS,
98 MATRIX_MULT_16_TEST1_B_ROWS,
99 MATRIX_MULT_16_TEST1_B_COLUMNS,
100 MATRIX_MULT_16_TEST1_C_ROWS,
101 MATRIX_MULT_16_TEST1_C_COLUMNS,
102 MATRIX_MULT_16_TEST1_A_QXY_Y,
103 MATRIX_MULT_16_TEST1_B_QXY_Y,
104 MATRIX_MULT_16_TEST1_C_QXY_Y);
105 }
106
test_matrix_mult_16_test2(void ** state)107 static void test_matrix_mult_16_test2(void **state)
108 {
109 (void)state;
110
111 matrix_mult_16_test(matrix_mult_16_test2_a,
112 matrix_mult_16_test2_b,
113 matrix_mult_16_test2_c,
114 MATRIX_MULT_16_TEST2_ELEMENTWISE,
115 MATRIX_MULT_16_TEST2_A_ROWS,
116 MATRIX_MULT_16_TEST2_A_COLUMNS,
117 MATRIX_MULT_16_TEST2_B_ROWS,
118 MATRIX_MULT_16_TEST2_B_COLUMNS,
119 MATRIX_MULT_16_TEST2_C_ROWS,
120 MATRIX_MULT_16_TEST2_C_COLUMNS,
121 MATRIX_MULT_16_TEST2_A_QXY_Y,
122 MATRIX_MULT_16_TEST2_B_QXY_Y,
123 MATRIX_MULT_16_TEST2_C_QXY_Y);
124 }
125
test_matrix_mult_16_test3(void ** state)126 static void test_matrix_mult_16_test3(void **state)
127 {
128 (void)state;
129
130 matrix_mult_16_test(matrix_mult_16_test3_a,
131 matrix_mult_16_test3_b,
132 matrix_mult_16_test3_c,
133 MATRIX_MULT_16_TEST3_ELEMENTWISE,
134 MATRIX_MULT_16_TEST3_A_ROWS,
135 MATRIX_MULT_16_TEST3_A_COLUMNS,
136 MATRIX_MULT_16_TEST3_B_ROWS,
137 MATRIX_MULT_16_TEST3_B_COLUMNS,
138 MATRIX_MULT_16_TEST3_C_ROWS,
139 MATRIX_MULT_16_TEST3_C_COLUMNS,
140 MATRIX_MULT_16_TEST3_A_QXY_Y,
141 MATRIX_MULT_16_TEST3_B_QXY_Y,
142 MATRIX_MULT_16_TEST3_C_QXY_Y);
143 }
144
test_matrix_mult_16_test4(void ** state)145 static void test_matrix_mult_16_test4(void **state)
146 {
147 (void)state;
148
149 matrix_mult_16_test(matrix_mult_16_test4_a,
150 matrix_mult_16_test4_b,
151 matrix_mult_16_test4_c,
152 MATRIX_MULT_16_TEST4_ELEMENTWISE,
153 MATRIX_MULT_16_TEST4_A_ROWS,
154 MATRIX_MULT_16_TEST4_A_COLUMNS,
155 MATRIX_MULT_16_TEST4_B_ROWS,
156 MATRIX_MULT_16_TEST4_B_COLUMNS,
157 MATRIX_MULT_16_TEST4_C_ROWS,
158 MATRIX_MULT_16_TEST4_C_COLUMNS,
159 MATRIX_MULT_16_TEST4_A_QXY_Y,
160 MATRIX_MULT_16_TEST4_B_QXY_Y,
161 MATRIX_MULT_16_TEST4_C_QXY_Y);
162 }
163
main(void)164 int main(void)
165 {
166 const struct CMUnitTest tests[] = {
167 cmocka_unit_test(test_matrix_mult_16_test1),
168 cmocka_unit_test(test_matrix_mult_16_test2),
169 cmocka_unit_test(test_matrix_mult_16_test3),
170 cmocka_unit_test(test_matrix_mult_16_test4),
171 };
172
173 cmocka_set_message_output(CM_OUTPUT_TAP);
174
175 return cmocka_run_group_tests(tests, NULL, NULL);
176 }
177