1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s16
22 * Description: s16 vector by matrix (transposed) multiplication
23 *
24 * $Date: 5 January 2023
25 * $Revision: V.2.2.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 #define MAX_COL_COUNT (512)
34
35 /**
36 * @ingroup groupSupport
37 */
38
39 /**
40 * @addtogroup supportFC
41 * @{
42 */
43
44 /*
45 * s16 vector(lhs) by matrix (transposed) multiplication
46 *
47 * Refer header file for details.
48 *
49 */
arm_nn_vec_mat_mult_t_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)50 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
51 const int8_t *rhs,
52 const int64_t *bias,
53 int16_t *dst,
54 const int32_t dst_multiplier,
55 const int32_t dst_shift,
56 const int32_t rhs_cols,
57 const int32_t rhs_rows,
58 const int32_t activation_min,
59 const int32_t activation_max)
60 {
61 #if defined(ARM_MATH_DSP)
62
63 int32_t rhs_cols_fast = rhs_cols;
64
65 if (rhs_cols > MAX_COL_COUNT)
66 {
67 rhs_cols_fast = MAX_COL_COUNT;
68 }
69
70 #if defined(ARM_MATH_MVEI)
71 int32_t row_loop_cnt = rhs_rows / 4;
72 int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
73
74 for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
75 {
76 int32_t col_cnt = rhs_cols_fast;
77
78 const int16_t *lhs_ptr = lhs;
79 const int8_t *rhs_ptr_0 = rhs;
80 const int8_t *rhs_ptr_1 = rhs + rhs_cols;
81 const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
82 const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
83
84 int32_t result_0 = 0;
85 int32_t result_1 = 0;
86 int32_t result_2 = 0;
87 int32_t result_3 = 0;
88
89 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
90 {
91 mve_pred16_t pred = vctp16q(col_cnt);
92 col_cnt -= 8;
93
94 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
95
96 int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
97 int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
98 int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
99 int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
100
101 result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
102 result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
103 result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
104 result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
105
106 lhs_ptr += 8;
107
108 rhs_ptr_0 += 8;
109 rhs_ptr_1 += 8;
110 rhs_ptr_2 += 8;
111 rhs_ptr_3 += 8;
112 }
113
114 int64_t result_64_0 = result_0;
115 int64_t result_64_1 = result_1;
116 int64_t result_64_2 = result_2;
117 int64_t result_64_3 = result_3;
118
119 if (rhs_cols > MAX_COL_COUNT)
120 {
121 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
122 {
123 const int16_t lhs_temp = *lhs_ptr++;
124
125 result_64_0 += *rhs_ptr_0++ * lhs_temp;
126 result_64_1 += *rhs_ptr_1++ * lhs_temp;
127 result_64_2 += *rhs_ptr_2++ * lhs_temp;
128 result_64_3 += *rhs_ptr_3++ * lhs_temp;
129 }
130 }
131
132 if (bias)
133 {
134 result_64_0 += *bias++;
135 result_64_1 += *bias++;
136 result_64_2 += *bias++;
137 result_64_3 += *bias++;
138 }
139
140 int32_t tmp;
141 tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
142 tmp = MAX(tmp, activation_min);
143 tmp = MIN(tmp, activation_max);
144 *dst++ = (int16_t)tmp;
145
146 tmp = 0;
147 tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
148 tmp = MAX(tmp, activation_min);
149 tmp = MIN(tmp, activation_max);
150 *dst++ = (int16_t)tmp;
151
152 tmp = 0;
153 tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
154 tmp = MAX(tmp, activation_min);
155 tmp = MIN(tmp, activation_max);
156 *dst++ = (int16_t)tmp;
157
158 tmp = 0;
159 tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
160 tmp = MAX(tmp, activation_min);
161 tmp = MIN(tmp, activation_max);
162 *dst++ = (int16_t)tmp;
163
164 rhs += 4 * rhs_cols;
165 }
166
167 for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
168 {
169 int32_t result = 0;
170
171 col_loop_cnt = (rhs_cols_fast + 7) / 8;
172
173 const int16_t *lhs_ptr = lhs;
174 const int8_t *rhs_ptr = rhs;
175
176 int32_t col_cnt = (int32_t)rhs_cols_fast;
177
178 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
179 {
180 mve_pred16_t pred = vctp16q(col_cnt);
181 col_cnt -= 8;
182
183 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
184 int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
185
186 result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
187
188 lhs_ptr += 8;
189 rhs_ptr += 8;
190 }
191
192 int64_t result_64 = result;
193
194 if (bias)
195 {
196 result_64 += *bias++;
197 }
198
199 if (rhs_cols > MAX_COL_COUNT)
200 {
201 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
202 {
203 const int16_t lhs_temp = *lhs_ptr++;
204
205 result_64 += *rhs_ptr++ * lhs_temp;
206 }
207 }
208
209 int32_t tmp = 0;
210 tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
211 tmp = MAX(tmp, activation_min);
212 tmp = MIN(tmp, activation_max);
213 *dst++ = (int16_t)tmp;
214
215 rhs += rhs_cols;
216 }
217
218 #else // ARM_MATH_MVEI
219
220 const int32_t row_loop_cnt = rhs_rows / 2;
221
222 for (int32_t i = 0; i < row_loop_cnt; i++)
223 {
224
225 int64_t acc_64_0 = 0;
226 int64_t acc_64_1 = 0;
227 int32_t acc_0 = 0;
228 int32_t acc_1 = 0;
229
230 const int32_t col_loop_cnt = rhs_cols_fast / 4;
231
232 const int16_t *lhs_vec = lhs;
233 const int8_t *rhs_0 = rhs;
234 const int8_t *rhs_1 = rhs + rhs_cols;
235 rhs += 2 * rhs_cols;
236
237 for (int j = col_loop_cnt; j != 0; j--)
238 {
239 int32_t ker_0, ker_1, vec_part_0, vec_part_1;
240
241 vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
242 vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
243
244 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
245
246 acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
247 acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
248
249 rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
250
251 acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
252 acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
253 }
254
255 acc_64_0 += acc_0;
256 acc_64_1 += acc_1;
257
258 for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
259 {
260 const int32_t lhs_temp = (*lhs_vec);
261 lhs_vec++;
262 acc_64_0 += lhs_temp * (*rhs_0);
263 rhs_0++;
264 acc_64_1 += lhs_temp * (*rhs_1);
265 rhs_1++;
266 }
267
268 if (bias)
269 {
270 acc_64_0 += *bias++;
271 acc_64_1 += *bias++;
272 }
273 int32_t tmp;
274
275 tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
276 tmp = MAX(tmp, activation_min);
277 tmp = MIN(tmp, activation_max);
278 *dst++ = (int16_t)tmp;
279
280 tmp = arm_nn_requantize_s64(acc_64_1, dst_multiplier, dst_shift);
281 tmp = MAX(tmp, activation_min);
282 tmp = MIN(tmp, activation_max);
283 *dst++ = (int16_t)tmp;
284 }
285
286 if (rhs_rows & 0x1)
287 {
288 int64_t acc_64_0 = 0;
289 int32_t acc_0 = 0;
290 const int32_t col_loop_cnt = rhs_cols_fast / 4;
291
292 const int16_t *lhs_vec = lhs;
293 const int8_t *rhs_0 = rhs;
294
295 for (int i = col_loop_cnt; i != 0; i--)
296 {
297 int32_t ker_0, ker_1, vec;
298 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
299
300 vec = arm_nn_read_q15x2_ia(&lhs_vec);
301 acc_0 = SMLAD(ker_0, vec, acc_0);
302
303 vec = arm_nn_read_q15x2_ia(&lhs_vec);
304 acc_0 = SMLAD(ker_1, vec, acc_0);
305 }
306
307 acc_64_0 += acc_0;
308
309 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
310 {
311 const int32_t lhs_temp = (*lhs_vec);
312 lhs_vec++;
313 acc_64_0 += lhs_temp * (*rhs_0);
314 rhs_0++;
315 }
316
317 if (bias)
318 {
319 acc_64_0 += *bias++;
320 }
321 int32_t tmp;
322 tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
323 tmp = MAX(tmp, activation_min);
324 tmp = MIN(tmp, activation_max);
325 *dst++ = (int16_t)tmp;
326 }
327
328 #endif // ARM_MATH_MVEI
329 #else // ARM_MATH_DSP
330 for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
331 {
332 const int16_t *lhs_ptr = lhs;
333 const int8_t *rhs_ptr_0 = &rhs[0];
334
335 int64_t result = 0;
336
337 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
338 {
339 const int64_t rhs_value0 = (int8_t)*rhs_ptr_0;
340 const int64_t lhs_value = *lhs_ptr;
341
342 result += lhs_value * rhs_value0;
343
344 ++rhs_ptr_0;
345 ++lhs_ptr;
346 }
347
348 if (bias)
349 {
350 result += *bias++;
351 }
352 // Quantize down
353 result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
354
355 // Clamp the result
356 result = ((result) > (activation_min) ? (result) : (activation_min));
357 result = ((result) < (activation_max) ? (result) : (activation_max));
358
359 *dst++ = (int16_t)result;
360 rhs += rhs_cols;
361 }
362 #endif // ARM_MATH_DSP
363
364 return ARM_CMSIS_NN_SUCCESS;
365 }
366
367 /**
368 * @} end of Doxygen group
369 */
370