1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s16
22  * Description:  s16 vector by matrix (transposed) multiplication
23  *
24  * $Date:        5 January 2023
25  * $Revision:    V.2.2.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 #define MAX_COL_COUNT (512)
34 
35 /**
36  * @ingroup groupSupport
37  */
38 
39 /**
40  * @addtogroup supportFC
41  * @{
42  */
43 
44 /*
45  * s16 vector(lhs) by matrix (transposed) multiplication
46  *
47  * Refer header file for details.
48  *
49  */
arm_nn_vec_mat_mult_t_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)50 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
51                                               const int8_t *rhs,
52                                               const int64_t *bias,
53                                               int16_t *dst,
54                                               const int32_t dst_multiplier,
55                                               const int32_t dst_shift,
56                                               const int32_t rhs_cols,
57                                               const int32_t rhs_rows,
58                                               const int32_t activation_min,
59                                               const int32_t activation_max)
60 {
61 #if defined(ARM_MATH_DSP)
62 
63     int32_t rhs_cols_fast = rhs_cols;
64 
65     if (rhs_cols > MAX_COL_COUNT)
66     {
67         rhs_cols_fast = MAX_COL_COUNT;
68     }
69 
70     #if defined(ARM_MATH_MVEI)
71     int32_t row_loop_cnt = rhs_rows / 4;
72     int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
73 
74     for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
75     {
76         int32_t col_cnt = rhs_cols_fast;
77 
78         const int16_t *lhs_ptr = lhs;
79         const int8_t *rhs_ptr_0 = rhs;
80         const int8_t *rhs_ptr_1 = rhs + rhs_cols;
81         const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
82         const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
83 
84         int32_t result_0 = 0;
85         int32_t result_1 = 0;
86         int32_t result_2 = 0;
87         int32_t result_3 = 0;
88 
89         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
90         {
91             mve_pred16_t pred = vctp16q(col_cnt);
92             col_cnt -= 8;
93 
94             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
95 
96             int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
97             int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
98             int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
99             int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
100 
101             result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
102             result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
103             result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
104             result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
105 
106             lhs_ptr += 8;
107 
108             rhs_ptr_0 += 8;
109             rhs_ptr_1 += 8;
110             rhs_ptr_2 += 8;
111             rhs_ptr_3 += 8;
112         }
113 
114         int64_t result_64_0 = result_0;
115         int64_t result_64_1 = result_1;
116         int64_t result_64_2 = result_2;
117         int64_t result_64_3 = result_3;
118 
119         if (rhs_cols > MAX_COL_COUNT)
120         {
121             for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
122             {
123                 const int16_t lhs_temp = *lhs_ptr++;
124 
125                 result_64_0 += *rhs_ptr_0++ * lhs_temp;
126                 result_64_1 += *rhs_ptr_1++ * lhs_temp;
127                 result_64_2 += *rhs_ptr_2++ * lhs_temp;
128                 result_64_3 += *rhs_ptr_3++ * lhs_temp;
129             }
130         }
131 
132         if (bias)
133         {
134             result_64_0 += *bias++;
135             result_64_1 += *bias++;
136             result_64_2 += *bias++;
137             result_64_3 += *bias++;
138         }
139 
140         int32_t tmp;
141         tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
142         tmp = MAX(tmp, activation_min);
143         tmp = MIN(tmp, activation_max);
144         *dst++ = (int16_t)tmp;
145 
146         tmp = 0;
147         tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
148         tmp = MAX(tmp, activation_min);
149         tmp = MIN(tmp, activation_max);
150         *dst++ = (int16_t)tmp;
151 
152         tmp = 0;
153         tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
154         tmp = MAX(tmp, activation_min);
155         tmp = MIN(tmp, activation_max);
156         *dst++ = (int16_t)tmp;
157 
158         tmp = 0;
159         tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
160         tmp = MAX(tmp, activation_min);
161         tmp = MIN(tmp, activation_max);
162         *dst++ = (int16_t)tmp;
163 
164         rhs += 4 * rhs_cols;
165     }
166 
167     for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
168     {
169         int32_t result = 0;
170 
171         col_loop_cnt = (rhs_cols_fast + 7) / 8;
172 
173         const int16_t *lhs_ptr = lhs;
174         const int8_t *rhs_ptr = rhs;
175 
176         int32_t col_cnt = (int32_t)rhs_cols_fast;
177 
178         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
179         {
180             mve_pred16_t pred = vctp16q(col_cnt);
181             col_cnt -= 8;
182 
183             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
184             int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
185 
186             result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
187 
188             lhs_ptr += 8;
189             rhs_ptr += 8;
190         }
191 
192         int64_t result_64 = result;
193 
194         if (bias)
195         {
196             result_64 += *bias++;
197         }
198 
199         if (rhs_cols > MAX_COL_COUNT)
200         {
201             for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
202             {
203                 const int16_t lhs_temp = *lhs_ptr++;
204 
205                 result_64 += *rhs_ptr++ * lhs_temp;
206             }
207         }
208 
209         int32_t tmp = 0;
210         tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
211         tmp = MAX(tmp, activation_min);
212         tmp = MIN(tmp, activation_max);
213         *dst++ = (int16_t)tmp;
214 
215         rhs += rhs_cols;
216     }
217 
218     #else // ARM_MATH_MVEI
219 
220     const int32_t row_loop_cnt = rhs_rows / 2;
221 
222     for (int32_t i = 0; i < row_loop_cnt; i++)
223     {
224 
225         int64_t acc_64_0 = 0;
226         int64_t acc_64_1 = 0;
227         int32_t acc_0 = 0;
228         int32_t acc_1 = 0;
229 
230         const int32_t col_loop_cnt = rhs_cols_fast / 4;
231 
232         const int16_t *lhs_vec = lhs;
233         const int8_t *rhs_0 = rhs;
234         const int8_t *rhs_1 = rhs + rhs_cols;
235         rhs += 2 * rhs_cols;
236 
237         for (int j = col_loop_cnt; j != 0; j--)
238         {
239             int32_t ker_0, ker_1, vec_part_0, vec_part_1;
240 
241             vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
242             vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
243 
244             rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
245 
246             acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
247             acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
248 
249             rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
250 
251             acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
252             acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
253         }
254 
255         acc_64_0 += acc_0;
256         acc_64_1 += acc_1;
257 
258         for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
259         {
260             const int32_t lhs_temp = (*lhs_vec);
261             lhs_vec++;
262             acc_64_0 += lhs_temp * (*rhs_0);
263             rhs_0++;
264             acc_64_1 += lhs_temp * (*rhs_1);
265             rhs_1++;
266         }
267 
268         if (bias)
269         {
270             acc_64_0 += *bias++;
271             acc_64_1 += *bias++;
272         }
273         int32_t tmp;
274 
275         tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
276         tmp = MAX(tmp, activation_min);
277         tmp = MIN(tmp, activation_max);
278         *dst++ = (int16_t)tmp;
279 
280         tmp = arm_nn_requantize_s64(acc_64_1, dst_multiplier, dst_shift);
281         tmp = MAX(tmp, activation_min);
282         tmp = MIN(tmp, activation_max);
283         *dst++ = (int16_t)tmp;
284     }
285 
286     if (rhs_rows & 0x1)
287     {
288         int64_t acc_64_0 = 0;
289         int32_t acc_0 = 0;
290         const int32_t col_loop_cnt = rhs_cols_fast / 4;
291 
292         const int16_t *lhs_vec = lhs;
293         const int8_t *rhs_0 = rhs;
294 
295         for (int i = col_loop_cnt; i != 0; i--)
296         {
297             int32_t ker_0, ker_1, vec;
298             rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
299 
300             vec = arm_nn_read_q15x2_ia(&lhs_vec);
301             acc_0 = SMLAD(ker_0, vec, acc_0);
302 
303             vec = arm_nn_read_q15x2_ia(&lhs_vec);
304             acc_0 = SMLAD(ker_1, vec, acc_0);
305         }
306 
307         acc_64_0 += acc_0;
308 
309         for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
310         {
311             const int32_t lhs_temp = (*lhs_vec);
312             lhs_vec++;
313             acc_64_0 += lhs_temp * (*rhs_0);
314             rhs_0++;
315         }
316 
317         if (bias)
318         {
319             acc_64_0 += *bias++;
320         }
321         int32_t tmp;
322         tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
323         tmp = MAX(tmp, activation_min);
324         tmp = MIN(tmp, activation_max);
325         *dst++ = (int16_t)tmp;
326     }
327 
328     #endif // ARM_MATH_MVEI
329 #else      // ARM_MATH_DSP
330     for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
331     {
332         const int16_t *lhs_ptr = lhs;
333         const int8_t *rhs_ptr_0 = &rhs[0];
334 
335         int64_t result = 0;
336 
337         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
338         {
339             const int64_t rhs_value0 = (int8_t)*rhs_ptr_0;
340             const int64_t lhs_value = *lhs_ptr;
341 
342             result += lhs_value * rhs_value0;
343 
344             ++rhs_ptr_0;
345             ++lhs_ptr;
346         }
347 
348         if (bias)
349         {
350             result += *bias++;
351         }
352         // Quantize down
353         result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
354 
355         // Clamp the result
356         result = ((result) > (activation_min) ? (result) : (activation_min));
357         result = ((result) < (activation_max) ? (result) : (activation_max));
358 
359         *dst++ = (int16_t)result;
360         rhs += rhs_cols;
361     }
362 #endif     // ARM_MATH_DSP
363 
364     return ARM_CMSIS_NN_SUCCESS;
365 }
366 
367 /**
368  * @} end of Doxygen group
369  */
370