1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s16_s16
22  * Description:  s16 vector by s16 matrix (transposed) multiplication
23  *
24  * $Date:        19 June 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 /**
33  * @ingroup groupSupport
34  */
35 
36 /**
37  * @addtogroup supportFC
38  * @{
39  */
40 
41 /*
42  * s16 vector(lhs) by s16 matrix (transposed) multiplication
43  *
44  * Refer header file for details.
45  *
46  */
arm_nn_vec_mat_mult_t_s16_s16(const int16_t * lhs,const int16_t * rhs,const int64_t * bias,int16_t * dst,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)47 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16_s16(const int16_t *lhs,
48                                                   const int16_t *rhs,
49                                                   const int64_t *bias,
50                                                   int16_t *dst,
51                                                   const int32_t dst_multiplier,
52                                                   const int32_t dst_shift,
53                                                   const int32_t rhs_cols,
54                                                   const int32_t rhs_rows,
55                                                   const int32_t activation_min,
56                                                   const int32_t activation_max)
57 {
58 #if defined(ARM_MATH_DSP)
59 
60     #if defined(ARM_MATH_MVEI)
61     int32_t row_loop_cnt = rhs_rows / 4;
62     int32_t col_loop_cnt = (rhs_cols + 7) / 8;
63 
64     for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
65     {
66         int32_t col_cnt = rhs_cols;
67 
68         const int16_t *lhs_ptr = lhs;
69         const int16_t *rhs_ptr_0 = rhs;
70         const int16_t *rhs_ptr_1 = rhs + rhs_cols;
71         const int16_t *rhs_ptr_2 = rhs + rhs_cols * 2;
72         const int16_t *rhs_ptr_3 = rhs + rhs_cols * 3;
73 
74         int64_t result_0 = 0;
75         int64_t result_1 = 0;
76         int64_t result_2 = 0;
77         int64_t result_3 = 0;
78 
79         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
80         {
81             mve_pred16_t pred = vctp16q(col_cnt);
82             col_cnt -= 8;
83 
84             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
85 
86             int16x8_t rhs_input_0 = vldrhq_z_s16(rhs_ptr_0, pred);
87             int16x8_t rhs_input_1 = vldrhq_z_s16(rhs_ptr_1, pred);
88             int16x8_t rhs_input_2 = vldrhq_z_s16(rhs_ptr_2, pred);
89             int16x8_t rhs_input_3 = vldrhq_z_s16(rhs_ptr_3, pred);
90 
91             result_0 = vmlaldavaq_s16(result_0, lhs_input, rhs_input_0);
92             result_1 = vmlaldavaq_s16(result_1, lhs_input, rhs_input_1);
93             result_2 = vmlaldavaq_s16(result_2, lhs_input, rhs_input_2);
94             result_3 = vmlaldavaq_s16(result_3, lhs_input, rhs_input_3);
95 
96             lhs_ptr += 8;
97 
98             rhs_ptr_0 += 8;
99             rhs_ptr_1 += 8;
100             rhs_ptr_2 += 8;
101             rhs_ptr_3 += 8;
102         }
103 
104         if (bias)
105         {
106             result_0 += *bias++;
107             result_1 += *bias++;
108             result_2 += *bias++;
109             result_3 += *bias++;
110         }
111 
112         int32_t tmp;
113         tmp = arm_nn_requantize_s64(result_0, dst_multiplier, dst_shift);
114         tmp = MAX(tmp, activation_min);
115         tmp = MIN(tmp, activation_max);
116         *dst++ = (int16_t)tmp;
117 
118         tmp = 0;
119         tmp = arm_nn_requantize_s64(result_1, dst_multiplier, dst_shift);
120         tmp = MAX(tmp, activation_min);
121         tmp = MIN(tmp, activation_max);
122         *dst++ = (int16_t)tmp;
123 
124         tmp = 0;
125         tmp = arm_nn_requantize_s64(result_2, dst_multiplier, dst_shift);
126         tmp = MAX(tmp, activation_min);
127         tmp = MIN(tmp, activation_max);
128         *dst++ = (int16_t)tmp;
129 
130         tmp = 0;
131         tmp = arm_nn_requantize_s64(result_3, dst_multiplier, dst_shift);
132         tmp = MAX(tmp, activation_min);
133         tmp = MIN(tmp, activation_max);
134         *dst++ = (int16_t)tmp;
135 
136         rhs += 4 * rhs_cols;
137     }
138 
139     for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
140     {
141         int64_t result = 0;
142 
143         col_loop_cnt = (rhs_cols + 7) / 8;
144 
145         const int16_t *lhs_ptr = lhs;
146         const int16_t *rhs_ptr = rhs;
147 
148         int32_t col_cnt = rhs_cols;
149 
150         for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
151         {
152             mve_pred16_t pred = vctp16q(col_cnt);
153             col_cnt -= 8;
154 
155             int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
156             int16x8_t rhs_input = vldrhq_z_s16(rhs_ptr, pred);
157 
158             result = vmlaldavaq_p_s16(result, lhs_input, rhs_input, pred);
159 
160             lhs_ptr += 8;
161             rhs_ptr += 8;
162         }
163 
164         if (bias)
165         {
166             result += *bias++;
167         }
168 
169         int32_t tmp = 0;
170         tmp = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
171         tmp = MAX(tmp, activation_min);
172         tmp = MIN(tmp, activation_max);
173         *dst++ = (int16_t)tmp;
174 
175         rhs += rhs_cols;
176     }
177 
178     #else // ARM_MATH_MVEI
179 
180     const int32_t row_loop_cnt = rhs_rows / 2;
181 
182     for (int32_t i = 0; i < row_loop_cnt; i++)
183     {
184 
185         int64_t acc_0 = 0;
186         int64_t acc_1 = 0;
187 
188         const int32_t col_loop_cnt = rhs_cols / 4;
189 
190         const int16_t *lhs_vec = lhs;
191         const int16_t *rhs_0 = rhs;
192         const int16_t *rhs_1 = rhs + rhs_cols;
193         rhs += 2 * rhs_cols;
194 
195         for (int j = col_loop_cnt; j != 0; j--)
196         {
197             int32_t ker_0, ker_1, vec_part_0, vec_part_1;
198 
199             vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
200             vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);
201 
202             ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
203             ker_1 = arm_nn_read_q15x2_ia(&rhs_0);
204 
205             acc_0 = SMLALD(ker_0, vec_part_0, acc_0);
206             acc_0 = SMLALD(ker_1, vec_part_1, acc_0);
207 
208             ker_0 = arm_nn_read_q15x2_ia(&rhs_1);
209             ker_1 = arm_nn_read_q15x2_ia(&rhs_1);
210 
211             acc_1 = SMLALD(ker_0, vec_part_0, acc_1);
212             acc_1 = SMLALD(ker_1, vec_part_1, acc_1);
213         }
214 
215         for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
216         {
217             const int16_t lhs_temp = (*lhs_vec);
218             lhs_vec++;
219             acc_0 += lhs_temp * (*rhs_0);
220             rhs_0++;
221             acc_1 += lhs_temp * (*rhs_1);
222             rhs_1++;
223         }
224 
225         if (bias)
226         {
227             acc_0 += *bias++;
228             acc_1 += *bias++;
229         }
230         int32_t tmp;
231 
232         tmp = arm_nn_requantize_s64(acc_0, dst_multiplier, dst_shift);
233         tmp = MAX(tmp, activation_min);
234         tmp = MIN(tmp, activation_max);
235         *dst++ = (int16_t)tmp;
236 
237         tmp = arm_nn_requantize_s64(acc_1, dst_multiplier, dst_shift);
238         tmp = MAX(tmp, activation_min);
239         tmp = MIN(tmp, activation_max);
240         *dst++ = (int16_t)tmp;
241     }
242 
243     if (rhs_rows & 0x1)
244     {
245         int64_t acc_0 = 0;
246         const int32_t col_loop_cnt = rhs_cols / 4;
247 
248         const int16_t *lhs_vec = lhs;
249         const int16_t *rhs_0 = rhs;
250 
251         for (int i = col_loop_cnt; i != 0; i--)
252         {
253             int32_t ker_0, vec;
254 
255             ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
256             vec = arm_nn_read_q15x2_ia(&lhs_vec);
257             acc_0 = SMLALD(ker_0, vec, acc_0);
258 
259             ker_0 = arm_nn_read_q15x2_ia(&rhs_0);
260             vec = arm_nn_read_q15x2_ia(&lhs_vec);
261             acc_0 = SMLALD(ker_0, vec, acc_0);
262         }
263 
264         for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
265         {
266             const int16_t lhs_temp = (*lhs_vec);
267             lhs_vec++;
268             acc_0 += lhs_temp * (*rhs_0);
269             rhs_0++;
270         }
271 
272         if (bias)
273         {
274             acc_0 += *bias++;
275         }
276         int32_t tmp;
277         tmp = arm_nn_requantize_s64(acc_0, dst_multiplier, dst_shift);
278         tmp = MAX(tmp, activation_min);
279         tmp = MIN(tmp, activation_max);
280         *dst++ = (int16_t)tmp;
281     }
282 
283     #endif // ARM_MATH_MVEI
284 #else      // ARM_MATH_DSP
285     for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
286     {
287         const int16_t *lhs_ptr = lhs;
288         const int16_t *rhs_ptr_0 = &rhs[0];
289 
290         int64_t result = 0;
291 
292         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
293         {
294             const int64_t rhs_value0 = *rhs_ptr_0;
295             const int64_t lhs_value = *lhs_ptr;
296 
297             result += lhs_value * rhs_value0;
298 
299             ++rhs_ptr_0;
300             ++lhs_ptr;
301         }
302 
303         if (bias)
304         {
305             result += *bias++;
306         }
307         // Quantize down
308         result = arm_nn_requantize_s64(result, dst_multiplier, dst_shift);
309 
310         // Clamp the result
311         result = MAX(result, activation_min);
312         result = MIN(result, activation_max);
313 
314         *dst++ = (int16_t)result;
315         rhs += rhs_cols;
316     }
317 #endif     // ARM_MATH_DSP
318 
319     return ARM_CMSIS_NN_SUCCESS;
320 }
321 
322 /**
323  * @} end of Doxygen group
324  */
325