1 // SPDX-License-Identifier: BSD-3-Clause
2 //
3 // Copyright(c) 2022 Intel Corporation. All rights reserved.
4 //
5 // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6 
7 #include <errno.h>
8 #include <stdlib.h>
9 #include <stdio.h>
10 #include <stdint.h>
11 #include <stdarg.h>
12 #include <stddef.h>
13 #include <setjmp.h>
14 #include <string.h>
15 #include <cmocka.h>
16 #include <math.h>
17 #include <sof/math/matrix.h>
18 #include "ref_matrix_mult_16_test1.h"
19 #include "ref_matrix_mult_16_test2.h"
20 #include "ref_matrix_mult_16_test3.h"
21 #include "ref_matrix_mult_16_test4.h"
22 
23 #define MATRIX_MULT_16_MAX_ERROR_ABS  1.5
24 #define MATRIX_MULT_16_MAX_ERROR_RMS  0.5
25 
matrix_mult_16_test(const int16_t * a_ref,const int16_t * b_ref,const int16_t * c_ref,int elementwise,int a_rows,int a_columns,int b_rows,int b_columns,int c_rows,int c_columns,int a_frac,int b_frac,int c_frac)26 static void matrix_mult_16_test(const int16_t *a_ref, const int16_t *b_ref, const int16_t *c_ref,
27 				int elementwise, int a_rows, int a_columns,
28 				int b_rows, int b_columns, int c_rows, int c_columns,
29 				int a_frac, int b_frac, int c_frac)
30 {
31 	struct mat_matrix_16b *a_matrix;
32 	struct mat_matrix_16b *b_matrix;
33 	struct mat_matrix_16b *c_matrix;
34 	float delta;
35 	float sum_squares = 0;
36 	float error_rms;
37 	float delta_max = 0;
38 	int16_t x;
39 	int i, j, k;
40 
41 	a_matrix = mat_matrix_alloc_16b(a_rows, a_columns, a_frac);
42 	if (!a_matrix)
43 		exit(EXIT_FAILURE);
44 
45 	b_matrix = mat_matrix_alloc_16b(b_rows, b_columns, b_frac);
46 	if (!b_matrix) {
47 		free(a_matrix);
48 		exit(EXIT_FAILURE);
49 	}
50 
51 	c_matrix = mat_matrix_alloc_16b(c_rows, c_columns, c_frac);
52 	if (!c_matrix)  {
53 		free(a_matrix);
54 		free(b_matrix);
55 		exit(EXIT_FAILURE);
56 	}
57 
58 	/* Initialize matrices a and b from test vectors and do matrix multiply */
59 	mat_copy_from_linear_16b(a_matrix, a_ref);
60 	mat_copy_from_linear_16b(b_matrix, b_ref);
61 	if (elementwise)
62 		mat_multiply_elementwise(a_matrix, b_matrix, c_matrix);
63 	else
64 		mat_multiply(a_matrix, b_matrix, c_matrix);
65 
66 	/* Check */
67 	k = 0;
68 	for (i = 0; i < c_matrix->rows; i++) {
69 		for (j = 0; j < c_matrix->columns; j++) {
70 			x = mat_get_scalar_16b(c_matrix, i, j);
71 			delta = (float)x - (float)c_ref[k++];
72 			sum_squares += delta * delta;
73 			if (delta > delta_max)
74 				delta_max = delta;
75 			else if (-delta > delta_max)
76 				delta_max = -delta;
77 		}
78 	}
79 
80 	error_rms = sqrt(sum_squares / (float)(c_matrix->rows * c_matrix->columns));
81 	printf("Max absolute error = %5.2f (max %5.2f), error RMS = %5.2f (max %5.2f)\n",
82 	       delta_max, MATRIX_MULT_16_MAX_ERROR_ABS, error_rms, MATRIX_MULT_16_MAX_ERROR_RMS);
83 
84 	assert_true(error_rms < MATRIX_MULT_16_MAX_ERROR_RMS);
85 	assert_true(delta_max < MATRIX_MULT_16_MAX_ERROR_ABS);
86 }
87 
test_matrix_mult_16_test1(void ** state)88 static void test_matrix_mult_16_test1(void **state)
89 {
90 	(void)state;
91 
92 	matrix_mult_16_test(matrix_mult_16_test1_a,
93 			    matrix_mult_16_test1_b,
94 			    matrix_mult_16_test1_c,
95 			    MATRIX_MULT_16_TEST1_ELEMENTWISE,
96 			    MATRIX_MULT_16_TEST1_A_ROWS,
97 			    MATRIX_MULT_16_TEST1_A_COLUMNS,
98 			    MATRIX_MULT_16_TEST1_B_ROWS,
99 			    MATRIX_MULT_16_TEST1_B_COLUMNS,
100 			    MATRIX_MULT_16_TEST1_C_ROWS,
101 			    MATRIX_MULT_16_TEST1_C_COLUMNS,
102 			    MATRIX_MULT_16_TEST1_A_QXY_Y,
103 			    MATRIX_MULT_16_TEST1_B_QXY_Y,
104 			    MATRIX_MULT_16_TEST1_C_QXY_Y);
105 }
106 
test_matrix_mult_16_test2(void ** state)107 static void test_matrix_mult_16_test2(void **state)
108 {
109 	(void)state;
110 
111 	matrix_mult_16_test(matrix_mult_16_test2_a,
112 			    matrix_mult_16_test2_b,
113 			    matrix_mult_16_test2_c,
114 			    MATRIX_MULT_16_TEST2_ELEMENTWISE,
115 			    MATRIX_MULT_16_TEST2_A_ROWS,
116 			    MATRIX_MULT_16_TEST2_A_COLUMNS,
117 			    MATRIX_MULT_16_TEST2_B_ROWS,
118 			    MATRIX_MULT_16_TEST2_B_COLUMNS,
119 			    MATRIX_MULT_16_TEST2_C_ROWS,
120 			    MATRIX_MULT_16_TEST2_C_COLUMNS,
121 			    MATRIX_MULT_16_TEST2_A_QXY_Y,
122 			    MATRIX_MULT_16_TEST2_B_QXY_Y,
123 			    MATRIX_MULT_16_TEST2_C_QXY_Y);
124 }
125 
test_matrix_mult_16_test3(void ** state)126 static void test_matrix_mult_16_test3(void **state)
127 {
128 	(void)state;
129 
130 	matrix_mult_16_test(matrix_mult_16_test3_a,
131 			    matrix_mult_16_test3_b,
132 			    matrix_mult_16_test3_c,
133 			    MATRIX_MULT_16_TEST3_ELEMENTWISE,
134 			    MATRIX_MULT_16_TEST3_A_ROWS,
135 			    MATRIX_MULT_16_TEST3_A_COLUMNS,
136 			    MATRIX_MULT_16_TEST3_B_ROWS,
137 			    MATRIX_MULT_16_TEST3_B_COLUMNS,
138 			    MATRIX_MULT_16_TEST3_C_ROWS,
139 			    MATRIX_MULT_16_TEST3_C_COLUMNS,
140 			    MATRIX_MULT_16_TEST3_A_QXY_Y,
141 			    MATRIX_MULT_16_TEST3_B_QXY_Y,
142 			    MATRIX_MULT_16_TEST3_C_QXY_Y);
143 }
144 
test_matrix_mult_16_test4(void ** state)145 static void test_matrix_mult_16_test4(void **state)
146 {
147 	(void)state;
148 
149 	matrix_mult_16_test(matrix_mult_16_test4_a,
150 			    matrix_mult_16_test4_b,
151 			    matrix_mult_16_test4_c,
152 			    MATRIX_MULT_16_TEST4_ELEMENTWISE,
153 			    MATRIX_MULT_16_TEST4_A_ROWS,
154 			    MATRIX_MULT_16_TEST4_A_COLUMNS,
155 			    MATRIX_MULT_16_TEST4_B_ROWS,
156 			    MATRIX_MULT_16_TEST4_B_COLUMNS,
157 			    MATRIX_MULT_16_TEST4_C_ROWS,
158 			    MATRIX_MULT_16_TEST4_C_COLUMNS,
159 			    MATRIX_MULT_16_TEST4_A_QXY_Y,
160 			    MATRIX_MULT_16_TEST4_B_QXY_Y,
161 			    MATRIX_MULT_16_TEST4_C_QXY_Y);
162 }
163 
main(void)164 int main(void)
165 {
166 	const struct CMUnitTest tests[] = {
167 		cmocka_unit_test(test_matrix_mult_16_test1),
168 		cmocka_unit_test(test_matrix_mult_16_test2),
169 		cmocka_unit_test(test_matrix_mult_16_test3),
170 		cmocka_unit_test(test_matrix_mult_16_test4),
171 	};
172 
173 	cmocka_set_message_output(CM_OUTPUT_TAP);
174 
175 	return cmocka_run_group_tests(tests, NULL, NULL);
176 }
177