1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mul_result_acc_s16
22  * Description:  s16 vector by matrix (transposed) multiplication
23  *
24  * $Date:        26 March 2023
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportFC
39  * @{
40  */
41 
42 /*
43  * s16 vector(lhs) by matrix (transposed) multiplication with result accumulation
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_vec_mat_mul_result_acc_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * effective_bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t batches,const int32_t batch_offset)48 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
49                                                       const int8_t *rhs,
50                                                       const int64_t *effective_bias,
51                                                       int16_t *dst,
52                                                       const int32_t dst_multiplier,
53                                                       const int32_t dst_shift,
54                                                       const int32_t rhs_cols,
55                                                       const int32_t rhs_rows,
56                                                       const int32_t batches,
57                                                       const int32_t batch_offset)
58 {
59 
60     int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multiplier);
61 
62     for (int batch = 0; batch < batches; batch++)
63     {
64 
65         const int8_t *rhs_ptr = &rhs[0];
66         const int64_t *effective_bias_ptr = &effective_bias[0];
67 
68 #if defined(ARM_MATH_DSP)
69 
70         int32_t rhs_cols_fast = rhs_cols;
71 
72         if (rhs_cols > MAX_COL_COUNT)
73         {
74             rhs_cols_fast = MAX_COL_COUNT;
75         }
76 
77     #if defined(ARM_MATH_MVEI)
78         int32_t row_loop_cnt = rhs_rows / 4;
79         const int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
80 
81         for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
82         {
83             int32_t col_cnt = rhs_cols_fast;
84 
85             const int16_t *lhs_ptr = lhs;
86             const int8_t *rhs_ptr_0 = rhs_ptr;
87             const int8_t *rhs_ptr_1 = rhs_ptr + rhs_cols;
88             const int8_t *rhs_ptr_2 = rhs_ptr + rhs_cols * 2;
89             const int8_t *rhs_ptr_3 = rhs_ptr + rhs_cols * 3;
90 
91             int32_t result_0 = *effective_bias_ptr++;
92             int32_t result_1 = *effective_bias_ptr++;
93             int32_t result_2 = *effective_bias_ptr++;
94             int32_t result_3 = *effective_bias_ptr++;
95 
96             for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
97             {
98                 mve_pred16_t pred = vctp16q(col_cnt);
99                 col_cnt -= 8;
100 
101                 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
102 
103                 int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
104                 int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
105                 int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
106                 int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
107 
108                 result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
109                 result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
110                 result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
111                 result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
112 
113                 lhs_ptr += 8;
114 
115                 rhs_ptr_0 += 8;
116                 rhs_ptr_1 += 8;
117                 rhs_ptr_2 += 8;
118                 rhs_ptr_3 += 8;
119             }
120 
121             int64_t result_64_0 = result_0;
122             int64_t result_64_1 = result_1;
123             int64_t result_64_2 = result_2;
124             int64_t result_64_3 = result_3;
125 
126             if (rhs_cols > MAX_COL_COUNT)
127             {
128                 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
129                 {
130                     const int16_t lhs_temp = *lhs_ptr++;
131 
132                     result_64_0 += *rhs_ptr_0++ * lhs_temp;
133                     result_64_1 += *rhs_ptr_1++ * lhs_temp;
134                     result_64_2 += *rhs_ptr_2++ * lhs_temp;
135                     result_64_3 += *rhs_ptr_3++ * lhs_temp;
136                 }
137             }
138 
139             int32_t tmp;
140             tmp = arm_nn_requantize_s64(result_64_0, reduced_multiplier, dst_shift);
141             tmp += (int64_t)*dst;
142             tmp = MAX(tmp, NN_Q15_MIN);
143             tmp = MIN(tmp, NN_Q15_MAX);
144             *dst++ = (int16_t)tmp;
145 
146             tmp = 0;
147             tmp = arm_nn_requantize_s64(result_64_1, reduced_multiplier, dst_shift);
148             tmp += (int64_t)*dst;
149             tmp = MAX(tmp, NN_Q15_MIN);
150             tmp = MIN(tmp, NN_Q15_MAX);
151             *dst++ = (int16_t)tmp;
152 
153             tmp = 0;
154             tmp = arm_nn_requantize_s64(result_64_2, reduced_multiplier, dst_shift);
155             tmp += (int64_t)*dst;
156             tmp = MAX(tmp, NN_Q15_MIN);
157             tmp = MIN(tmp, NN_Q15_MAX);
158             *dst++ = (int16_t)tmp;
159 
160             tmp = 0;
161             tmp = arm_nn_requantize_s64(result_64_3, reduced_multiplier, dst_shift);
162             tmp += (int64_t)*dst;
163             tmp = MAX(tmp, NN_Q15_MIN);
164             tmp = MIN(tmp, NN_Q15_MAX);
165             *dst++ = (int16_t)tmp;
166 
167             rhs_ptr += 4 * rhs_cols;
168         }
169 
170         for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
171         {
172             int32_t result = *effective_bias_ptr++;
173 
174             const int16_t *lhs_ptr = lhs;
175             const int8_t *rhs_ptr0 = rhs_ptr;
176 
177             int32_t col_cnt = (int32_t)rhs_cols_fast;
178 
179             for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
180             {
181                 mve_pred16_t pred = vctp16q(col_cnt);
182                 col_cnt -= 8;
183 
184                 int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
185                 int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr0, pred);
186 
187                 result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
188 
189                 lhs_ptr += 8;
190                 rhs_ptr0 += 8;
191             }
192 
193             int64_t result_64 = result;
194 
195             if (rhs_cols > MAX_COL_COUNT)
196             {
197                 for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
198                 {
199                     const int16_t lhs_temp = *lhs_ptr++;
200 
201                     result_64 += *rhs_ptr0++ * lhs_temp;
202                 }
203             }
204 
205             int32_t tmp = 0;
206             tmp = arm_nn_requantize_s64(result_64, reduced_multiplier, dst_shift);
207             tmp += (int64_t)*dst;
208             tmp = MAX(tmp, NN_Q15_MIN);
209             tmp = MIN(tmp, NN_Q15_MAX);
210             *dst++ = (int16_t)tmp;
211 
212             rhs_ptr += rhs_cols;
213         }
214 
215     #else // ARM_MATH_MVEI
216 
217         const int32_t row_loop_cnt = rhs_rows / 2;
218 
219         for (int32_t i = 0; i < row_loop_cnt; i++)
220         {
221 
222             int64_t acc_64_0 = 0;
223             int64_t acc_64_1 = 0;
224             int32_t acc_0 = 0;
225             int32_t acc_1 = 0;
226 
227             const int32_t col_loop_cnt = rhs_cols_fast / 4;
228 
229             const int16_t *lhs_vec = lhs;
230             const int8_t *rhs_0 = rhs_ptr;
231             rhs_ptr += rhs_cols;
232             const int8_t *rhs_1 = rhs_ptr;
233             rhs_ptr += rhs_cols;
234 
235             for (int j = col_loop_cnt; j != 0; j--)
236             {
237                 int32_t ker_0, ker_1, vec_part_0, vec_part_1;
238 
239                 vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
240                 vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
241 
242                 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
243 
244                 acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
245                 acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
246 
247                 rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
248 
249                 acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
250                 acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
251             }
252 
253             acc_64_0 += acc_0;
254             acc_64_1 += acc_1;
255 
256             for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
257             {
258                 const int32_t lhs_temp = (*lhs_vec);
259                 lhs_vec++;
260                 acc_64_0 += lhs_temp * (*rhs_0);
261                 rhs_0++;
262                 acc_64_1 += lhs_temp * (*rhs_1);
263                 rhs_1++;
264             }
265 
266             acc_64_0 += *effective_bias_ptr++;
267             acc_64_1 += *effective_bias_ptr++;
268             int32_t tmp;
269 
270             tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
271             tmp += (int64_t)*dst;
272             tmp = MAX(tmp, NN_Q15_MIN);
273             tmp = MIN(tmp, NN_Q15_MAX);
274             *dst++ = (int16_t)tmp;
275 
276             tmp = arm_nn_requantize_s64(acc_64_1, reduced_multiplier, dst_shift);
277             tmp += (int64_t)*dst;
278             tmp = MAX(tmp, NN_Q15_MIN);
279             tmp = MIN(tmp, NN_Q15_MAX);
280             *dst++ = (int16_t)tmp;
281         }
282 
283         if (rhs_rows & 0x1)
284         {
285             int64_t acc_64_0 = 0;
286             int32_t acc_0 = 0;
287             const int32_t col_loop_cnt = rhs_cols_fast / 4;
288 
289             const int16_t *lhs_vec = lhs;
290             const int8_t *rhs_0 = rhs_ptr;
291 
292             for (int i = col_loop_cnt; i != 0; i--)
293             {
294                 int32_t ker_0, ker_1, vec;
295                 rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
296 
297                 vec = arm_nn_read_q15x2_ia(&lhs_vec);
298                 acc_0 = SMLAD(ker_0, vec, acc_0);
299 
300                 vec = arm_nn_read_q15x2_ia(&lhs_vec);
301                 acc_0 = SMLAD(ker_1, vec, acc_0);
302             }
303 
304             acc_64_0 += acc_0;
305 
306             for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
307             {
308                 const int32_t lhs_temp = (*lhs_vec);
309                 lhs_vec++;
310                 acc_64_0 += lhs_temp * (*rhs_0);
311                 rhs_0++;
312             }
313 
314             acc_64_0 += *effective_bias_ptr++;
315 
316             int32_t tmp;
317             tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
318             tmp += (int64_t)*dst;
319             tmp = MAX(tmp, NN_Q15_MIN);
320             tmp = MIN(tmp, NN_Q15_MAX);
321             *dst++ = (int16_t)tmp;
322         }
323 
324     #endif // ARM_MATH_MVEI
325 #else      // ARM_MATH_DSP
326         for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
327         {
328             const int16_t *lhs_ptr = lhs;
329 
330             int64_t result = *effective_bias_ptr++;
331 
332             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
333             {
334                 const int64_t rhs_value0 = (int8_t)*rhs_ptr;
335                 const int64_t lhs_value = *lhs_ptr;
336 
337                 result += lhs_value * rhs_value0;
338                 ++rhs_ptr;
339                 ++lhs_ptr;
340             }
341 
342             // Quantize down
343             result = arm_nn_requantize_s64(result, reduced_multiplier, dst_shift);
344             result += (int64_t)*dst;
345 
346             // Clamp the result
347             result = ((result) > (NN_Q15_MIN) ? (result) : (NN_Q15_MIN));
348             result = ((result) < (NN_Q15_MAX) ? (result) : (NN_Q15_MAX));
349 
350             *dst++ = (int16_t)result;
351         }
352 #endif     // ARM_MATH_DSP
353 
354         lhs += rhs_cols * batch_offset;
355     }
356 
357     return ARM_CMSIS_NN_SUCCESS;
358 }
359 
360 /**
361  * @} end of Doxygen group
362  */
363