1 // SPDX-License-Identifier: BSD-3-Clause
2 //
3 // Copyright(c) 2022 Intel Corporation. All rights reserved.
4 //
5 // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
6
7 #include <sof/math/matrix.h>
8 #include <errno.h>
9 #include <stdint.h>
10
mat_multiply(struct mat_matrix_16b * a,struct mat_matrix_16b * b,struct mat_matrix_16b * c)11 int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
12 {
13 int64_t s;
14 int16_t *x;
15 int16_t *y;
16 int16_t *z = c->data;
17 int i, j, k;
18 int y_inc = b->columns;
19 const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
20
21 if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
22 return -EINVAL;
23
24 /* If all data is Q0 */
25 if (shift_minus_one == -1) {
26 for (i = 0; i < a->rows; i++) {
27 for (j = 0; j < b->columns; j++) {
28 s = 0;
29 x = a->data + a->columns * i;
30 y = b->data + j;
31 for (k = 0; k < b->rows; k++) {
32 s += (int32_t)(*x) * (*y);
33 x++;
34 y += y_inc;
35 }
36 *z = (int16_t)s; /* For Q16.0 */
37 z++;
38 }
39 }
40
41 return 0;
42 }
43
44 for (i = 0; i < a->rows; i++) {
45 for (j = 0; j < b->columns; j++) {
46 s = 0;
47 x = a->data + a->columns * i;
48 y = b->data + j;
49 for (k = 0; k < b->rows; k++) {
50 s += (int32_t)(*x) * (*y);
51 x++;
52 y += y_inc;
53 }
54 *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
55 z++;
56 }
57 }
58 return 0;
59 }
60
mat_multiply_elementwise(struct mat_matrix_16b * a,struct mat_matrix_16b * b,struct mat_matrix_16b * c)61 int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
62 struct mat_matrix_16b *c)
63 { int64_t p;
64 int16_t *x = a->data;
65 int16_t *y = b->data;
66 int16_t *z = c->data;
67 int i;
68 const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
69
70 if (a->columns != b->columns || b->columns != c->columns ||
71 a->rows != b->rows || b->rows != c->rows) {
72 return -EINVAL;
73 }
74
75 /* If all data is Q0 */
76 if (shift_minus_one == -1) {
77 for (i = 0; i < a->rows * a->columns; i++) {
78 *z = *x * *y;
79 x++;
80 y++;
81 z++;
82 }
83
84 return 0;
85 }
86
87 for (i = 0; i < a->rows * a->columns; i++) {
88 p = (int32_t)(*x) * *y;
89 *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
90 x++;
91 y++;
92 z++;
93 }
94
95 return 0;
96 }
97