1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s16
22 * Description: s16 vector by matrix (transposed) multiplication
23 *
24 * $Date: 22 March 2024
25 * $Revision: V.2.3.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportFC
39 * @{
40 */
41
42 /*
43 * s16 vector(lhs) by matrix (transposed) multiplication
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_vec_mat_mult_t_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)48 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
49 const int8_t *rhs,
50 const int64_t *bias,
51 int16_t *dst,
52 const int32_t dst_multiplier,
53 const int32_t dst_shift,
54 const int32_t rhs_cols,
55 const int32_t rhs_rows,
56 const int32_t activation_min,
57 const int32_t activation_max)
58 {
59 #if defined(ARM_MATH_DSP)
60
61 int32_t rhs_cols_fast = rhs_cols;
62
63 if (rhs_cols > MAX_COL_COUNT)
64 {
65 rhs_cols_fast = MAX_COL_COUNT;
66 }
67
68 #if defined(ARM_MATH_MVEI)
69 int32_t row_loop_cnt = rhs_rows / 4;
70 int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
71
72 for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
73 {
74 int32_t col_cnt = rhs_cols_fast;
75
76 const int16_t *lhs_ptr = lhs;
77 const int8_t *rhs_ptr_0 = rhs;
78 const int8_t *rhs_ptr_1 = rhs + rhs_cols;
79 const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
80 const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
81
82 int32_t result_0 = 0;
83 int32_t result_1 = 0;
84 int32_t result_2 = 0;
85 int32_t result_3 = 0;
86
87 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
88 {
89 mve_pred16_t pred = vctp16q(col_cnt);
90 col_cnt -= 8;
91
92 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
93
94 int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
95 int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
96 int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
97 int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
98
99 result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
100 result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
101 result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
102 result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
103
104 lhs_ptr += 8;
105
106 rhs_ptr_0 += 8;
107 rhs_ptr_1 += 8;
108 rhs_ptr_2 += 8;
109 rhs_ptr_3 += 8;
110 }
111
112 int64_t result_64_0 = result_0;
113 int64_t result_64_1 = result_1;
114 int64_t result_64_2 = result_2;
115 int64_t result_64_3 = result_3;
116
117 if (rhs_cols > MAX_COL_COUNT)
118 {
119 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
120 {
121 const int16_t lhs_temp = *lhs_ptr++;
122
123 result_64_0 += *rhs_ptr_0++ * lhs_temp;
124 result_64_1 += *rhs_ptr_1++ * lhs_temp;
125 result_64_2 += *rhs_ptr_2++ * lhs_temp;
126 result_64_3 += *rhs_ptr_3++ * lhs_temp;
127 }
128 }
129
130 if (bias)
131 {
132 result_64_0 += *bias++;
133 result_64_1 += *bias++;
134 result_64_2 += *bias++;
135 result_64_3 += *bias++;
136 }
137
138 int32_t tmp;
139 tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
140 tmp = MAX(tmp, activation_min);
141 tmp = MIN(tmp, activation_max);
142 *dst++ = (int16_t)tmp;
143
144 tmp = 0;
145 tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
146 tmp = MAX(tmp, activation_min);
147 tmp = MIN(tmp, activation_max);
148 *dst++ = (int16_t)tmp;
149
150 tmp = 0;
151 tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
152 tmp = MAX(tmp, activation_min);
153 tmp = MIN(tmp, activation_max);
154 *dst++ = (int16_t)tmp;
155
156 tmp = 0;
157 tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
158 tmp = MAX(tmp, activation_min);
159 tmp = MIN(tmp, activation_max);
160 *dst++ = (int16_t)tmp;
161
162 rhs += 4 * rhs_cols;
163 }
164
165 for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
166 {
167 int32_t result = 0;
168
169 col_loop_cnt = (rhs_cols_fast + 7) / 8;
170
171 const int16_t *lhs_ptr = lhs;
172 const int8_t *rhs_ptr = rhs;
173
174 int32_t col_cnt = (int32_t)rhs_cols_fast;
175
176 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
177 {
178 mve_pred16_t pred = vctp16q(col_cnt);
179 col_cnt -= 8;
180
181 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
182 int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
183
184 result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
185
186 lhs_ptr += 8;
187 rhs_ptr += 8;
188 }
189
190 int64_t result_64 = result;
191
192 if (bias)
193 {
194 result_64 += *bias++;
195 }
196
197 if (rhs_cols > MAX_COL_COUNT)
198 {
199 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
200 {
201 const int16_t lhs_temp = *lhs_ptr++;
202
203 result_64 += *rhs_ptr++ * lhs_temp;
204 }
205 }
206
207 int32_t tmp = 0;
208 tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
209 tmp = MAX(tmp, activation_min);
210 tmp = MIN(tmp, activation_max);
211 *dst++ = (int16_t)tmp;
212
213 rhs += rhs_cols;
214 }
215
216 #else // ARM_MATH_MVEI
217
218 const int32_t row_loop_cnt = rhs_rows / 2;
219
220 for (int32_t i = 0; i < row_loop_cnt; i++)
221 {
222
223 int64_t acc_64_0 = 0;
224 int64_t acc_64_1 = 0;
225 int32_t acc_0 = 0;
226 int32_t acc_1 = 0;
227
228 const int32_t col_loop_cnt = rhs_cols_fast / 4;
229
230 const int16_t *lhs_vec = lhs;
231 const int8_t *rhs_0 = rhs;
232 const int8_t *rhs_1 = rhs + rhs_cols;
233 rhs += 2 * rhs_cols;
234
235 for (int j = col_loop_cnt; j != 0; j--)
236 {
237 int32_t ker_0, ker_1, vec_part_0, vec_part_1;
238
239 vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
240 vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
241
242 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
243
244 acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
245 acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
246
247 rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
248
249 acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
250 acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
251 }
252
253 acc_64_0 += acc_0;
254 acc_64_1 += acc_1;
255
256 for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
257 {
258 const int32_t lhs_temp = (*lhs_vec);
259 lhs_vec++;
260 acc_64_0 += lhs_temp * (*rhs_0);
261 rhs_0++;
262 acc_64_1 += lhs_temp * (*rhs_1);
263 rhs_1++;
264 }
265
266 if (bias)
267 {
268 acc_64_0 += *bias++;
269 acc_64_1 += *bias++;
270 }
271 int32_t tmp;
272
273 tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
274 tmp = MAX(tmp, activation_min);
275 tmp = MIN(tmp, activation_max);
276 *dst++ = (int16_t)tmp;
277
278 tmp = arm_nn_requantize_s64(acc_64_1, dst_multiplier, dst_shift);
279 tmp = MAX(tmp, activation_min);
280 tmp = MIN(tmp, activation_max);
281 *dst++ = (int16_t)tmp;
282 }
283
284 if (rhs_rows & 0x1)
285 {
286 int64_t acc_64_0 = 0;
287 int32_t acc_0 = 0;
288 const int32_t col_loop_cnt = rhs_cols_fast / 4;
289
290 const int16_t *lhs_vec = lhs;
291 const int8_t *rhs_0 = rhs;
292
293 for (int i = col_loop_cnt; i != 0; i--)
294 {
295 int32_t ker_0, ker_1, vec;
296 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
297
298 vec = arm_nn_read_q15x2_ia(&lhs_vec);
299 acc_0 = SMLAD(ker_0, vec, acc_0);
300
301 vec = arm_nn_read_q15x2_ia(&lhs_vec);
302 acc_0 = SMLAD(ker_1, vec, acc_0);
303 }
304
305 acc_64_0 += acc_0;
306
307 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
308 {
309 const int32_t lhs_temp = (*lhs_vec);
310 lhs_vec++;
311 acc_64_0 += lhs_temp * (*rhs_0);
312 rhs_0++;
313 }
314
315 if (bias)
316 {
317 acc_64_0 += *bias++;
318 }
319 int32_t tmp;
320 tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
321 tmp = MAX(tmp, activation_min);
322 tmp = MIN(tmp, activation_max);
323 *dst++ = (int16_t)tmp;
324 }
325
326 #endif // ARM_MATH_MVEI
327 #else // ARM_MATH_DSP
328 for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
329 {
330 const int16_t *lhs_ptr = lhs;
331 const int8_t *rhs_ptr_0 = &rhs[0];
332
333 int64_t result = 0;
334
335 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
336 {
337 const int64_t rhs_value0 = (int8_t)*rhs_ptr_0;
338 const int64_t lhs_value = *lhs_ptr;
339
340 result += lhs_value * rhs_value0;
341
342 ++rhs_ptr_0;
343 ++lhs_ptr;
344 }
345
346 if (bias)
347 {
348 result += *bias++;
349 }
350 // Quantize down
351 result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
352
353 // Clamp the result
354 result = ((result) > (activation_min) ? (result) : (activation_min));
355 result = ((result) < (activation_max) ? (result) : (activation_max));
356
357 *dst++ = (int16_t)result;
358 rhs += rhs_cols;
359 }
360 #endif // ARM_MATH_DSP
361
362 return ARM_CMSIS_NN_SUCCESS;
363 }
364
365 /**
366 * @} end of Doxygen group
367 */
368