1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_nt_t_s16
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23 *
24 * $Date: 11 April 2024
25 * $Revision: V.1.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s16 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s16(const int16_t * lhs,const int8_t * rhs,const cmsis_nn_bias_data * bias_data,int16_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t activation_min,const int32_t activation_max)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s16(const int16_t *lhs,
49 const int8_t *rhs,
50 const cmsis_nn_bias_data *bias_data,
51 int16_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t activation_min,
58 const int32_t activation_max)
59 {
60 #if defined(ARM_MATH_MVEI)
61
62 const uint32_t rhs_rows_offset = (uint32_t)rhs_rows * sizeof(int16_t);
63 const uint32x4_t scatter_offset = {
64 0, (uint32_t)rhs_rows_offset, (uint32_t)rhs_rows_offset * 2, (uint32_t)rhs_rows_offset * 3};
65
66 const int64_t *bias_s64 = (const int64_t *)bias_data->data;
67 const int32_t *bias_s32 = (const int32_t *)bias_data->data;
68 const bool is_int32_bias = bias_data->is_int32_bias;
69
70 const int32_t rhs_cols_fast = is_int32_bias ? rhs_cols : (rhs_cols > MAX_COL_COUNT ? MAX_COL_COUNT : rhs_cols);
71 const int32_t rhs_cols_slow = rhs_cols - MAX_COL_COUNT;
72
73 int i_items = 0;
74 for (; i_items <= (lhs_rows - 4); i_items += 4)
75 {
76 for (int i = 0; i < rhs_rows; i++)
77 {
78 int32_t acc_n0 = 0;
79 int32_t acc_n1 = 0;
80 int32_t acc_n2 = 0;
81 int32_t acc_n3 = 0;
82
83 const int16_t *ip_row_0 = lhs;
84 const int16_t *ip_row_1 = lhs + rhs_cols;
85 const int16_t *ip_row_2 = lhs + (2 * rhs_cols);
86 const int16_t *ip_row_3 = lhs + (3 * rhs_cols);
87 const int8_t *col_base = rhs + i * rhs_cols;
88
89 #if defined(ARM_MATH_AUTOVECTORIZE)
90 for (int j = 0; j < rhs_cols_fast; j++)
91 {
92 int8_t col = col_base[j];
93 acc_n0 += ip_row_0[j] * col;
94 acc_n1 += ip_row_1[j] * col;
95 acc_n2 += ip_row_2[j] * col;
96 acc_n3 += ip_row_3[j] * col;
97 }
98 #else
99 // Note: If operand initialization is moved around, use '&' constraint to
100 // specify earlyclobber operands.
101 __ASM volatile(" .p2align 2 \n"
102 " wlstp.16 lr, %[cnt], 1f \n"
103 " mov %[out0], 0 \n"
104 " mov %[out1], 0 \n"
105 " mov %[out2], 0 \n"
106 " mov %[out3], 0 \n"
107 " vldrb.s16 q0, [%[col]], #8 \n"
108 "2: \n"
109 " vldrh.u16 q1, [%[row0]], #16 \n"
110 " vmlava.s16 %[out0], q0, q1 \n"
111 " vldrh.u16 q2, [%[row1]], #16 \n"
112 " vmlava.s16 %[out1], q0, q2 \n"
113 " vldrh.u16 q3, [%[row2]], #16 \n"
114 " vmlava.s16 %[out2], q0, q3 \n"
115 " vldrh.u16 q4, [%[row3]], #16 \n"
116 " vmlava.s16 %[out3], q0, q4 \n"
117 " vldrb.s16 q0, [%[col]], #8 \n"
118 " letp lr, 2b \n"
119 "1: \n"
120 : [col] "+l"(col_base),
121 [row0] "+l"(ip_row_0),
122 [row1] "+l"(ip_row_1),
123 [row2] "+l"(ip_row_2),
124 [row3] "+l"(ip_row_3),
125 [out0] "=Te"(acc_n0),
126 [out1] "=Te"(acc_n1),
127 [out2] "=Te"(acc_n2),
128 [out3] "=Te"(acc_n3)
129 : [cnt] "r"(rhs_cols_fast)
130 : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
131 #endif
132
133 if (is_int32_bias)
134 {
135 int32x4_t result;
136
137 if (bias_s32)
138 {
139 acc_n0 += bias_s32[i];
140 acc_n1 += bias_s32[i];
141 acc_n2 += bias_s32[i];
142 acc_n3 += bias_s32[i];
143 }
144
145 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
146
147 result = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
148
149 result = vmaxq_s32(result, vdupq_n_s32(activation_min));
150 result = vminq_s32(result, vdupq_n_s32(activation_max));
151
152 vstrhq_scatter_offset_s32(dst, scatter_offset, result);
153 }
154 else
155 {
156 int64_t acc_n0_s64 = acc_n0;
157 int64_t acc_n1_s64 = acc_n1;
158 int64_t acc_n2_s64 = acc_n2;
159 int64_t acc_n3_s64 = acc_n3;
160
161 if (rhs_cols > MAX_COL_COUNT)
162 {
163 ip_row_0 = lhs + MAX_COL_COUNT;
164 ip_row_1 = lhs + rhs_cols + MAX_COL_COUNT;
165 ip_row_2 = lhs + (2 * rhs_cols) + MAX_COL_COUNT;
166 ip_row_3 = lhs + (3 * rhs_cols) + MAX_COL_COUNT;
167 col_base = rhs + i * rhs_cols + MAX_COL_COUNT;
168
169 for (int j = 0; j < rhs_cols_slow; j++)
170 {
171 int8_t col = col_base[j];
172 acc_n0_s64 += ip_row_0[j] * col;
173 acc_n1_s64 += ip_row_1[j] * col;
174 acc_n2_s64 += ip_row_2[j] * col;
175 acc_n3_s64 += ip_row_3[j] * col;
176 }
177 }
178
179 if (bias_s64)
180 {
181 acc_n0_s64 += bias_s64[i];
182 acc_n1_s64 += bias_s64[i];
183 acc_n2_s64 += bias_s64[i];
184 acc_n3_s64 += bias_s64[i];
185 }
186
187 int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multipliers[i]);
188 int32_t shift = dst_shifts[i];
189
190 acc_n0 = arm_nn_requantize_s64(acc_n0_s64, reduced_multiplier, shift);
191 acc_n1 = arm_nn_requantize_s64(acc_n1_s64, reduced_multiplier, shift);
192 acc_n2 = arm_nn_requantize_s64(acc_n2_s64, reduced_multiplier, shift);
193 acc_n3 = arm_nn_requantize_s64(acc_n3_s64, reduced_multiplier, shift);
194 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
195
196 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
197 res = vminq_s32(res, vdupq_n_s32(activation_max));
198
199 vstrhq_scatter_offset_s32(dst, scatter_offset, res);
200 }
201 dst++;
202 }
203
204 lhs += 4 * rhs_cols;
205 dst += (3 * rhs_rows);
206 }
207
208 if (is_int32_bias)
209 {
210
211 for (; i_items < lhs_rows; i_items++)
212 {
213 int32_t acc[4];
214 const int32_t *multipliers = dst_multipliers;
215 const int32_t *shifts = dst_shifts;
216 for (int i = 0; i < rhs_rows; i++)
217 {
218 int32_t acc_n0 = 0;
219 const int16_t *ip_row_0 = lhs;
220 const int8_t *col_base = rhs + i * rhs_cols;
221
222 #if defined(ARM_MATH_AUTOVECTORIZE)
223 for (int j = 0; j < rhs_cols; j++)
224 {
225 int32_t col = col_base[j];
226 acc_n0 += ip_row_0[j] * col;
227 }
228 #else
229 __ASM volatile(" .p2align 2 \n"
230 " wlstp.32 lr, %[cnt], 1f \n"
231 " mov %[out0], 0 \n"
232 "2: \n"
233 " vldrb.s32 q0, [%[col]], #4 \n"
234 " vldrh.s32 q1, [%[row0]], #8 \n"
235 " vmlava.s32 %[out0], q0, q1 \n"
236 " letp lr, 2b \n"
237 "1: \n"
238 : [col] "+l"(col_base), [row0] "+l"(ip_row_0), [out0] "=Te"(acc_n0)
239 : [cnt] "r"(rhs_cols)
240 : "q0", "q1", "memory", "r14");
241 #endif
242 if (bias_s32)
243 {
244 acc_n0 += bias_s32[i];
245 }
246
247 const int32_t index = i & 0x3;
248 acc[index] = acc_n0;
249
250 if (index == 3)
251 {
252 int32x4_t res = vldrwq_s32(acc);
253 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
254 multipliers += 4;
255 shifts += 4;
256 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
257 res = vminq_s32(res, vdupq_n_s32(activation_max));
258 vstrhq_s32(dst, res);
259 dst += 4;
260 }
261 }
262 lhs += rhs_cols;
263
264 const int32_t tail_rows = rhs_rows & 0x3;
265 for (int i = 0; i < tail_rows; i++)
266 {
267 int32_t acc_n0 = acc[i];
268 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
269 acc_n0 = MAX(acc_n0, activation_min);
270 acc_n0 = MIN(acc_n0, activation_max);
271 *dst++ = (int16_t)acc_n0;
272 }
273 }
274 }
275 else
276 {
277
278 for (; i_items < lhs_rows; i_items++)
279 {
280 for (int i = 0; i < rhs_rows; i++)
281 {
282 int32_t acc_n0 = 0;
283 int64_t acc_n0_s64 = 0;
284 const int16_t *ip_row_0 = lhs;
285 const int8_t *col_base = rhs + i * rhs_cols;
286
287 #if defined(ARM_MATH_AUTOVECTORIZE)
288 for (int j = 0; j < rhs_cols_fast; j++)
289 {
290 int8_t col = col_base[j];
291 acc_n0 += ip_row_0[j] * col;
292 }
293 #else
294 __ASM volatile(" .p2align 2 \n"
295 " wlstp.32 lr, %[cnt], 1f \n"
296 " mov %[out0], 0 \n"
297 "2: \n"
298 " vldrb.s32 q0, [%[col]], #4 \n"
299 " vldrh.s32 q1, [%[row0]], #8 \n"
300 " vmlava.s32 %[out0], q0, q1 \n"
301 " letp lr, 2b \n"
302 "1: \n"
303 : [col] "+l"(col_base), [row0] "+l"(ip_row_0), [out0] "=Te"(acc_n0)
304 : [cnt] "r"(rhs_cols_fast)
305 : "q0", "q1", "memory", "r14");
306 #endif
307
308 acc_n0_s64 = acc_n0;
309
310 if (rhs_cols > MAX_COL_COUNT)
311 {
312 ip_row_0 = lhs + MAX_COL_COUNT;
313 col_base = rhs + i * rhs_cols + MAX_COL_COUNT;
314
315 for (int j = 0; j < rhs_cols_slow; j++)
316 {
317 int8_t col = col_base[j];
318 acc_n0_s64 += ip_row_0[j] * col;
319 }
320 }
321
322 if (bias_s64)
323 {
324 acc_n0_s64 += bias_s64[i];
325 }
326
327 int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multipliers[i]);
328 int32_t shift = dst_shifts[i];
329
330 acc_n0 = arm_nn_requantize_s64(acc_n0_s64, reduced_multiplier, shift);
331 acc_n0 = MAX(acc_n0, activation_min);
332 acc_n0 = MIN(acc_n0, activation_max);
333 *dst++ = (int16_t)acc_n0;
334 }
335 lhs += rhs_cols;
336 }
337 }
338
339 #else
340 (void)lhs;
341 (void)rhs;
342 (void)dst_multipliers;
343 (void)dst_shifts;
344 (void)dst;
345 (void)activation_min;
346 (void)activation_max;
347 (void)bias_data;
348 (void)lhs_rows;
349 (void)lhs_rows;
350 (void)rhs_rows;
351 (void)rhs_cols;
352
353 return ARM_CMSIS_NN_NO_IMPL_ERROR;
354 #endif
355 return ARM_CMSIS_NN_SUCCESS;
356 }
357
358 /**
359 * @} end of Doxygen group
360 */
361