1 // SPDX-License-Identifier: BSD-3-Clause
2 //
3 // Copyright(c) 2022 Intel Corporation. All rights reserved.
4 //
5 // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6 
7 #include <sof/math/matrix.h>
8 #include <errno.h>
9 #include <stdint.h>
10 
mat_multiply(struct mat_matrix_16b * a,struct mat_matrix_16b * b,struct mat_matrix_16b * c)11 int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
12 {
13 	int64_t s;
14 	int16_t *x;
15 	int16_t *y;
16 	int16_t *z = c->data;
17 	int i, j, k;
18 	int y_inc = b->columns;
19 	const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
20 
21 	if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
22 		return -EINVAL;
23 
24 	/* If all data is Q0 */
25 	if (shift_minus_one == -1) {
26 		for (i = 0; i < a->rows; i++) {
27 			for (j = 0; j < b->columns; j++) {
28 				s = 0;
29 				x = a->data + a->columns * i;
30 				y = b->data + j;
31 				for (k = 0; k < b->rows; k++) {
32 					s += (int32_t)(*x) * (*y);
33 					x++;
34 					y += y_inc;
35 				}
36 				*z = (int16_t)s; /* For Q16.0 */
37 				z++;
38 			}
39 		}
40 
41 		return 0;
42 	}
43 
44 	for (i = 0; i < a->rows; i++) {
45 		for (j = 0; j < b->columns; j++) {
46 			s = 0;
47 			x = a->data + a->columns * i;
48 			y = b->data + j;
49 			for (k = 0; k < b->rows; k++) {
50 				s += (int32_t)(*x) * (*y);
51 				x++;
52 				y += y_inc;
53 			}
54 			*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
55 			z++;
56 		}
57 	}
58 	return 0;
59 }
60 
mat_multiply_elementwise(struct mat_matrix_16b * a,struct mat_matrix_16b * b,struct mat_matrix_16b * c)61 int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
62 			     struct mat_matrix_16b *c)
63 {	int64_t p;
64 	int16_t *x = a->data;
65 	int16_t *y = b->data;
66 	int16_t *z = c->data;
67 	int i;
68 	const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
69 
70 	if (a->columns != b->columns || b->columns != c->columns ||
71 	    a->rows != b->rows || b->rows != c->rows) {
72 		return -EINVAL;
73 	}
74 
75 	/* If all data is Q0 */
76 	if (shift_minus_one == -1) {
77 		for (i = 0; i < a->rows * a->columns; i++) {
78 			*z = *x * *y;
79 			x++;
80 			y++;
81 			z++;
82 		}
83 
84 		return 0;
85 	}
86 
87 	for (i = 0; i < a->rows * a->columns; i++) {
88 		p = (int32_t)(*x) * *y;
89 		*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90 		x++;
91 		y++;
92 		z++;
93 	}
94 
95 	return 0;
96 }
97