1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_nt_t_s16
22  * Description:  Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23  *
24  * $Date:        11 April 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportConvolution
39  * @{
40  */
41 
42 /*
43  * s16 matrix multiplication with the right-hand-side matrix transposed
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_mat_mult_nt_t_s16(const int16_t * lhs,const int8_t * rhs,const cmsis_nn_bias_data * bias_data,int16_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t activation_min,const int32_t activation_max)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s16(const int16_t *lhs,
49                                              const int8_t *rhs,
50                                              const cmsis_nn_bias_data *bias_data,
51                                              int16_t *dst,
52                                              const int32_t *dst_multipliers,
53                                              const int32_t *dst_shifts,
54                                              const int32_t lhs_rows,
55                                              const int32_t rhs_rows,
56                                              const int32_t rhs_cols,
57                                              const int32_t activation_min,
58                                              const int32_t activation_max)
59 {
60 #if defined(ARM_MATH_MVEI)
61 
62     const uint32_t rhs_rows_offset = (uint32_t)rhs_rows * sizeof(int16_t);
63     const uint32x4_t scatter_offset = {
64         0, (uint32_t)rhs_rows_offset, (uint32_t)rhs_rows_offset * 2, (uint32_t)rhs_rows_offset * 3};
65 
66     const int64_t *bias_s64 = (const int64_t *)bias_data->data;
67     const int32_t *bias_s32 = (const int32_t *)bias_data->data;
68     const bool is_int32_bias = bias_data->is_int32_bias;
69 
70     const int32_t rhs_cols_fast = is_int32_bias ? rhs_cols : (rhs_cols > MAX_COL_COUNT ? MAX_COL_COUNT : rhs_cols);
71     const int32_t rhs_cols_slow = rhs_cols - MAX_COL_COUNT;
72 
73     int i_items = 0;
74     for (; i_items <= (lhs_rows - 4); i_items += 4)
75     {
76         for (int i = 0; i < rhs_rows; i++)
77         {
78             int32_t acc_n0 = 0;
79             int32_t acc_n1 = 0;
80             int32_t acc_n2 = 0;
81             int32_t acc_n3 = 0;
82 
83             const int16_t *ip_row_0 = lhs;
84             const int16_t *ip_row_1 = lhs + rhs_cols;
85             const int16_t *ip_row_2 = lhs + (2 * rhs_cols);
86             const int16_t *ip_row_3 = lhs + (3 * rhs_cols);
87             const int8_t *col_base = rhs + i * rhs_cols;
88 
89     #if defined(ARM_MATH_AUTOVECTORIZE)
90             for (int j = 0; j < rhs_cols_fast; j++)
91             {
92                 int8_t col = col_base[j];
93                 acc_n0 += ip_row_0[j] * col;
94                 acc_n1 += ip_row_1[j] * col;
95                 acc_n2 += ip_row_2[j] * col;
96                 acc_n3 += ip_row_3[j] * col;
97             }
98     #else
99             // Note: If operand initialization is moved around, use '&' constraint to
100             // specify earlyclobber operands.
101             __ASM volatile(" .p2align 2                              \n"
102                            "   wlstp.16        lr, %[cnt], 1f        \n"
103                            "   mov             %[out0], 0            \n"
104                            "   mov             %[out1], 0            \n"
105                            "   mov             %[out2], 0            \n"
106                            "   mov             %[out3], 0            \n"
107                            "   vldrb.s16       q0, [%[col]], #8      \n"
108                            "2:                                       \n"
109                            "   vldrh.u16       q1, [%[row0]], #16     \n"
110                            "   vmlava.s16      %[out0], q0, q1       \n"
111                            "   vldrh.u16       q2, [%[row1]], #16     \n"
112                            "   vmlava.s16      %[out1], q0, q2       \n"
113                            "   vldrh.u16       q3, [%[row2]], #16     \n"
114                            "   vmlava.s16      %[out2], q0, q3       \n"
115                            "   vldrh.u16       q4, [%[row3]], #16     \n"
116                            "   vmlava.s16      %[out3], q0, q4       \n"
117                            "   vldrb.s16       q0, [%[col]], #8      \n"
118                            "   letp            lr, 2b                \n"
119                            "1:                                       \n"
120                            : [col] "+l"(col_base),
121                              [row0] "+l"(ip_row_0),
122                              [row1] "+l"(ip_row_1),
123                              [row2] "+l"(ip_row_2),
124                              [row3] "+l"(ip_row_3),
125                              [out0] "=Te"(acc_n0),
126                              [out1] "=Te"(acc_n1),
127                              [out2] "=Te"(acc_n2),
128                              [out3] "=Te"(acc_n3)
129                            : [cnt] "r"(rhs_cols_fast)
130                            : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
131     #endif
132 
133             if (is_int32_bias)
134             {
135                 int32x4_t result;
136 
137                 if (bias_s32)
138                 {
139                     acc_n0 += bias_s32[i];
140                     acc_n1 += bias_s32[i];
141                     acc_n2 += bias_s32[i];
142                     acc_n3 += bias_s32[i];
143                 }
144 
145                 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
146 
147                 result = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
148 
149                 result = vmaxq_s32(result, vdupq_n_s32(activation_min));
150                 result = vminq_s32(result, vdupq_n_s32(activation_max));
151 
152                 vstrhq_scatter_offset_s32(dst, scatter_offset, result);
153             }
154             else
155             {
156                 int64_t acc_n0_s64 = acc_n0;
157                 int64_t acc_n1_s64 = acc_n1;
158                 int64_t acc_n2_s64 = acc_n2;
159                 int64_t acc_n3_s64 = acc_n3;
160 
161                 if (rhs_cols > MAX_COL_COUNT)
162                 {
163                     ip_row_0 = lhs + MAX_COL_COUNT;
164                     ip_row_1 = lhs + rhs_cols + MAX_COL_COUNT;
165                     ip_row_2 = lhs + (2 * rhs_cols) + MAX_COL_COUNT;
166                     ip_row_3 = lhs + (3 * rhs_cols) + MAX_COL_COUNT;
167                     col_base = rhs + i * rhs_cols + MAX_COL_COUNT;
168 
169                     for (int j = 0; j < rhs_cols_slow; j++)
170                     {
171                         int8_t col = col_base[j];
172                         acc_n0_s64 += ip_row_0[j] * col;
173                         acc_n1_s64 += ip_row_1[j] * col;
174                         acc_n2_s64 += ip_row_2[j] * col;
175                         acc_n3_s64 += ip_row_3[j] * col;
176                     }
177                 }
178 
179                 if (bias_s64)
180                 {
181                     acc_n0_s64 += bias_s64[i];
182                     acc_n1_s64 += bias_s64[i];
183                     acc_n2_s64 += bias_s64[i];
184                     acc_n3_s64 += bias_s64[i];
185                 }
186 
187                 int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multipliers[i]);
188                 int32_t shift = dst_shifts[i];
189 
190                 acc_n0 = arm_nn_requantize_s64(acc_n0_s64, reduced_multiplier, shift);
191                 acc_n1 = arm_nn_requantize_s64(acc_n1_s64, reduced_multiplier, shift);
192                 acc_n2 = arm_nn_requantize_s64(acc_n2_s64, reduced_multiplier, shift);
193                 acc_n3 = arm_nn_requantize_s64(acc_n3_s64, reduced_multiplier, shift);
194                 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
195 
196                 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
197                 res = vminq_s32(res, vdupq_n_s32(activation_max));
198 
199                 vstrhq_scatter_offset_s32(dst, scatter_offset, res);
200             }
201             dst++;
202         }
203 
204         lhs += 4 * rhs_cols;
205         dst += (3 * rhs_rows);
206     }
207 
208     if (is_int32_bias)
209     {
210 
211         for (; i_items < lhs_rows; i_items++)
212         {
213             int32_t acc[4];
214             const int32_t *multipliers = dst_multipliers;
215             const int32_t *shifts = dst_shifts;
216             for (int i = 0; i < rhs_rows; i++)
217             {
218                 int32_t acc_n0 = 0;
219                 const int16_t *ip_row_0 = lhs;
220                 const int8_t *col_base = rhs + i * rhs_cols;
221 
222     #if defined(ARM_MATH_AUTOVECTORIZE)
223                 for (int j = 0; j < rhs_cols; j++)
224                 {
225                     int32_t col = col_base[j];
226                     acc_n0 += ip_row_0[j] * col;
227                 }
228     #else
229                 __ASM volatile(" .p2align 2                              \n"
230                                "   wlstp.32        lr, %[cnt], 1f        \n"
231                                "   mov             %[out0], 0            \n"
232                                "2:                                       \n"
233                                "   vldrb.s32       q0, [%[col]], #4      \n"
234                                "   vldrh.s32       q1, [%[row0]], #8     \n"
235                                "   vmlava.s32      %[out0], q0, q1       \n"
236                                "   letp            lr, 2b                \n"
237                                "1:                                       \n"
238                                : [col] "+l"(col_base), [row0] "+l"(ip_row_0), [out0] "=Te"(acc_n0)
239                                : [cnt] "r"(rhs_cols)
240                                : "q0", "q1", "memory", "r14");
241     #endif
242                 if (bias_s32)
243                 {
244                     acc_n0 += bias_s32[i];
245                 }
246 
247                 const int32_t index = i & 0x3;
248                 acc[index] = acc_n0;
249 
250                 if (index == 3)
251                 {
252                     int32x4_t res = vldrwq_s32(acc);
253                     res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
254                     multipliers += 4;
255                     shifts += 4;
256                     res = vmaxq_s32(res, vdupq_n_s32(activation_min));
257                     res = vminq_s32(res, vdupq_n_s32(activation_max));
258                     vstrhq_s32(dst, res);
259                     dst += 4;
260                 }
261             }
262             lhs += rhs_cols;
263 
264             const int32_t tail_rows = rhs_rows & 0x3;
265             for (int i = 0; i < tail_rows; i++)
266             {
267                 int32_t acc_n0 = acc[i];
268                 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
269                 acc_n0 = MAX(acc_n0, activation_min);
270                 acc_n0 = MIN(acc_n0, activation_max);
271                 *dst++ = (int16_t)acc_n0;
272             }
273         }
274     }
275     else
276     {
277 
278         for (; i_items < lhs_rows; i_items++)
279         {
280             for (int i = 0; i < rhs_rows; i++)
281             {
282                 int32_t acc_n0 = 0;
283                 int64_t acc_n0_s64 = 0;
284                 const int16_t *ip_row_0 = lhs;
285                 const int8_t *col_base = rhs + i * rhs_cols;
286 
287     #if defined(ARM_MATH_AUTOVECTORIZE)
288                 for (int j = 0; j < rhs_cols_fast; j++)
289                 {
290                     int8_t col = col_base[j];
291                     acc_n0 += ip_row_0[j] * col;
292                 }
293     #else
294                 __ASM volatile(" .p2align 2                              \n"
295                                "   wlstp.32        lr, %[cnt], 1f        \n"
296                                "   mov             %[out0], 0            \n"
297                                "2:                                       \n"
298                                "   vldrb.s32       q0, [%[col]], #4      \n"
299                                "   vldrh.s32       q1, [%[row0]], #8     \n"
300                                "   vmlava.s32      %[out0], q0, q1       \n"
301                                "   letp            lr, 2b                \n"
302                                "1:                                       \n"
303                                : [col] "+l"(col_base), [row0] "+l"(ip_row_0), [out0] "=Te"(acc_n0)
304                                : [cnt] "r"(rhs_cols_fast)
305                                : "q0", "q1", "memory", "r14");
306     #endif
307 
308                 acc_n0_s64 = acc_n0;
309 
310                 if (rhs_cols > MAX_COL_COUNT)
311                 {
312                     ip_row_0 = lhs + MAX_COL_COUNT;
313                     col_base = rhs + i * rhs_cols + MAX_COL_COUNT;
314 
315                     for (int j = 0; j < rhs_cols_slow; j++)
316                     {
317                         int8_t col = col_base[j];
318                         acc_n0_s64 += ip_row_0[j] * col;
319                     }
320                 }
321 
322                 if (bias_s64)
323                 {
324                     acc_n0_s64 += bias_s64[i];
325                 }
326 
327                 int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multipliers[i]);
328                 int32_t shift = dst_shifts[i];
329 
330                 acc_n0 = arm_nn_requantize_s64(acc_n0_s64, reduced_multiplier, shift);
331                 acc_n0 = MAX(acc_n0, activation_min);
332                 acc_n0 = MIN(acc_n0, activation_max);
333                 *dst++ = (int16_t)acc_n0;
334             }
335             lhs += rhs_cols;
336         }
337     }
338 
339 #else
340     (void)lhs;
341     (void)rhs;
342     (void)dst_multipliers;
343     (void)dst_shifts;
344     (void)dst;
345     (void)activation_min;
346     (void)activation_max;
347     (void)bias_data;
348     (void)lhs_rows;
349     (void)lhs_rows;
350     (void)rhs_rows;
351     (void)rhs_cols;
352 
353     return ARM_CMSIS_NN_NO_IMPL_ERROR;
354 #endif
355     return ARM_CMSIS_NN_SUCCESS;
356 }
357 
358 /**
359  * @} end of Doxygen group
360  */
361