1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s16
22  * Description:  s16 vector by matrix (transposed) multiplication
23  *
24  * $Date:        22 March 2024
25  * $Revision:    V.2.3.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportFC
39  * @{
40  */
41 
42 /*
43  * s16 vector(lhs) by matrix (transposed) multiplication
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_vec_mat_mult_t_s16(const int16_t * lhs,const int8_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)48 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
49                                               const int8_t *rhs,
50                                               const int64_t *bias,
51                                               int16_t *dst,
52                                               const int32_t dst_multiplier,
53                                               const int32_t dst_shift,
54                                               const int32_t rhs_cols,
55                                               const int32_t rhs_rows,
56                                               const int32_t activation_min,
57                                               const int32_t activation_max)
58 {
59 #if defined(ARM_MATH_DSP)
60 
61     int32_t rhs_cols_fast = rhs_cols;
62 
63     if (rhs_cols > MAX_COL_COUNT)
64     {
65         rhs_cols_fast = MAX_COL_COUNT;
66     }
67 
68     #if defined(ARM_MATH_MVEI)
69     int32_t row_loop_cnt = rhs_rows / 4;
70     int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;
71 
72     for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
73     {
74         int32_t col_cnt = rhs_cols_fast;
75 
76         const int16_t *lhs_ptr = lhs;
77         const int8_t *rhs_ptr_0 = rhs;
78         const int8_t *rhs_ptr_1 = rhs + rhs_cols;
79         const int8_t *rhs_ptr_2 = rhs + rhs_cols * 2;
80         const int8_t *rhs_ptr_3 = rhs + rhs_cols * 3;
81 
82         int32_t result_0 = 0;
83         int32_t result_1 = 0;
84         int32_t result_2 = 0;
85         int32_t result_3 = 0;
86 
87         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
88         {
89             mve_pred16_t pred = vctp16q(col_cnt);
90             col_cnt -= 8;
91 
92             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
93 
94             int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
95             int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
96             int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
97             int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);
98 
99             result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
100             result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
101             result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
102             result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);
103 
104             lhs_ptr += 8;
105 
106             rhs_ptr_0 += 8;
107             rhs_ptr_1 += 8;
108             rhs_ptr_2 += 8;
109             rhs_ptr_3 += 8;
110         }
111 
112         int64_t result_64_0 = result_0;
113         int64_t result_64_1 = result_1;
114         int64_t result_64_2 = result_2;
115         int64_t result_64_3 = result_3;
116 
117         if (rhs_cols > MAX_COL_COUNT)
118         {
119             for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
120             {
121                 const int16_t lhs_temp = *lhs_ptr++;
122 
123                 result_64_0 += *rhs_ptr_0++ * lhs_temp;
124                 result_64_1 += *rhs_ptr_1++ * lhs_temp;
125                 result_64_2 += *rhs_ptr_2++ * lhs_temp;
126                 result_64_3 += *rhs_ptr_3++ * lhs_temp;
127             }
128         }
129 
130         if (bias)
131         {
132             result_64_0 += *bias++;
133             result_64_1 += *bias++;
134             result_64_2 += *bias++;
135             result_64_3 += *bias++;
136         }
137 
138         int32_t tmp;
139         tmp = arm_nn_requantize_s64(result_64_0, dst_multiplier, dst_shift);
140         tmp = MAX(tmp, activation_min);
141         tmp = MIN(tmp, activation_max);
142         *dst++ = (int16_t)tmp;
143 
144         tmp = 0;
145         tmp = arm_nn_requantize_s64(result_64_1, dst_multiplier, dst_shift);
146         tmp = MAX(tmp, activation_min);
147         tmp = MIN(tmp, activation_max);
148         *dst++ = (int16_t)tmp;
149 
150         tmp = 0;
151         tmp = arm_nn_requantize_s64(result_64_2, dst_multiplier, dst_shift);
152         tmp = MAX(tmp, activation_min);
153         tmp = MIN(tmp, activation_max);
154         *dst++ = (int16_t)tmp;
155 
156         tmp = 0;
157         tmp = arm_nn_requantize_s64(result_64_3, dst_multiplier, dst_shift);
158         tmp = MAX(tmp, activation_min);
159         tmp = MIN(tmp, activation_max);
160         *dst++ = (int16_t)tmp;
161 
162         rhs += 4 * rhs_cols;
163     }
164 
165     for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
166     {
167         int32_t result = 0;
168 
169         col_loop_cnt = (rhs_cols_fast + 7) / 8;
170 
171         const int16_t *lhs_ptr = lhs;
172         const int8_t *rhs_ptr = rhs;
173 
174         int32_t col_cnt = (int32_t)rhs_cols_fast;
175 
176         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
177         {
178             mve_pred16_t pred = vctp16q(col_cnt);
179             col_cnt -= 8;
180 
181             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
182             int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr, pred);
183 
184             result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);
185 
186             lhs_ptr += 8;
187             rhs_ptr += 8;
188         }
189 
190         int64_t result_64 = result;
191 
192         if (bias)
193         {
194             result_64 += *bias++;
195         }
196 
197         if (rhs_cols > MAX_COL_COUNT)
198         {
199             for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
200             {
201                 const int16_t lhs_temp = *lhs_ptr++;
202 
203                 result_64 += *rhs_ptr++ * lhs_temp;
204             }
205         }
206 
207         int32_t tmp = 0;
208         tmp = arm_nn_requantize_s64(result_64, dst_multiplier, dst_shift);
209         tmp = MAX(tmp, activation_min);
210         tmp = MIN(tmp, activation_max);
211         *dst++ = (int16_t)tmp;
212 
213         rhs += rhs_cols;
214     }
215 
216     #else // ARM_MATH_MVEI
217 
218     const int32_t row_loop_cnt = rhs_rows / 2;
219 
220     for (int32_t i = 0; i < row_loop_cnt; i++)
221     {
222 
223         int64_t acc_64_0 = 0;
224         int64_t acc_64_1 = 0;
225         int32_t acc_0 = 0;
226         int32_t acc_1 = 0;
227 
228         const int32_t col_loop_cnt = rhs_cols_fast / 4;
229 
230         const int16_t *lhs_vec = lhs;
231         const int8_t *rhs_0 = rhs;
232         const int8_t *rhs_1 = rhs + rhs_cols;
233         rhs += 2 * rhs_cols;
234 
235         for (int j = col_loop_cnt; j != 0; j--)
236         {
237             int32_t ker_0, ker_1, vec_part_0, vec_part_1;
238 
239             vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
240             vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
241 
242             rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
243 
244             acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
245             acc_0 = SMLAD(ker_1, vec_part_1, acc_0);
246 
247             rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);
248 
249             acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
250             acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
251         }
252 
253         acc_64_0 += acc_0;
254         acc_64_1 += acc_1;
255 
256         for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
257         {
258             const int32_t lhs_temp = (*lhs_vec);
259             lhs_vec++;
260             acc_64_0 += lhs_temp * (*rhs_0);
261             rhs_0++;
262             acc_64_1 += lhs_temp * (*rhs_1);
263             rhs_1++;
264         }
265 
266         if (bias)
267         {
268             acc_64_0 += *bias++;
269             acc_64_1 += *bias++;
270         }
271         int32_t tmp;
272 
273         tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
274         tmp = MAX(tmp, activation_min);
275         tmp = MIN(tmp, activation_max);
276         *dst++ = (int16_t)tmp;
277 
278         tmp = arm_nn_requantize_s64(acc_64_1, dst_multiplier, dst_shift);
279         tmp = MAX(tmp, activation_min);
280         tmp = MIN(tmp, activation_max);
281         *dst++ = (int16_t)tmp;
282     }
283 
284     if (rhs_rows & 0x1)
285     {
286         int64_t acc_64_0 = 0;
287         int32_t acc_0 = 0;
288         const int32_t col_loop_cnt = rhs_cols_fast / 4;
289 
290         const int16_t *lhs_vec = lhs;
291         const int8_t *rhs_0 = rhs;
292 
293         for (int i = col_loop_cnt; i != 0; i--)
294         {
295             int32_t ker_0, ker_1, vec;
296             rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);
297 
298             vec = arm_nn_read_q15x2_ia(&lhs_vec);
299             acc_0 = SMLAD(ker_0, vec, acc_0);
300 
301             vec = arm_nn_read_q15x2_ia(&lhs_vec);
302             acc_0 = SMLAD(ker_1, vec, acc_0);
303         }
304 
305         acc_64_0 += acc_0;
306 
307         for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
308         {
309             const int32_t lhs_temp = (*lhs_vec);
310             lhs_vec++;
311             acc_64_0 += lhs_temp * (*rhs_0);
312             rhs_0++;
313         }
314 
315         if (bias)
316         {
317             acc_64_0 += *bias++;
318         }
319         int32_t tmp;
320         tmp = arm_nn_requantize_s64(acc_64_0, dst_multiplier, dst_shift);
321         tmp = MAX(tmp, activation_min);
322         tmp = MIN(tmp, activation_max);
323         *dst++ = (int16_t)tmp;
324     }
325 
326     #endif // ARM_MATH_MVEI
327 #else      // ARM_MATH_DSP
328     for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
329     {
330         const int16_t *lhs_ptr = lhs;
331         const int8_t *rhs_ptr_0 = &rhs[0];
332 
333         int64_t result = 0;
334 
335         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
336         {
337             const int64_t rhs_value0 = (int8_t)*rhs_ptr_0;
338             const int64_t lhs_value = *lhs_ptr;
339 
340             result += lhs_value * rhs_value0;
341 
342             ++rhs_ptr_0;
343             ++lhs_ptr;
344         }
345 
346         if (bias)
347         {
348             result += *bias++;
349         }
350         // Quantize down
351         result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
352 
353         // Clamp the result
354         result = ((result) > (activation_min) ? (result) : (activation_min));
355         result = ((result) < (activation_max) ? (result) : (activation_max));
356 
357         *dst++ = (int16_t)result;
358         rhs += rhs_cols;
359     }
360 #endif     // ARM_MATH_DSP
361 
362     return ARM_CMSIS_NN_SUCCESS;
363 }
364 
365 /**
366  * @} end of Doxygen group
367  */
368