1 /*
2 * SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_svdf_s8
22 * Description: s8 vector by matrix (transposed) multiplication with
23 * s16 output. Targetted at SVDF operator.
24 *
25 * $Date: 28 March 2023
26 * $Revision: V.3.2.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 *
30 * -------------------------------------------------------------------- */
31
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup groupSupport
36 */
37
38 /**
39 * @addtogroup supportFC
40 * @{
41 */
42
43 /*
44 * s8 vector(lhs) by matrix (transposed) multiplication
45 *
46 * Refer header file for details.
47 *
48 */
arm_nn_vec_mat_mult_t_svdf_s8(const int8_t * lhs,const int8_t * rhs,int16_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)49 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
50 const int8_t *rhs,
51 int16_t *dst,
52 const int32_t lhs_offset,
53 const int32_t dst_offset,
54 const int32_t dst_multiplier,
55 const int32_t dst_shift,
56 const int32_t rhs_cols,
57 const int32_t rhs_rows,
58 const int32_t activation_min,
59 const int32_t activation_max)
60 {
61 if (rhs_cols < 0 || (NN_Q31_MAX - rhs_cols) < 16 || dst_offset < 0)
62 {
63 return ARM_CMSIS_NN_ARG_ERROR;
64 }
65
66 #if defined(ARM_MATH_MVEI)
67 int32_t row_loop_cnt = rhs_rows / 3;
68
69 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
70 {
71 int32_t acc_0 = 0;
72 int32_t acc_1 = 0;
73 int32_t acc_2 = 0;
74
75 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
76
77 const int8_t *lhs_vec = lhs;
78 const int8_t *rhs_0 = rhs;
79 const int8_t *rhs_1 = rhs + rhs_cols;
80 const int8_t *rhs_2 = rhs + 2 * rhs_cols;
81
82 int32_t rhs_sum_0 = 0;
83 int32_t rhs_sum_1 = 0;
84 int32_t rhs_sum_2 = 0;
85
86 uint32_t col_cnt = (uint32_t)rhs_cols;
87
88 for (int i = 0; i < col_loop_cnt; i++)
89 {
90 mve_pred16_t p = vctp8q(col_cnt);
91 col_cnt -= 16;
92
93 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
94
95 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
96 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
97 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
98
99 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
100 rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_1);
101 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
102
103 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
104 rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_2);
105 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
106
107 lhs_vec += 16;
108 rhs_0 += 16;
109 rhs_1 += 16;
110 rhs_2 += 16;
111 }
112 rhs += 3 * rhs_cols;
113
114 int32x4_t acc = {acc_0, acc_1, acc_2, 0};
115 const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, 0};
116 acc += vdupq_n_s32(lhs_offset) * rhs_sum;
117
118 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
119 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
120 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
121 *(dst) = (int16_t)acc[0];
122 *(dst + dst_offset) = (int16_t)acc[1];
123 *(dst + 2 * dst_offset) = (int16_t)acc[2];
124 dst += 3 * dst_offset;
125 }
126
127 const int loop_cnt = rhs_rows % 3;
128 for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
129 {
130 int32_t acc_0 = 0;
131 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
132 const int8_t *lhs_vec = lhs;
133 const int8_t *rhs_0 = rhs;
134 int32_t rhs_sum_0 = 0;
135 uint32_t col_cnt = (uint32_t)rhs_cols;
136
137 for (int i = 0; i < col_loop_cnt; i++)
138 {
139 mve_pred16_t p = vctp8q(col_cnt);
140 col_cnt -= 16;
141 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
142
143 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
144 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
145 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
146
147 lhs_vec += 16;
148 rhs_0 += 16;
149 }
150 rhs += rhs_cols;
151
152 const int32_t offsets = rhs_sum_0 * lhs_offset;
153 acc_0 = QADD(acc_0, offsets);
154 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
155
156 // Clamp the result
157 acc_0 = MAX(acc_0, activation_min);
158 *dst = (int16_t)MIN(acc_0, activation_max);
159 dst += dst_offset;
160 }
161
162 #elif defined(ARM_MATH_DSP)
163 int32_t row_loop_cnt = rhs_rows / 2;
164
165 const int16_t lhs_offset_s16 = lhs_offset;
166
167 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
168 for (int32_t i = 0; i < row_loop_cnt; i++)
169 {
170 int32_t acc_0 = 0;
171 int32_t acc_1 = 0;
172
173 const int8_t *lhs_vec = lhs;
174 const int8_t *rhs_0 = rhs;
175 const int8_t *rhs_1 = rhs + rhs_cols;
176 rhs += 2 * rhs_cols;
177
178 int32_t rhs_cols_idx = 0;
179
180 int32_t vec_0, vec_1, ker_0, ker_1;
181
182 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
183 #pragma clang loop unroll(disable)
184 #endif
185 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
186 {
187 // 4 x MAC acc_0, acc1
188 vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
189 vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
190 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
191 ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
192 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
193 ker_0 = SXTB16(ker_0);
194 acc_0 = SMLAD(ker_1, vec_1, acc_0);
195 acc_0 = SMLAD(ker_0, vec_0, acc_0);
196 ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
197 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
198 ker_0 = SXTB16(ker_0);
199 acc_1 = SMLAD(ker_1, vec_1, acc_1);
200 acc_1 = SMLAD(ker_0, vec_0, acc_1);
201
202 // 4 x MAC acc_0, acc1
203 vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
204 vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
205 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
206 ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
207 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
208 ker_0 = SXTB16(ker_0);
209 acc_0 = SMLAD(ker_1, vec_1, acc_0);
210 acc_0 = SMLAD(ker_0, vec_0, acc_0);
211 ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
212 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
213 ker_0 = SXTB16(ker_0);
214 acc_1 = SMLAD(ker_1, vec_1, acc_1);
215 acc_1 = SMLAD(ker_0, vec_0, acc_1);
216
217 // 4 x MAC acc_0, acc1
218 vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
219 vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
220 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
221 ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
222 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
223 ker_0 = SXTB16(ker_0);
224 acc_0 = SMLAD(ker_1, vec_1, acc_0);
225 acc_0 = SMLAD(ker_0, vec_0, acc_0);
226 ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
227 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
228 ker_0 = SXTB16(ker_0);
229 acc_1 = SMLAD(ker_1, vec_1, acc_1);
230 acc_1 = SMLAD(ker_0, vec_0, acc_1);
231
232 // 4 x MAC acc_0, acc1
233 vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
234 vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
235 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
236 ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
237 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
238 ker_0 = SXTB16(ker_0);
239 acc_0 = SMLAD(ker_1, vec_1, acc_0);
240 acc_0 = SMLAD(ker_0, vec_0, acc_0);
241 ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
242 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
243 ker_0 = SXTB16(ker_0);
244 acc_1 = SMLAD(ker_1, vec_1, acc_1);
245 acc_1 = SMLAD(ker_0, vec_0, acc_1);
246 }
247
248 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
249 #pragma clang loop unroll(disable)
250 #endif
251 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
252 {
253 vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
254 vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
255
256 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
257
258 ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
259 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
260 ker_0 = SXTB16(ker_0);
261
262 acc_0 = SMLAD(ker_1, vec_1, acc_0);
263 acc_0 = SMLAD(ker_0, vec_0, acc_0);
264
265 ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
266 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
267 ker_0 = SXTB16(ker_0);
268
269 acc_1 = SMLAD(ker_1, vec_1, acc_1);
270 acc_1 = SMLAD(ker_0, vec_0, acc_1);
271 }
272
273 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
274 #pragma clang loop unroll(disable)
275 #endif
276 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
277 {
278 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
279 lhs_vec++;
280 acc_0 += lhs_temp * (*rhs_0);
281 rhs_0++;
282 acc_1 += lhs_temp * (*rhs_1);
283 rhs_1++;
284 }
285
286 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
287 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
288
289 // Clamp the result
290 acc_0 = MAX(acc_0, activation_min);
291 acc_0 = MIN(acc_0, activation_max);
292 acc_1 = MAX(acc_1, activation_min);
293 acc_1 = MIN(acc_1, activation_max);
294 *dst = (int16_t)acc_0;
295 *(dst + dst_offset) = (int16_t)acc_1;
296 dst += 2 * dst_offset;
297 }
298 if (rhs_rows & 0x1)
299 {
300 int32_t acc_0 = 0;
301 const int32_t col_loop_cnt = rhs_cols / 4;
302 const int8_t *lhs_vec = lhs;
303 const int8_t *rhs_0 = rhs;
304 for (int i = col_loop_cnt; i != 0; i--)
305 {
306 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
307 int32_t vec_1 = SXTAB16(lhs_offset_s16x2, ROR((uint32_t)vec_0, 8));
308 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
309
310 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
311 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
312 ker_0 = SXTB16(ker_0);
313
314 acc_0 = SMLAD(ker_1, vec_1, acc_0);
315 acc_0 = SMLAD(ker_0, vec_0, acc_0);
316 }
317 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
318 {
319 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
320 lhs_vec++;
321 acc_0 += lhs_temp * *rhs_0;
322 rhs_0++;
323 }
324 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
325
326 // Clamp the result
327 acc_0 = MAX(acc_0, activation_min);
328 acc_0 = MIN(acc_0, activation_max);
329 *dst = (int16_t)acc_0;
330 dst += dst_offset;
331 }
332
333 #else
334
335 int32_t row_loop_cnt = rhs_rows / 3;
336
337 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
338 {
339 const int8_t *lhs_ptr = lhs;
340 const int8_t *rhs_ptr_0 = &rhs[0];
341 const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
342 const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
343
344 int32_t res00 = 0;
345 int32_t res01 = 0;
346 int32_t res02 = 0;
347 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
348 {
349 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
350 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
351 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
352 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
353
354 res00 += lhs_value * rhs_value0;
355 res01 += lhs_value * rhs_value1;
356 res02 += lhs_value * rhs_value2;
357
358 ++rhs_ptr_0;
359 ++rhs_ptr_1;
360 ++rhs_ptr_2;
361 ++lhs_ptr;
362 }
363 // Quantize down
364 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
365 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
366 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
367
368 // Clamp the result
369 res00 = MAX(res00, activation_min);
370 res00 = MIN(res00, activation_max);
371 res01 = MAX(res01, activation_min);
372 res01 = MIN(res01, activation_max);
373 res02 = MAX(res02, activation_min);
374 res02 = MIN(res02, activation_max);
375
376 *dst = (int16_t)res00;
377 *(dst + dst_offset) = (int16_t)res01;
378 *(dst + 2 * dst_offset) = (int16_t)res02;
379 dst += 3 * dst_offset;
380 rhs += 3 * rhs_cols;
381 }
382
383 const int loop_cnt = rhs_rows % 3;
384
385 for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
386 {
387 const int8_t *lhs_ptr = &lhs[0];
388 const int8_t *rhs_ptr = &rhs[0];
389
390 int32_t res00 = 0;
391
392 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
393 {
394 int32_t rhs_value0 = (int8_t)rhs_ptr[0];
395 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
396
397 res00 += lhs_value * rhs_value0;
398
399 ++rhs_ptr;
400 ++lhs_ptr;
401 }
402
403 // Quantize down
404 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
405
406 // Clamp the result
407 res00 = MAX(res00, activation_min);
408 res00 = MIN(res00, activation_max);
409
410 *dst = (int16_t)res00;
411 dst += dst_offset;
412 rhs += rhs_cols;
413 }
414 #endif
415
416 return ARM_CMSIS_NN_SUCCESS;
417 }
418
419 /**
420 * @} end of Doxygen group
421 */
422