1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s16_s16
22 * Description: s16 vector by s16 matrix (transposed) multiplication
23 *
24 * $Date: 19 June 2024
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32 /**
33 * @ingroup groupSupport
34 */
35
36 /**
37 * @addtogroup supportFC
38 * @{
39 */
40
41 /*
42 * s16 vector(lhs) by s16 matrix (transposed) multiplication
43 *
44 * Refer header file for details.
45 *
46 */
arm_nn_vec_mat_mult_t_s16_s16(const int16_t * lhs,const int16_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)47 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16_s16(const int16_t *lhs,
48 const int16_t *rhs,
49 const int64_t *bias,
50 int16_t *dst,
51 const int32_t dst_multiplier,
52 const int32_t dst_shift,
53 const int32_t rhs_cols,
54 const int32_t rhs_rows,
55 const int32_t activation_min,
56 const int32_t activation_max)
57 {
58 #if defined(ARM_MATH_DSP)
59
60 #if defined(ARM_MATH_MVEI)
61 int32_t row_loop_cnt = rhs_rows / 4;
62 int32_t col_loop_cnt = (rhs_cols + 7) / 8;
63
64 for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
65 {
66 int32_t col_cnt = rhs_cols;
67
68 const int16_t *lhs_ptr = lhs;
69 const int16_t *rhs_ptr_0 = rhs;
70 const int16_t *rhs_ptr_1 = rhs + rhs_cols;
71 const int16_t *rhs_ptr_2 = rhs + rhs_cols * 2;
72 const int16_t *rhs_ptr_3 = rhs + rhs_cols * 3;
73
74 int64_t result_0 = 0;
75 int64_t result_1 = 0;
76 int64_t result_2 = 0;
77 int64_t result_3 = 0;
78
79 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
80 {
81 mve_pred16_t pred = vctp16q(col_cnt);
82 col_cnt -= 8;
83
84 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
85
86 int16x8_t rhs_input_0 = vldrhq_z_s16(rhs_ptr_0, pred);
87 int16x8_t rhs_input_1 = vldrhq_z_s16(rhs_ptr_1, pred);
88 int16x8_t rhs_input_2 = vldrhq_z_s16(rhs_ptr_2, pred);
89 int16x8_t rhs_input_3 = vldrhq_z_s16(rhs_ptr_3, pred);
90
91 result_0 = vmlaldavaq_s16(result_0, lhs_input, rhs_input_0);
92 result_1 = vmlaldavaq_s16(result_1, lhs_input, rhs_input_1);
93 result_2 = vmlaldavaq_s16(result_2, lhs_input, rhs_input_2);
94 result_3 = vmlaldavaq_s16(result_3, lhs_input, rhs_input_3);
95
96 lhs_ptr += 8;
97
98 rhs_ptr_0 += 8;
99 rhs_ptr_1 += 8;
100 rhs_ptr_2 += 8;
101 rhs_ptr_3 += 8;
102 }
103
104 if (bias)
105 {
106 result_0 += *bias++;
107 result_1 += *bias++;
108 result_2 += *bias++;
109 result_3 += *bias++;
110 }
111
112 int32_t tmp;
113 tmp = arm_nn_requantize_s64(result_0, dst_multiplier, dst_shift);
114 tmp = MAX(tmp, activation_min);
115 tmp = MIN(tmp, activation_max);
116 *dst++ = (int16_t)tmp;
117
118 tmp = 0;
119 tmp = arm_nn_requantize_s64(result_1, dst_multiplier, dst_shift);
120 tmp = MAX(tmp, activation_min);
121 tmp = MIN(tmp, activation_max);
122 *dst++ = (int16_t)tmp;
123
124 tmp = 0;
125 tmp = arm_nn_requantize_s64(result_2, dst_multiplier, dst_shift);
126 tmp = MAX(tmp, activation_min);
127 tmp = MIN(tmp, activation_max);
128 *dst++ = (int16_t)tmp;
129
130 tmp = 0;
131 tmp = arm_nn_requantize_s64(result_3, dst_multiplier, dst_shift);
132 tmp = MAX(tmp, activation_min);
133 tmp = MIN(tmp, activation_max);
134 *dst++ = (int16_t)tmp;
135
136 rhs += 4 * rhs_cols;
137 }
138
139 for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
140 {
141 int64_t result = 0;
142
143 col_loop_cnt = (rhs_cols + 7) / 8;
144
145 const int16_t *lhs_ptr = lhs;
146 const int16_t *rhs_ptr = rhs;
147
148 int32_t col_cnt = rhs_cols;
149
150 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
151 {
152 mve_pred16_t pred = vctp16q(col_cnt);
153 col_cnt -= 8;
154
155 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
156 int16x8_t rhs_input = vldrhq_z_s16(rhs_ptr, pred);
157
158 result = vmlaldavaq_p_s16(result, lhs_input, rhs_input, pred);
159
160 lhs_ptr += 8;
161 rhs_ptr += 8;
162 }
163
164 if (bias)
165 {
166 result += *bias++;
167 }
168
169 int32_t tmp = 0;
170 tmp = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
171 tmp = MAX(tmp, activation_min);
172 tmp = MIN(tmp, activation_max);
173 *dst++ = (int16_t)tmp;
174
175 rhs += rhs_cols;
176 }
177
178 #else // ARM_MATH_MVEI
179
180 const int32_t row_loop_cnt = rhs_rows / 2;
181
182 for (int32_t i = 0; i < row_loop_cnt; i++)
183 {
184
185 int64_t acc_0 = 0;
186 int64_t acc_1 = 0;
187
188 const int32_t col_loop_cnt = rhs_cols / 4;
189
190 const int16_t *lhs_vec = lhs;
191 const int16_t *rhs_0 = rhs;
192 const int16_t *rhs_1 = rhs + rhs_cols;
193 rhs += 2 * rhs_cols;
194
195 for (int j = col_loop_cnt; j != 0; j--)
196 {
197 int32_t ker_0, ker_1, vec_part_0, vec_part_1;
198
199 vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
200 vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
201
202 ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
203 ker_1 = arm_nn_read_q15x2_ia(&rhs_0);
204
205 acc_0 = SMLALD(ker_0, vec_part_0, acc_0);
206 acc_0 = SMLALD(ker_1, vec_part_1, acc_0);
207
208 ker_0 = arm_nn_read_q15x2_ia(&rhs_1);
209 ker_1 = arm_nn_read_q15x2_ia(&rhs_1);
210
211 acc_1 = SMLALD(ker_0, vec_part_0, acc_1);
212 acc_1 = SMLALD(ker_1, vec_part_1, acc_1);
213 }
214
215 for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
216 {
217 const int16_t lhs_temp = (*lhs_vec);
218 lhs_vec++;
219 acc_0 += lhs_temp * (*rhs_0);
220 rhs_0++;
221 acc_1 += lhs_temp * (*rhs_1);
222 rhs_1++;
223 }
224
225 if (bias)
226 {
227 acc_0 += *bias++;
228 acc_1 += *bias++;
229 }
230 int32_t tmp;
231
232 tmp = arm_nn_requantize_s64(acc_0, dst_multiplier, dst_shift);
233 tmp = MAX(tmp, activation_min);
234 tmp = MIN(tmp, activation_max);
235 *dst++ = (int16_t)tmp;
236
237 tmp = arm_nn_requantize_s64(acc_1, dst_multiplier, dst_shift);
238 tmp = MAX(tmp, activation_min);
239 tmp = MIN(tmp, activation_max);
240 *dst++ = (int16_t)tmp;
241 }
242
243 if (rhs_rows & 0x1)
244 {
245 int64_t acc_0 = 0;
246 const int32_t col_loop_cnt = rhs_cols / 4;
247
248 const int16_t *lhs_vec = lhs;
249 const int16_t *rhs_0 = rhs;
250
251 for (int i = col_loop_cnt; i != 0; i--)
252 {
253 int32_t ker_0, vec;
254
255 ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
256 vec = arm_nn_read_q15x2_ia(&lhs_vec);
257 acc_0 = SMLALD(ker_0, vec, acc_0);
258
259 ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
260 vec = arm_nn_read_q15x2_ia(&lhs_vec);
261 acc_0 = SMLALD(ker_0, vec, acc_0);
262 }
263
264 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
265 {
266 const int16_t lhs_temp = (*lhs_vec);
267 lhs_vec++;
268 acc_0 += lhs_temp * (*rhs_0);
269 rhs_0++;
270 }
271
272 if (bias)
273 {
274 acc_0 += *bias++;
275 }
276 int32_t tmp;
277 tmp = arm_nn_requantize_s64(acc_0, dst_multiplier, dst_shift);
278 tmp = MAX(tmp, activation_min);
279 tmp = MIN(tmp, activation_max);
280 *dst++ = (int16_t)tmp;
281 }
282
283 #endif // ARM_MATH_MVEI
284 #else // ARM_MATH_DSP
285 for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
286 {
287 const int16_t *lhs_ptr = lhs;
288 const int16_t *rhs_ptr_0 = &rhs[0];
289
290 int64_t result = 0;
291
292 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
293 {
294 const int64_t rhs_value0 = *rhs_ptr_0;
295 const int64_t lhs_value = *lhs_ptr;
296
297 result += lhs_value * rhs_value0;
298
299 ++rhs_ptr_0;
300 ++lhs_ptr;
301 }
302
303 if (bias)
304 {
305 result += *bias++;
306 }
307 // Quantize down
308 result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
309
310 // Clamp the result
311 result = MAX(result, activation_min);
312 result = MIN(result, activation_max);
313
314 *dst++ = (int16_t)result;
315 rhs += rhs_cols;
316 }
317 #endif // ARM_MATH_DSP
318
319 return ARM_CMSIS_NN_SUCCESS;
320 }
321
322 /**
323 * @} end of Doxygen group
324 */
325