1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mul_result_acc_s16
22 * Description: s16 vector by matrix (transposed) multiplication
23 *
24 * $Date: 26 March 2023
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportFC
39 * @{
40 */
41
42 /*
43 * s16 vector(lhs) by matrix (transposed) multiplication with result accumulation
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_vec_mat_mul_result_acc_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * effective_bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t batches,const int32_t batch_offset)48 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
49 const int8_t *rhs,
50 const int64_t *effective_bias,
51 int16_t *dst,
52 const int32_t dst_multiplier,
53 const int32_t dst_shift,
54 const int32_t rhs_cols,
55 const int32_t rhs_rows,
56 const int32_t batches,
57 const int32_t batch_offset)
58 {
59
60 int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multiplier);
61
62 for (int batch = 0; batch < batches; batch++)
63 {
64
65 const int8_t *rhs_ptr = &rhs[0];
66 const int64_t *effective_bias_ptr = &effective_bias[0];
67
68 #if defined(ARM_MATH_DSP)
69
70 int32_t rhs_cols_fast = rhs_cols;
71
72 if (rhs_cols > MAX_COL_COUNT)
73 {
74 rhs_cols_fast = MAX_COL_COUNT;
75 }
76
77 #if defined(ARM_MATH_MVEI)
78 int32_t row_loop_cnt = rhs_rows / 4;
79 const int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
80
81 for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
82 {
83 int32_t col_cnt = rhs_cols_fast;
84
85 const int16_t *lhs_ptr = lhs;
86 const int8_t *rhs_ptr_0 = rhs_ptr;
87 const int8_t *rhs_ptr_1 = rhs_ptr + rhs_cols;
88 const int8_t *rhs_ptr_2 = rhs_ptr + rhs_cols * 2;
89 const int8_t *rhs_ptr_3 = rhs_ptr + rhs_cols * 3;
90
91 int32_t result_0 = *effective_bias_ptr++;
92 int32_t result_1 = *effective_bias_ptr++;
93 int32_t result_2 = *effective_bias_ptr++;
94 int32_t result_3 = *effective_bias_ptr++;
95
96 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
97 {
98 mve_pred16_t pred = vctp16q(col_cnt);
99 col_cnt -= 8;
100
101 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
102
103 int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
104 int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
105 int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
106 int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
107
108 result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
109 result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
110 result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
111 result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
112
113 lhs_ptr += 8;
114
115 rhs_ptr_0 += 8;
116 rhs_ptr_1 += 8;
117 rhs_ptr_2 += 8;
118 rhs_ptr_3 += 8;
119 }
120
121 int64_t result_64_0 = result_0;
122 int64_t result_64_1 = result_1;
123 int64_t result_64_2 = result_2;
124 int64_t result_64_3 = result_3;
125
126 if (rhs_cols > MAX_COL_COUNT)
127 {
128 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
129 {
130 const int16_t lhs_temp = *lhs_ptr++;
131
132 result_64_0 += *rhs_ptr_0++ * lhs_temp;
133 result_64_1 += *rhs_ptr_1++ * lhs_temp;
134 result_64_2 += *rhs_ptr_2++ * lhs_temp;
135 result_64_3 += *rhs_ptr_3++ * lhs_temp;
136 }
137 }
138
139 int32_t tmp;
140 tmp = arm_nn_requantize_s64(result_64_0, reduced_multiplier, dst_shift);
141 tmp += (int64_t)*dst;
142 tmp = MAX(tmp, NN_Q15_MIN);
143 tmp = MIN(tmp, NN_Q15_MAX);
144 *dst++ = (int16_t)tmp;
145
146 tmp = 0;
147 tmp = arm_nn_requantize_s64(result_64_1, reduced_multiplier, dst_shift);
148 tmp += (int64_t)*dst;
149 tmp = MAX(tmp, NN_Q15_MIN);
150 tmp = MIN(tmp, NN_Q15_MAX);
151 *dst++ = (int16_t)tmp;
152
153 tmp = 0;
154 tmp = arm_nn_requantize_s64(result_64_2, reduced_multiplier, dst_shift);
155 tmp += (int64_t)*dst;
156 tmp = MAX(tmp, NN_Q15_MIN);
157 tmp = MIN(tmp, NN_Q15_MAX);
158 *dst++ = (int16_t)tmp;
159
160 tmp = 0;
161 tmp = arm_nn_requantize_s64(result_64_3, reduced_multiplier, dst_shift);
162 tmp += (int64_t)*dst;
163 tmp = MAX(tmp, NN_Q15_MIN);
164 tmp = MIN(tmp, NN_Q15_MAX);
165 *dst++ = (int16_t)tmp;
166
167 rhs_ptr += 4 * rhs_cols;
168 }
169
170 for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
171 {
172 int32_t result = *effective_bias_ptr++;
173
174 const int16_t *lhs_ptr = lhs;
175 const int8_t *rhs_ptr0 = rhs_ptr;
176
177 int32_t col_cnt = (int32_t)rhs_cols_fast;
178
179 for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
180 {
181 mve_pred16_t pred = vctp16q(col_cnt);
182 col_cnt -= 8;
183
184 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
185 int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr0, pred);
186
187 result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
188
189 lhs_ptr += 8;
190 rhs_ptr0 += 8;
191 }
192
193 int64_t result_64 = result;
194
195 if (rhs_cols > MAX_COL_COUNT)
196 {
197 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
198 {
199 const int16_t lhs_temp = *lhs_ptr++;
200
201 result_64 += *rhs_ptr0++ * lhs_temp;
202 }
203 }
204
205 int32_t tmp = 0;
206 tmp = arm_nn_requantize_s64(result_64, reduced_multiplier, dst_shift);
207 tmp += (int64_t)*dst;
208 tmp = MAX(tmp, NN_Q15_MIN);
209 tmp = MIN(tmp, NN_Q15_MAX);
210 *dst++ = (int16_t)tmp;
211
212 rhs_ptr += rhs_cols;
213 }
214
215 #else // ARM_MATH_MVEI
216
217 const int32_t row_loop_cnt = rhs_rows / 2;
218
219 for (int32_t i = 0; i < row_loop_cnt; i++)
220 {
221
222 int64_t acc_64_0 = 0;
223 int64_t acc_64_1 = 0;
224 int32_t acc_0 = 0;
225 int32_t acc_1 = 0;
226
227 const int32_t col_loop_cnt = rhs_cols_fast / 4;
228
229 const int16_t *lhs_vec = lhs;
230 const int8_t *rhs_0 = rhs_ptr;
231 rhs_ptr += rhs_cols;
232 const int8_t *rhs_1 = rhs_ptr;
233 rhs_ptr += rhs_cols;
234
235 for (int j = col_loop_cnt; j != 0; j--)
236 {
237 int32_t ker_0, ker_1, vec_part_0, vec_part_1;
238
239 vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
240 vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
241
242 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
243
244 acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
245 acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
246
247 rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
248
249 acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
250 acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
251 }
252
253 acc_64_0 += acc_0;
254 acc_64_1 += acc_1;
255
256 for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
257 {
258 const int32_t lhs_temp = (*lhs_vec);
259 lhs_vec++;
260 acc_64_0 += lhs_temp * (*rhs_0);
261 rhs_0++;
262 acc_64_1 += lhs_temp * (*rhs_1);
263 rhs_1++;
264 }
265
266 acc_64_0 += *effective_bias_ptr++;
267 acc_64_1 += *effective_bias_ptr++;
268 int32_t tmp;
269
270 tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
271 tmp += (int64_t)*dst;
272 tmp = MAX(tmp, NN_Q15_MIN);
273 tmp = MIN(tmp, NN_Q15_MAX);
274 *dst++ = (int16_t)tmp;
275
276 tmp = arm_nn_requantize_s64(acc_64_1, reduced_multiplier, dst_shift);
277 tmp += (int64_t)*dst;
278 tmp = MAX(tmp, NN_Q15_MIN);
279 tmp = MIN(tmp, NN_Q15_MAX);
280 *dst++ = (int16_t)tmp;
281 }
282
283 if (rhs_rows & 0x1)
284 {
285 int64_t acc_64_0 = 0;
286 int32_t acc_0 = 0;
287 const int32_t col_loop_cnt = rhs_cols_fast / 4;
288
289 const int16_t *lhs_vec = lhs;
290 const int8_t *rhs_0 = rhs_ptr;
291
292 for (int i = col_loop_cnt; i != 0; i--)
293 {
294 int32_t ker_0, ker_1, vec;
295 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
296
297 vec = arm_nn_read_q15x2_ia(&lhs_vec);
298 acc_0 = SMLAD(ker_0, vec, acc_0);
299
300 vec = arm_nn_read_q15x2_ia(&lhs_vec);
301 acc_0 = SMLAD(ker_1, vec, acc_0);
302 }
303
304 acc_64_0 += acc_0;
305
306 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
307 {
308 const int32_t lhs_temp = (*lhs_vec);
309 lhs_vec++;
310 acc_64_0 += lhs_temp * (*rhs_0);
311 rhs_0++;
312 }
313
314 acc_64_0 += *effective_bias_ptr++;
315
316 int32_t tmp;
317 tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
318 tmp += (int64_t)*dst;
319 tmp = MAX(tmp, NN_Q15_MIN);
320 tmp = MIN(tmp, NN_Q15_MAX);
321 *dst++ = (int16_t)tmp;
322 }
323
324 #endif // ARM_MATH_MVEI
325 #else // ARM_MATH_DSP
326 for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
327 {
328 const int16_t *lhs_ptr = lhs;
329
330 int64_t result = *effective_bias_ptr++;
331
332 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
333 {
334 const int64_t rhs_value0 = (int8_t)*rhs_ptr;
335 const int64_t lhs_value = *lhs_ptr;
336
337 result += lhs_value * rhs_value0;
338 ++rhs_ptr;
339 ++lhs_ptr;
340 }
341
342 // Quantize down
343 result = arm_nn_requantize_s64(result, reduced_multiplier, dst_shift);
344 result += (int64_t)*dst;
345
346 // Clamp the result
347 result = ((result) > (NN_Q15_MIN) ? (result) : (NN_Q15_MIN));
348 result = ((result) < (NN_Q15_MAX) ? (result) : (NN_Q15_MAX));
349
350 *dst++ = (int16_t)result;
351 }
352 #endif // ARM_MATH_DSP
353
354 lhs += rhs_cols * batch_offset;
355 }
356
357 return ARM_CMSIS_NN_SUCCESS;
358 }
359
360 /**
361 * @} end of Doxygen group
362 */
363