1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s8
22  * Description:  s8 vector by matrix (transposed) multiplication
23  *
24  * $Date:        14 Feb 2023
25  * $Revision:    V.6.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @defgroup supportFC Fully Connected
39  *
40  * Support functions for Fully Connected
41  *
42  */
43 
44 /**
45  * @addtogroup supportFC
46  * @{
47  */
48 
49 /*
50  * s8 vector(lhs) by matrix (transposed) multiplication
51  *
52  * Refer header file for details.
53  *
54  */
55 #if defined(ARM_MATH_DSP) && !defined(__ARMCC_VERSION) && !defined(__ICCARM__)
56     #pragma GCC optimize("unroll-loops")
57 #endif
arm_nn_vec_mat_mult_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * kernel_sum,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max,const int32_t address_offset,const int32_t rhs_offset)58 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
59                                              const int8_t *rhs,
60                                              const int32_t *kernel_sum,
61                                              const int32_t *bias,
62                                              int8_t *dst,
63                                              const int32_t lhs_offset,
64                                              const int32_t dst_offset,
65                                              const int32_t dst_multiplier,
66                                              const int32_t dst_shift,
67                                              const int32_t rhs_cols,
68                                              const int32_t rhs_rows,
69                                              const int32_t activation_min,
70                                              const int32_t activation_max,
71                                              const int32_t address_offset,
72                                              const int32_t rhs_offset)
73 {
74     if (rhs_offset)
75     {
76 #if defined(ARM_MATH_MVEI)
77         const int32_t row_loop_cnt = rhs_rows / 4;
78         const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
79 
80         for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
81         {
82             int32_t acc_0 = 0;
83             int32_t acc_1 = 0;
84             int32_t acc_2 = 0;
85             int32_t acc_3 = 0;
86 
87             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
88 
89             const int8_t *lhs_vec = lhs;
90             const int8_t *rhs_0_ptr = rhs;
91             const int8_t *rhs_1_ptr = rhs + rhs_cols;
92             const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
93             const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
94 
95             int32_t lhs_sum = 0;
96 
97             if (bias)
98             {
99                 acc_0 = *bias++;
100                 acc_1 = *bias++;
101                 acc_2 = *bias++;
102                 acc_3 = *bias++;
103             }
104 
105             uint32_t col_cnt = (uint32_t)rhs_cols;
106 
107             for (int32_t i = 0; i < col_loop_cnt; i++)
108             {
109                 mve_pred16_t p = vctp8q(col_cnt);
110                 col_cnt -= 16;
111 
112                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
113                 lhs_sum = vaddvaq_s8(lhs_sum, input);
114 
115                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
116                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
117 
118                 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
119                 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
120 
121                 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
122                 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
123 
124                 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
125                 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
126 
127                 lhs_vec += 16;
128                 rhs_0_ptr += 16;
129                 rhs_1_ptr += 16;
130                 rhs_2_ptr += 16;
131                 rhs_3_ptr += 16;
132             }
133             rhs += 4 * rhs_cols;
134 
135             int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
136 
137             const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
138             acc += vdupq_n_s32(lhs_offset) * rhs_sum;
139             kernel_sum += 4;
140 
141             acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
142             acc += vdupq_n_s32(rhs_offset * lhs_offset) * vdupq_n_s32(rhs_cols);
143 
144             acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
145             acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
146             acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
147             acc = vminq_s32(acc, vdupq_n_s32(activation_max));
148 
149             vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
150 
151             dst += 4 * address_offset;
152         }
153 
154         const int loop_cnt = rhs_rows % 4;
155         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
156         {
157             int32_t acc_0 = 0;
158             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
159             const int8_t *lhs_vec = lhs;
160             const int8_t *rhs_ptr = rhs;
161             int32_t lhs_sum = 0;
162             uint32_t col_cnt = (uint32_t)rhs_cols;
163 
164             for (int32_t i = 0; i < col_loop_cnt; i++)
165             {
166                 mve_pred16_t p = vctp8q(col_cnt);
167                 col_cnt -= 16;
168                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
169                 lhs_sum = vaddvaq_s8(lhs_sum, input);
170 
171                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
172                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
173 
174                 lhs_vec += 16;
175                 rhs_ptr += 16;
176             }
177             rhs += rhs_cols;
178 
179             if (bias)
180             {
181                 acc_0 += *bias;
182                 bias++;
183             }
184             const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
185             acc_0 += rhs_sum * lhs_offset;
186             acc_0 += lhs_sum * rhs_offset;
187             acc_0 += rhs_cols * lhs_offset * rhs_offset;
188 
189             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
190             acc_0 += dst_offset;
191 
192             // Clamp the result
193             acc_0 = MAX(acc_0, activation_min);
194             *dst = MIN(acc_0, activation_max);
195             dst += address_offset;
196         }
197 
198 #elif defined(ARM_MATH_DSP)
199         (void)kernel_sum;
200 
201         const int32_t row_loop_cnt = rhs_rows / 2;
202         const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
203         const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
204 
205         const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
206         const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
207 
208         for (int32_t i = 0; i < row_loop_cnt; i++)
209         {
210             int32_t acc_0 = 0;
211             int32_t acc_1 = 0;
212             if (bias)
213             {
214                 acc_0 = *bias++;
215                 acc_1 = *bias++;
216             }
217 
218             const int32_t col_loop_cnt = rhs_cols / 4;
219 
220             const int8_t *lhs_vec = lhs;
221             const int8_t *rhs_0_ptr = rhs;
222             const int8_t *rhs_1_ptr = rhs + rhs_cols;
223             rhs += 2 * rhs_cols;
224 
225             for (int32_t j = col_loop_cnt; j != 0; j--)
226             {
227                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
228                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
229 
230                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
231 
232                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
233                 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
234                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
235 
236                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
237                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
238 
239                 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
240                 ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
241                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
242 
243                 acc_1 = SMLAD(ker_1, vec_1, acc_1);
244                 acc_1 = SMLAD(ker_0, vec_0, acc_1);
245             }
246 
247             for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
248             {
249                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
250                 lhs_vec++;
251                 acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
252                 rhs_0_ptr++;
253                 acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
254                 rhs_1_ptr++;
255             }
256 
257             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
258             acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
259 
260             // Add offset
261             acc_0 += dst_offset;
262             acc_1 += dst_offset;
263             // Clamp the result
264             acc_0 = MAX(acc_0, activation_min);
265             acc_0 = MIN(acc_0, activation_max);
266             acc_1 = MAX(acc_1, activation_min);
267             acc_1 = MIN(acc_1, activation_max);
268             *dst = (int8_t)acc_0;
269             *(dst + address_offset) = (int8_t)acc_1;
270             dst += 2 * address_offset;
271         }
272 
273         if (rhs_rows & 0x1)
274         {
275             int32_t acc_0 = 0;
276             if (bias)
277             {
278                 acc_0 = *bias++;
279             }
280             const int32_t col_loop_cnt = rhs_cols / 4;
281 
282             const int8_t *lhs_vec = lhs;
283             const int8_t *rhs_ptr = rhs;
284 
285             for (int32_t i = col_loop_cnt; i != 0; i--)
286             {
287                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
288                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
289                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
290 
291                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
292                 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
293                 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
294 
295                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
296                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
297             }
298 
299             for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
300             {
301                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
302                 lhs_vec++;
303                 acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
304                 rhs_ptr++;
305             }
306 
307             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
308 
309             // Add offset
310             acc_0 += dst_offset;
311             // Clamp the result
312             acc_0 = MAX(acc_0, activation_min);
313             acc_0 = MIN(acc_0, activation_max);
314             *dst = (int8_t)acc_0;
315             dst += address_offset;
316         }
317 
318 #else
319         (void)kernel_sum;
320 
321         const int32_t row_loop_cnt = rhs_rows / 3;
322 
323         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
324         {
325             const int8_t *lhs_ptr = lhs;
326             const int8_t *rhs_ptr_0 = &rhs[0];
327             const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
328             const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
329 
330             int32_t res00 = 0;
331             int32_t res01 = 0;
332             int32_t res02 = 0;
333             if (bias)
334             {
335                 res00 = *bias++;
336                 res01 = *bias++;
337                 res02 = *bias++;
338             }
339             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
340             {
341                 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
342                 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
343                 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
344                 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
345 
346                 res00 += lhs_value * rhs_value0;
347                 res01 += lhs_value * rhs_value1;
348                 res02 += lhs_value * rhs_value2;
349 
350                 ++rhs_ptr_0;
351                 ++rhs_ptr_1;
352                 ++rhs_ptr_2;
353                 ++lhs_ptr;
354             }
355 
356             // Quantize down
357             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
358             res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
359             res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
360 
361             // Add offset
362             res00 += dst_offset;
363             res01 += dst_offset;
364             res02 += dst_offset;
365 
366             // Clamp the result
367             res00 = MAX(res00, activation_min);
368             res00 = MIN(res00, activation_max);
369             res01 = MAX(res01, activation_min);
370             res01 = MIN(res01, activation_max);
371             res02 = MAX(res02, activation_min);
372             res02 = MIN(res02, activation_max);
373 
374             *dst = (int8_t)res00;
375             *(dst + address_offset) = (int8_t)res01;
376             *(dst + 2 * address_offset) = (int8_t)res02;
377             dst += 3 * address_offset;
378 
379             rhs += 3 * rhs_cols;
380         }
381 
382         const int loop_cnt = rhs_rows % 3;
383 
384         for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
385         {
386             const int8_t *lhs_ptr = &lhs[0];
387             const int8_t *rhs_ptr = &rhs[0];
388 
389             int32_t res00 = 0;
390             if (bias)
391             {
392                 res00 = *bias++;
393             }
394 
395             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
396             {
397                 int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
398                 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
399 
400                 res00 += lhs_value * rhs_value0;
401 
402                 ++rhs_ptr;
403                 ++lhs_ptr;
404             }
405 
406             // Quantize down
407             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
408 
409             // Add offset
410             res00 += dst_offset;
411 
412             // Clamp the result
413             res00 = MAX(res00, activation_min);
414             res00 = MIN(res00, activation_max);
415 
416             *dst = (int8_t)res00;
417             dst += address_offset;
418             rhs += rhs_cols;
419         }
420 #endif
421     }
422 
423     else
424     {
425 
426 #if defined(ARM_MATH_MVEI)
427         const int32_t row_loop_cnt = rhs_rows / 4;
428         const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
429 
430         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
431         {
432             int32_t acc_0 = 0;
433             int32_t acc_1 = 0;
434             int32_t acc_2 = 0;
435             int32_t acc_3 = 0;
436 
437             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
438 
439             const int8_t *lhs_vec = lhs;
440             const int8_t *rhs_0_ptr = rhs;
441             const int8_t *rhs_1_ptr = rhs + rhs_cols;
442             const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
443             const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
444 
445             if (bias)
446             {
447                 acc_0 = *bias++;
448                 acc_1 = *bias++;
449                 acc_2 = *bias++;
450                 acc_3 = *bias++;
451             }
452 
453             uint32_t col_cnt = (uint32_t)rhs_cols;
454 
455             for (int32_t i = 0; i < col_loop_cnt; i++)
456             {
457                 mve_pred16_t p = vctp8q(col_cnt);
458                 col_cnt -= 16;
459 
460                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
461 
462                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
463                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
464 
465                 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
466                 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
467 
468                 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
469                 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
470 
471                 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
472                 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
473 
474                 lhs_vec += 16;
475                 rhs_0_ptr += 16;
476                 rhs_1_ptr += 16;
477                 rhs_2_ptr += 16;
478                 rhs_3_ptr += 16;
479             }
480             rhs += 4 * rhs_cols;
481 
482             int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
483 
484             const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
485             acc += vdupq_n_s32(lhs_offset) * rhs_sum;
486             kernel_sum += 4;
487 
488             acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
489             acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
490             acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
491             acc = vminq_s32(acc, vdupq_n_s32(activation_max));
492 
493             vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
494 
495             dst += 4 * address_offset;
496         }
497 
498         const int loop_cnt = rhs_rows % 4;
499         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
500         {
501             int32_t acc_0 = 0;
502             const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
503             const int8_t *lhs_vec = lhs;
504             const int8_t *rhs_ptr = rhs;
505             uint32_t col_cnt = (uint32_t)rhs_cols;
506 
507             for (int32_t i = 0; i < col_loop_cnt; i++)
508             {
509                 mve_pred16_t p = vctp8q(col_cnt);
510                 col_cnt -= 16;
511                 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
512 
513                 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
514                 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
515 
516                 lhs_vec += 16;
517                 rhs_ptr += 16;
518             }
519             rhs += rhs_cols;
520 
521             if (bias)
522             {
523                 acc_0 += *bias;
524                 bias++;
525             }
526             const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
527             const int32_t offsets = rhs_sum * lhs_offset;
528             acc_0 += offsets;
529             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
530             acc_0 += dst_offset;
531 
532             // Clamp the result
533             acc_0 = MAX(acc_0, activation_min);
534             *dst = MIN(acc_0, activation_max);
535             dst += address_offset;
536         }
537 
538 #elif defined(ARM_MATH_DSP)
539         (void)kernel_sum;
540 
541         const int32_t row_loop_cnt = rhs_rows / 2;
542         const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
543         const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
544 
545         for (int32_t i = 0; i < row_loop_cnt; i++)
546         {
547             int32_t acc_0 = 0;
548             int32_t acc_1 = 0;
549             if (bias)
550             {
551                 acc_0 = *bias++;
552                 acc_1 = *bias++;
553             }
554 
555             const int32_t col_loop_cnt = rhs_cols / 4;
556 
557             const int8_t *lhs_vec = lhs;
558             const int8_t *rhs_0_ptr = rhs;
559             const int8_t *rhs_1_ptr = rhs + rhs_cols;
560             rhs += 2 * rhs_cols;
561 
562             for (int32_t j = col_loop_cnt; j != 0; j--)
563             {
564                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
565                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
566 
567                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
568 
569                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
570                 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
571                 ker_0 = SXTB16(ker_0);
572 
573                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
574                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
575 
576                 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
577                 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
578                 ker_0 = SXTB16(ker_0);
579 
580                 acc_1 = SMLAD(ker_1, vec_1, acc_1);
581                 acc_1 = SMLAD(ker_0, vec_0, acc_1);
582             }
583 
584             for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
585             {
586                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
587                 lhs_vec++;
588                 acc_0 += lhs_temp * (*rhs_0_ptr);
589                 rhs_0_ptr++;
590                 acc_1 += lhs_temp * (*rhs_1_ptr);
591                 rhs_1_ptr++;
592             }
593 
594             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
595             acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
596 
597             // Add offset
598             acc_0 += dst_offset;
599             acc_1 += dst_offset;
600             // Clamp the result
601             acc_0 = MAX(acc_0, activation_min);
602             acc_0 = MIN(acc_0, activation_max);
603             acc_1 = MAX(acc_1, activation_min);
604             acc_1 = MIN(acc_1, activation_max);
605             *dst = (int8_t)acc_0;
606             *(dst + address_offset) = (int8_t)acc_1;
607             dst += 2 * address_offset;
608         }
609 
610         if (rhs_rows & 0x1)
611         {
612             int32_t acc_0 = 0;
613             if (bias)
614             {
615                 acc_0 = *bias++;
616             }
617             const int32_t col_loop_cnt = rhs_cols / 4;
618 
619             const int8_t *lhs_vec = lhs;
620             const int8_t *rhs_ptr = rhs;
621 
622             for (int32_t i = col_loop_cnt; i != 0; i--)
623             {
624                 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
625                 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
626                 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
627 
628                 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
629                 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
630                 ker_0 = SXTB16(ker_0);
631 
632                 acc_0 = SMLAD(ker_1, vec_1, acc_0);
633                 acc_0 = SMLAD(ker_0, vec_0, acc_0);
634             }
635 
636             for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
637             {
638                 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
639                 lhs_vec++;
640                 acc_0 += lhs_temp * (*rhs_ptr);
641                 rhs_ptr++;
642             }
643 
644             acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
645 
646             // Add offset
647             acc_0 += dst_offset;
648             // Clamp the result
649             acc_0 = MAX(acc_0, activation_min);
650             acc_0 = MIN(acc_0, activation_max);
651             *dst = (int8_t)acc_0;
652             dst += address_offset;
653         }
654 
655 #else
656         (void)kernel_sum;
657 
658         const int32_t row_loop_cnt = rhs_rows / 3;
659 
660         for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
661         {
662             const int8_t *lhs_ptr = lhs;
663             const int8_t *rhs_ptr_0 = &rhs[0];
664             const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
665             const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
666 
667             int32_t res00 = 0;
668             int32_t res01 = 0;
669             int32_t res02 = 0;
670             if (bias)
671             {
672                 res00 = *bias++;
673                 res01 = *bias++;
674                 res02 = *bias++;
675             }
676             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
677             {
678                 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
679                 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
680                 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
681                 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
682 
683                 res00 += lhs_value * rhs_value0;
684                 res01 += lhs_value * rhs_value1;
685                 res02 += lhs_value * rhs_value2;
686 
687                 ++rhs_ptr_0;
688                 ++rhs_ptr_1;
689                 ++rhs_ptr_2;
690                 ++lhs_ptr;
691             }
692             // Quantize down
693             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
694             res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
695             res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
696 
697             // Add offset
698             res00 += dst_offset;
699             res01 += dst_offset;
700             res02 += dst_offset;
701 
702             // Clamp the result
703             res00 = MAX(res00, activation_min);
704             res00 = MIN(res00, activation_max);
705             res01 = MAX(res01, activation_min);
706             res01 = MIN(res01, activation_max);
707             res02 = MAX(res02, activation_min);
708             res02 = MIN(res02, activation_max);
709 
710             *dst = (int8_t)res00;
711             *(dst + address_offset) = (int8_t)res01;
712             *(dst + 2 * address_offset) = (int8_t)res02;
713             dst += 3 * address_offset;
714 
715             rhs += 3 * rhs_cols;
716         }
717 
718         const int loop_cnt = rhs_rows % 3;
719 
720         for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
721         {
722             const int8_t *lhs_ptr = &lhs[0];
723             const int8_t *rhs_ptr = &rhs[0];
724 
725             int32_t res00 = 0;
726             if (bias)
727             {
728                 res00 = *bias++;
729             }
730 
731             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
732             {
733                 int32_t rhs_value0 = (int8_t)rhs_ptr[0];
734                 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
735 
736                 res00 += lhs_value * rhs_value0;
737 
738                 ++rhs_ptr;
739                 ++lhs_ptr;
740             }
741 
742             // Quantize down
743             res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
744 
745             // Add offset
746             res00 += dst_offset;
747 
748             // Clamp the result
749             res00 = MAX(res00, activation_min);
750             res00 = MIN(res00, activation_max);
751 
752             *dst = (int8_t)res00;
753             dst += address_offset;
754             rhs += rhs_cols;
755         }
756 #endif
757     }
758     return ARM_CMSIS_NN_SUCCESS;
759 }
760 
761 /**
762  * @} end of Doxygen group
763  */
764