1 /*
2  * SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_svdf_s8
22  * Description:  s8 vector by matrix (transposed) multiplication with
23  *               s16 output. Targetted at SVDF operator.
24  *
25  * $Date:        28 March 2023
26  * $Revision:    V.3.2.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup groupSupport
36  */
37 
38 /**
39  * @addtogroup supportFC
40  * @{
41  */
42 
43 /*
44  * s8 vector(lhs) by matrix (transposed) multiplication
45  *
46  * Refer header file for details.
47  *
48  */
arm_nn_vec_mat_mult_t_svdf_s8(const int8_t * lhs,const int8_t * rhs,int16_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)49 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
50                                                   const int8_t *rhs,
51                                                   int16_t *dst,
52                                                   const int32_t lhs_offset,
53                                                   const int32_t dst_offset,
54                                                   const int32_t dst_multiplier,
55                                                   const int32_t dst_shift,
56                                                   const int32_t rhs_cols,
57                                                   const int32_t rhs_rows,
58                                                   const int32_t activation_min,
59                                                   const int32_t activation_max)
60 {
61     if (rhs_cols < 0 || (NN_Q31_MAX - rhs_cols) < 16 || dst_offset < 0)
62     {
63         return ARM_CMSIS_NN_ARG_ERROR;
64     }
65 
66 #if defined(ARM_MATH_MVEI)
67     int32_t row_loop_cnt = rhs_rows / 3;
68 
69     for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
70     {
71         int32_t acc_0 = 0;
72         int32_t acc_1 = 0;
73         int32_t acc_2 = 0;
74 
75         const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
76 
77         const int8_t *lhs_vec = lhs;
78         const int8_t *rhs_0 = rhs;
79         const int8_t *rhs_1 = rhs + rhs_cols;
80         const int8_t *rhs_2 = rhs + 2 * rhs_cols;
81 
82         int32_t rhs_sum_0 = 0;
83         int32_t rhs_sum_1 = 0;
84         int32_t rhs_sum_2 = 0;
85 
86         uint32_t col_cnt = (uint32_t)rhs_cols;
87 
88         for (int i = 0; i < col_loop_cnt; i++)
89         {
90             mve_pred16_t p = vctp8q(col_cnt);
91             col_cnt -= 16;
92 
93             const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
94 
95             const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
96             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
97             acc_0 = vmladavaq_s8(acc_0, ker_0, input);
98 
99             const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
100             rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_1);
101             acc_1 = vmladavaq_s8(acc_1, ker_1, input);
102 
103             const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
104             rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_2);
105             acc_2 = vmladavaq_s8(acc_2, ker_2, input);
106 
107             lhs_vec += 16;
108             rhs_0 += 16;
109             rhs_1 += 16;
110             rhs_2 += 16;
111         }
112         rhs += 3 * rhs_cols;
113 
114         int32x4_t acc = {acc_0, acc_1, acc_2, 0};
115         const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, 0};
116         acc += vdupq_n_s32(lhs_offset) * rhs_sum;
117 
118         acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
119         acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
120         acc = vminq_s32(acc, vdupq_n_s32(activation_max));
121         *(dst) = (int16_t)acc[0];
122         *(dst + dst_offset) = (int16_t)acc[1];
123         *(dst + 2 * dst_offset) = (int16_t)acc[2];
124         dst += 3 * dst_offset;
125     }
126 
127     const int loop_cnt = rhs_rows % 3;
128     for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
129     {
130         int32_t acc_0 = 0;
131         const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
132         const int8_t *lhs_vec = lhs;
133         const int8_t *rhs_0 = rhs;
134         int32_t rhs_sum_0 = 0;
135         uint32_t col_cnt = (uint32_t)rhs_cols;
136 
137         for (int i = 0; i < col_loop_cnt; i++)
138         {
139             mve_pred16_t p = vctp8q(col_cnt);
140             col_cnt -= 16;
141             const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
142 
143             const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
144             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
145             acc_0 = vmladavaq_s8(acc_0, ker_0, input);
146 
147             lhs_vec += 16;
148             rhs_0 += 16;
149         }
150         rhs += rhs_cols;
151 
152         const int32_t offsets = rhs_sum_0 * lhs_offset;
153         acc_0 = QADD(acc_0, offsets);
154         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
155 
156         // Clamp the result
157         acc_0 = MAX(acc_0, activation_min);
158         *dst = (int16_t)MIN(acc_0, activation_max);
159         dst += dst_offset;
160     }
161 
162 #elif defined(ARM_MATH_DSP)
163     int32_t row_loop_cnt = rhs_rows / 2;
164 
165     const int16_t lhs_offset_s16 = lhs_offset;
166 
167     const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
168     for (int32_t i = 0; i < row_loop_cnt; i++)
169     {
170         int32_t acc_0 = 0;
171         int32_t acc_1 = 0;
172 
173         const int8_t *lhs_vec = lhs;
174         const int8_t *rhs_0 = rhs;
175         const int8_t *rhs_1 = rhs + rhs_cols;
176         rhs += 2 * rhs_cols;
177 
178         int32_t rhs_cols_idx = 0;
179 
180         int32_t vec_0, vec_1, ker_0, ker_1;
181 
182     #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
183         #pragma clang loop unroll(disable)
184     #endif
185         for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
186         {
187             // 4 x MAC acc_0, acc1
188             vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
189             vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
190             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
191             ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
192             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
193             ker_0 = SXTB16(ker_0);
194             acc_0 = SMLAD(ker_1, vec_1, acc_0);
195             acc_0 = SMLAD(ker_0, vec_0, acc_0);
196             ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
197             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
198             ker_0 = SXTB16(ker_0);
199             acc_1 = SMLAD(ker_1, vec_1, acc_1);
200             acc_1 = SMLAD(ker_0, vec_0, acc_1);
201 
202             // 4 x MAC acc_0, acc1
203             vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
204             vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
205             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
206             ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
207             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
208             ker_0 = SXTB16(ker_0);
209             acc_0 = SMLAD(ker_1, vec_1, acc_0);
210             acc_0 = SMLAD(ker_0, vec_0, acc_0);
211             ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
212             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
213             ker_0 = SXTB16(ker_0);
214             acc_1 = SMLAD(ker_1, vec_1, acc_1);
215             acc_1 = SMLAD(ker_0, vec_0, acc_1);
216 
217             // 4 x MAC acc_0, acc1
218             vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
219             vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
220             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
221             ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
222             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
223             ker_0 = SXTB16(ker_0);
224             acc_0 = SMLAD(ker_1, vec_1, acc_0);
225             acc_0 = SMLAD(ker_0, vec_0, acc_0);
226             ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
227             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
228             ker_0 = SXTB16(ker_0);
229             acc_1 = SMLAD(ker_1, vec_1, acc_1);
230             acc_1 = SMLAD(ker_0, vec_0, acc_1);
231 
232             // 4 x MAC acc_0, acc1
233             vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
234             vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
235             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
236             ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
237             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
238             ker_0 = SXTB16(ker_0);
239             acc_0 = SMLAD(ker_1, vec_1, acc_0);
240             acc_0 = SMLAD(ker_0, vec_0, acc_0);
241             ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
242             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
243             ker_0 = SXTB16(ker_0);
244             acc_1 = SMLAD(ker_1, vec_1, acc_1);
245             acc_1 = SMLAD(ker_0, vec_0, acc_1);
246         }
247 
248     #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
249         #pragma clang loop unroll(disable)
250     #endif
251         for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
252         {
253             vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
254             vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
255 
256             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
257 
258             ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
259             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
260             ker_0 = SXTB16(ker_0);
261 
262             acc_0 = SMLAD(ker_1, vec_1, acc_0);
263             acc_0 = SMLAD(ker_0, vec_0, acc_0);
264 
265             ker_0 = arm_nn_read_s8x4_ia(&rhs_1);
266             ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
267             ker_0 = SXTB16(ker_0);
268 
269             acc_1 = SMLAD(ker_1, vec_1, acc_1);
270             acc_1 = SMLAD(ker_0, vec_0, acc_1);
271         }
272 
273     #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
274         #pragma clang loop unroll(disable)
275     #endif
276         for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
277         {
278             const int32_t lhs_temp = (*lhs_vec + lhs_offset);
279             lhs_vec++;
280             acc_0 += lhs_temp * (*rhs_0);
281             rhs_0++;
282             acc_1 += lhs_temp * (*rhs_1);
283             rhs_1++;
284         }
285 
286         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
287         acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
288 
289         // Clamp the result
290         acc_0 = MAX(acc_0, activation_min);
291         acc_0 = MIN(acc_0, activation_max);
292         acc_1 = MAX(acc_1, activation_min);
293         acc_1 = MIN(acc_1, activation_max);
294         *dst = (int16_t)acc_0;
295         *(dst + dst_offset) = (int16_t)acc_1;
296         dst += 2 * dst_offset;
297     }
298     if (rhs_rows & 0x1)
299     {
300         int32_t acc_0 = 0;
301         const int32_t col_loop_cnt = rhs_cols / 4;
302         const int8_t *lhs_vec = lhs;
303         const int8_t *rhs_0 = rhs;
304         for (int i = col_loop_cnt; i != 0; i--)
305         {
306             int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
307             int32_t vec_1 = SXTAB16(lhs_offset_s16x2, ROR((uint32_t)vec_0, 8));
308             vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
309 
310             int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0);
311             int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
312             ker_0 = SXTB16(ker_0);
313 
314             acc_0 = SMLAD(ker_1, vec_1, acc_0);
315             acc_0 = SMLAD(ker_0, vec_0, acc_0);
316         }
317         for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
318         {
319             const int32_t lhs_temp = (*lhs_vec + lhs_offset);
320             lhs_vec++;
321             acc_0 += lhs_temp * *rhs_0;
322             rhs_0++;
323         }
324         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
325 
326         // Clamp the result
327         acc_0 = MAX(acc_0, activation_min);
328         acc_0 = MIN(acc_0, activation_max);
329         *dst = (int16_t)acc_0;
330         dst += dst_offset;
331     }
332 
333 #else
334 
335     int32_t row_loop_cnt = rhs_rows / 3;
336 
337     for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
338     {
339         const int8_t *lhs_ptr = lhs;
340         const int8_t *rhs_ptr_0 = &rhs[0];
341         const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
342         const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
343 
344         int32_t res00 = 0;
345         int32_t res01 = 0;
346         int32_t res02 = 0;
347         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
348         {
349             const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
350             const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
351             const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
352             const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
353 
354             res00 += lhs_value * rhs_value0;
355             res01 += lhs_value * rhs_value1;
356             res02 += lhs_value * rhs_value2;
357 
358             ++rhs_ptr_0;
359             ++rhs_ptr_1;
360             ++rhs_ptr_2;
361             ++lhs_ptr;
362         }
363         // Quantize down
364         res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
365         res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
366         res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
367 
368         // Clamp the result
369         res00 = MAX(res00, activation_min);
370         res00 = MIN(res00, activation_max);
371         res01 = MAX(res01, activation_min);
372         res01 = MIN(res01, activation_max);
373         res02 = MAX(res02, activation_min);
374         res02 = MIN(res02, activation_max);
375 
376         *dst = (int16_t)res00;
377         *(dst + dst_offset) = (int16_t)res01;
378         *(dst + 2 * dst_offset) = (int16_t)res02;
379         dst += 3 * dst_offset;
380         rhs += 3 * rhs_cols;
381     }
382 
383     const int loop_cnt = rhs_rows % 3;
384 
385     for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
386     {
387         const int8_t *lhs_ptr = &lhs[0];
388         const int8_t *rhs_ptr = &rhs[0];
389 
390         int32_t res00 = 0;
391 
392         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
393         {
394             int32_t rhs_value0 = (int8_t)rhs_ptr[0];
395             int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
396 
397             res00 += lhs_value * rhs_value0;
398 
399             ++rhs_ptr;
400             ++lhs_ptr;
401         }
402 
403         // Quantize down
404         res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
405 
406         // Clamp the result
407         res00 = MAX(res00, activation_min);
408         res00 = MIN(res00, activation_max);
409 
410         *dst = (int16_t)res00;
411         dst += dst_offset;
412         rhs += rhs_cols;
413     }
414 #endif
415 
416     return ARM_CMSIS_NN_SUCCESS;
417 }
418 
419 /**
420  * @} end of Doxygen group
421  */
422