1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s8
22  * Description:  s8 vector by matrix (transposed) multiplication
23  *
24  * $Date:        5 Sep 2024
25  * $Revision:    V.6.2.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @defgroup supportFC Fully Connected
39  *
40  * Support functions for Fully Connected
41  *
42  */
43 
44 /**
45  * @addtogroup supportFC
46  * @{
47  */
48 
49 /*
50  * s8 vector(lhs) by matrix (transposed) multiplication
51  *
52  * Refer header file for details.
53  *
54  */
55 #if !defined(ARM_MATH_MVEI) && defined(ARM_MATH_DSP) && !defined(__ARMCC_VERSION) && !defined(__ICCARM__)
56     #pragma GCC optimize("unroll-loops")
57 #endif
arm_nn_vec_mat_mult_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * kernel_sum,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max,const int32_t address_offset,const int32_t rhs_offset)58 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
59                                              const int8_t *rhs,
60                                              const int32_t *kernel_sum,
61                                              const int32_t *bias,
62                                              int8_t *dst,
63                                              const int32_t lhs_offset,
64                                              const int32_t dst_offset,
65                                              const int32_t dst_multiplier,
66                                              const int32_t dst_shift,
67                                              const int32_t rhs_cols,
68                                              const int32_t rhs_rows,
69                                              const int32_t activation_min,
70                                              const int32_t activation_max,
71                                              const int32_t address_offset,
72                                              const int32_t rhs_offset)
73 {
74     if (rhs_offset)
75     {
76 #if defined(ARM_MATH_MVEI)
77         (void)bias;
78         (void)lhs_offset;
79         const int32_t row_loop_cnt = rhs_rows / 4;
80         const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
81 
82         for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
83         {
84             int32_t acc_0 = *kernel_sum++;
85             int32_t acc_1 = *kernel_sum++;
86             int32_t acc_2 = *kernel_sum++;
87             int32_t acc_3 = *kernel_sum++;
88 
89             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
90 
91             const int8_t *lhs_vec = lhs;
92             const int8_t *rhs_0_ptr = rhs;
93             const int8_t *rhs_1_ptr = rhs + rhs_cols;
94             const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
95             const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
96 
97             int32_t lhs_sum = 0;
98 
99             uint32_t col_cnt = (uint32_t)rhs_cols;
100 
101             for (int32_t i = 0; i < col_loop_cnt; i++)
102             {
103                 mve_pred16_t p = vctp8q(col_cnt);
104                 col_cnt -= 16;
105 
106                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
107                 lhs_sum = vaddvaq_s8(lhs_sum, input);
108 
109                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
110                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
111 
112                 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
113                 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
114 
115                 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
116                 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
117 
118                 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
119                 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
120 
121                 lhs_vec += 16;
122                 rhs_0_ptr += 16;
123                 rhs_1_ptr += 16;
124                 rhs_2_ptr += 16;
125                 rhs_3_ptr += 16;
126             }
127             rhs += 4 * rhs_cols;
128 
129             int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
130 
131             acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
132 
133             acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
134             acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
135             acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
136             acc = vminq_s32(acc, vdupq_n_s32(activation_max));
137 
138             vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
139 
140             dst += 4 * address_offset;
141         }
142 
143         const int loop_cnt = rhs_rows % 4;
144         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
145         {
146             int32_t acc_0 = *kernel_sum++;
147             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
148             const int8_t *lhs_vec = lhs;
149             const int8_t *rhs_ptr = rhs;
150             int32_t lhs_sum = 0;
151             uint32_t col_cnt = (uint32_t)rhs_cols;
152 
153             for (int32_t i = 0; i < col_loop_cnt; i++)
154             {
155                 mve_pred16_t p = vctp8q(col_cnt);
156                 col_cnt -= 16;
157                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
158                 lhs_sum = vaddvaq_s8(lhs_sum, input);
159 
160                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
161                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
162 
163                 lhs_vec += 16;
164                 rhs_ptr += 16;
165             }
166             rhs += rhs_cols;
167 
168             acc_0 += lhs_sum * rhs_offset;
169 
170             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
171             acc_0 += dst_offset;
172 
173             // Clamp the result
174             acc_0 = MAX(acc_0, activation_min);
175             *dst = MIN(acc_0, activation_max);
176             dst += address_offset;
177         }
178 
179 #elif defined(ARM_MATH_DSP)
180         (void)kernel_sum;
181 
182         const int32_t row_loop_cnt = rhs_rows / 2;
183         const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
184         const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
185 
186         const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
187         const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
188 
189         for (int32_t i = 0; i < row_loop_cnt; i++)
190         {
191             int32_t acc_0 = 0;
192             int32_t acc_1 = 0;
193             if (bias)
194             {
195                 acc_0 = *bias++;
196                 acc_1 = *bias++;
197             }
198 
199             const int32_t col_loop_cnt = rhs_cols / 4;
200 
201             const int8_t *lhs_vec = lhs;
202             const int8_t *rhs_0_ptr = rhs;
203             const int8_t *rhs_1_ptr = rhs + rhs_cols;
204             rhs += 2 * rhs_cols;
205 
206             for (int32_t j = col_loop_cnt; j != 0; j--)
207             {
208                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
209                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
210 
211                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
212 
213                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
214                 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
215                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
216 
217                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
218                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
219 
220                 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
221                 ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
222                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
223 
224                 acc_1 = SMLAD(ker_1, vec_1, acc_1);
225                 acc_1 = SMLAD(ker_0, vec_0, acc_1);
226             }
227 
228             for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
229             {
230                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
231                 lhs_vec++;
232                 acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
233                 rhs_0_ptr++;
234                 acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
235                 rhs_1_ptr++;
236             }
237 
238             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
239             acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
240 
241             // Add offset
242             acc_0 += dst_offset;
243             acc_1 += dst_offset;
244             // Clamp the result
245             acc_0 = MAX(acc_0, activation_min);
246             acc_0 = MIN(acc_0, activation_max);
247             acc_1 = MAX(acc_1, activation_min);
248             acc_1 = MIN(acc_1, activation_max);
249             *dst = (int8_t)acc_0;
250             *(dst + address_offset) = (int8_t)acc_1;
251             dst += 2 * address_offset;
252         }
253 
254         if (rhs_rows & 0x1)
255         {
256             int32_t acc_0 = 0;
257             if (bias)
258             {
259                 acc_0 = *bias++;
260             }
261             const int32_t col_loop_cnt = rhs_cols / 4;
262 
263             const int8_t *lhs_vec = lhs;
264             const int8_t *rhs_ptr = rhs;
265 
266             for (int32_t i = col_loop_cnt; i != 0; i--)
267             {
268                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
269                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
270                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
271 
272                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
273                 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
274                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
275 
276                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
277                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
278             }
279 
280             for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
281             {
282                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
283                 lhs_vec++;
284                 acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
285                 rhs_ptr++;
286             }
287 
288             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
289 
290             // Add offset
291             acc_0 += dst_offset;
292             // Clamp the result
293             acc_0 = MAX(acc_0, activation_min);
294             acc_0 = MIN(acc_0, activation_max);
295             *dst = (int8_t)acc_0;
296             dst += address_offset;
297         }
298 
299 #else
300         (void)kernel_sum;
301 
302         const int32_t row_loop_cnt = rhs_rows / 3;
303 
304         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
305         {
306             const int8_t *lhs_ptr = lhs;
307             const int8_t *rhs_ptr_0 = &rhs[0];
308             const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
309             const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
310 
311             int32_t res00 = 0;
312             int32_t res01 = 0;
313             int32_t res02 = 0;
314             if (bias)
315             {
316                 res00 = *bias++;
317                 res01 = *bias++;
318                 res02 = *bias++;
319             }
320             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
321             {
322                 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
323                 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
324                 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
325                 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
326 
327                 res00 += lhs_value * rhs_value0;
328                 res01 += lhs_value * rhs_value1;
329                 res02 += lhs_value * rhs_value2;
330 
331                 ++rhs_ptr_0;
332                 ++rhs_ptr_1;
333                 ++rhs_ptr_2;
334                 ++lhs_ptr;
335             }
336 
337             // Quantize down
338             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
339             res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
340             res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
341 
342             // Add offset
343             res00 += dst_offset;
344             res01 += dst_offset;
345             res02 += dst_offset;
346 
347             // Clamp the result
348             res00 = MAX(res00, activation_min);
349             res00 = MIN(res00, activation_max);
350             res01 = MAX(res01, activation_min);
351             res01 = MIN(res01, activation_max);
352             res02 = MAX(res02, activation_min);
353             res02 = MIN(res02, activation_max);
354 
355             *dst = (int8_t)res00;
356             *(dst + address_offset) = (int8_t)res01;
357             *(dst + 2 * address_offset) = (int8_t)res02;
358             dst += 3 * address_offset;
359 
360             rhs += 3 * rhs_cols;
361         }
362 
363         const int loop_cnt = rhs_rows % 3;
364 
365         for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
366         {
367             const int8_t *lhs_ptr = &lhs[0];
368             const int8_t *rhs_ptr = &rhs[0];
369 
370             int32_t res00 = 0;
371             if (bias)
372             {
373                 res00 = *bias++;
374             }
375 
376             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
377             {
378                 int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
379                 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
380 
381                 res00 += lhs_value * rhs_value0;
382 
383                 ++rhs_ptr;
384                 ++lhs_ptr;
385             }
386 
387             // Quantize down
388             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
389 
390             // Add offset
391             res00 += dst_offset;
392 
393             // Clamp the result
394             res00 = MAX(res00, activation_min);
395             res00 = MIN(res00, activation_max);
396 
397             *dst = (int8_t)res00;
398             dst += address_offset;
399             rhs += rhs_cols;
400         }
401 #endif
402     }
403 
404     else
405     {
406 #if defined(ARM_MATH_MVEI)
407         const int32_t row_loop_cnt = rhs_rows / 4;
408         const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
409 
410         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
411         {
412             int32_t acc_0 = *kernel_sum++;
413             int32_t acc_1 = *kernel_sum++;
414             int32_t acc_2 = *kernel_sum++;
415             int32_t acc_3 = *kernel_sum++;
416 
417             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
418 
419             const int8_t *lhs_vec = lhs;
420             const int8_t *rhs_0_ptr = rhs;
421             const int8_t *rhs_1_ptr = rhs + rhs_cols;
422             const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
423             const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
424 
425             uint32_t col_cnt = (uint32_t)rhs_cols;
426 
427             for (int32_t i = 0; i < col_loop_cnt; i++)
428             {
429                 mve_pred16_t p = vctp8q(col_cnt);
430                 col_cnt -= 16;
431 
432                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
433 
434                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
435                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
436 
437                 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
438                 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
439 
440                 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
441                 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
442 
443                 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
444                 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
445 
446                 lhs_vec += 16;
447                 rhs_0_ptr += 16;
448                 rhs_1_ptr += 16;
449                 rhs_2_ptr += 16;
450                 rhs_3_ptr += 16;
451             }
452             rhs += 4 * rhs_cols;
453 
454             int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
455 
456             acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
457             acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
458             acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
459             acc = vminq_s32(acc, vdupq_n_s32(activation_max));
460 
461             vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
462 
463             dst += 4 * address_offset;
464         }
465 
466         const int loop_cnt = rhs_rows % 4;
467         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
468         {
469             int32_t acc_0 = *kernel_sum++;
470             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
471             const int8_t *lhs_vec = lhs;
472             const int8_t *rhs_ptr = rhs;
473             uint32_t col_cnt = (uint32_t)rhs_cols;
474 
475             for (int32_t i = 0; i < col_loop_cnt; i++)
476             {
477                 mve_pred16_t p = vctp8q(col_cnt);
478                 col_cnt -= 16;
479                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
480 
481                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
482                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
483 
484                 lhs_vec += 16;
485                 rhs_ptr += 16;
486             }
487             rhs += rhs_cols;
488 
489             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
490             acc_0 += dst_offset;
491 
492             // Clamp the result
493             acc_0 = MAX(acc_0, activation_min);
494             *dst = MIN(acc_0, activation_max);
495             dst += address_offset;
496         }
497 
498 #elif defined(ARM_MATH_DSP)
499         (void)kernel_sum;
500 
501         const int32_t row_loop_cnt = rhs_rows / 2;
502         const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
503         const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
504 
505         for (int32_t i = 0; i < row_loop_cnt; i++)
506         {
507             int32_t acc_0 = 0;
508             int32_t acc_1 = 0;
509             if (bias)
510             {
511                 acc_0 = *bias++;
512                 acc_1 = *bias++;
513             }
514 
515             const int32_t col_loop_cnt = rhs_cols / 4;
516 
517             const int8_t *lhs_vec = lhs;
518             const int8_t *rhs_0_ptr = rhs;
519             const int8_t *rhs_1_ptr = rhs + rhs_cols;
520             rhs += 2 * rhs_cols;
521 
522             for (int32_t j = col_loop_cnt; j != 0; j--)
523             {
524                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
525                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
526 
527                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
528 
529                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
530                 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
531                 ker_0 = SXTB16(ker_0);
532 
533                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
534                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
535 
536                 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
537                 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
538                 ker_0 = SXTB16(ker_0);
539 
540                 acc_1 = SMLAD(ker_1, vec_1, acc_1);
541                 acc_1 = SMLAD(ker_0, vec_0, acc_1);
542             }
543 
544             for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
545             {
546                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
547                 lhs_vec++;
548                 acc_0 += lhs_temp * (*rhs_0_ptr);
549                 rhs_0_ptr++;
550                 acc_1 += lhs_temp * (*rhs_1_ptr);
551                 rhs_1_ptr++;
552             }
553 
554             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
555             acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
556 
557             // Add offset
558             acc_0 += dst_offset;
559             acc_1 += dst_offset;
560             // Clamp the result
561             acc_0 = MAX(acc_0, activation_min);
562             acc_0 = MIN(acc_0, activation_max);
563             acc_1 = MAX(acc_1, activation_min);
564             acc_1 = MIN(acc_1, activation_max);
565             *dst = (int8_t)acc_0;
566             *(dst + address_offset) = (int8_t)acc_1;
567             dst += 2 * address_offset;
568         }
569 
570         if (rhs_rows & 0x1)
571         {
572             int32_t acc_0 = 0;
573             if (bias)
574             {
575                 acc_0 = *bias++;
576             }
577             const int32_t col_loop_cnt = rhs_cols / 4;
578 
579             const int8_t *lhs_vec = lhs;
580             const int8_t *rhs_ptr = rhs;
581 
582             for (int32_t i = col_loop_cnt; i != 0; i--)
583             {
584                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
585                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
586                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
587 
588                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
589                 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
590                 ker_0 = SXTB16(ker_0);
591 
592                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
593                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
594             }
595 
596             for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
597             {
598                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
599                 lhs_vec++;
600                 acc_0 += lhs_temp * (*rhs_ptr);
601                 rhs_ptr++;
602             }
603 
604             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
605 
606             // Add offset
607             acc_0 += dst_offset;
608             // Clamp the result
609             acc_0 = MAX(acc_0, activation_min);
610             acc_0 = MIN(acc_0, activation_max);
611             *dst = (int8_t)acc_0;
612             dst += address_offset;
613         }
614 
615 #else
616         (void)kernel_sum;
617 
618         const int32_t row_loop_cnt = rhs_rows / 3;
619 
620         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
621         {
622             const int8_t *lhs_ptr = lhs;
623             const int8_t *rhs_ptr_0 = &rhs[0];
624             const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
625             const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
626 
627             int32_t res00 = 0;
628             int32_t res01 = 0;
629             int32_t res02 = 0;
630             if (bias)
631             {
632                 res00 = *bias++;
633                 res01 = *bias++;
634                 res02 = *bias++;
635             }
636             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
637             {
638                 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
639                 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
640                 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
641                 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
642 
643                 res00 += lhs_value * rhs_value0;
644                 res01 += lhs_value * rhs_value1;
645                 res02 += lhs_value * rhs_value2;
646 
647                 ++rhs_ptr_0;
648                 ++rhs_ptr_1;
649                 ++rhs_ptr_2;
650                 ++lhs_ptr;
651             }
652             // Quantize down
653             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
654             res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
655             res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
656 
657             // Add offset
658             res00 += dst_offset;
659             res01 += dst_offset;
660             res02 += dst_offset;
661 
662             // Clamp the result
663             res00 = MAX(res00, activation_min);
664             res00 = MIN(res00, activation_max);
665             res01 = MAX(res01, activation_min);
666             res01 = MIN(res01, activation_max);
667             res02 = MAX(res02, activation_min);
668             res02 = MIN(res02, activation_max);
669 
670             *dst = (int8_t)res00;
671             *(dst + address_offset) = (int8_t)res01;
672             *(dst + 2 * address_offset) = (int8_t)res02;
673             dst += 3 * address_offset;
674 
675             rhs += 3 * rhs_cols;
676         }
677 
678         const int loop_cnt = rhs_rows % 3;
679 
680         for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
681         {
682             const int8_t *lhs_ptr = &lhs[0];
683             const int8_t *rhs_ptr = &rhs[0];
684 
685             int32_t res00 = 0;
686             if (bias)
687             {
688                 res00 = *bias++;
689             }
690 
691             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
692             {
693                 int32_t rhs_value0 = (int8_t)rhs_ptr[0];
694                 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
695 
696                 res00 += lhs_value * rhs_value0;
697 
698                 ++rhs_ptr;
699                 ++lhs_ptr;
700             }
701 
702             // Quantize down
703             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
704 
705             // Add offset
706             res00 += dst_offset;
707 
708             // Clamp the result
709             res00 = MAX(res00, activation_min);
710             res00 = MIN(res00, activation_max);
711 
712             *dst = (int8_t)res00;
713             dst += address_offset;
714             rhs += rhs_cols;
715         }
716 #endif
717     }
718     return ARM_CMSIS_NN_SUCCESS;
719 }
720 
721 /**
722  * @} end of Doxygen group
723  */
724