1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_s8_nt_t_s8
22  * Description:  Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23  *
24  * $Date:        22 March 2023
25  * $Revision:    V.2.1.2
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportConvolution
39  * @{
40  */
41 
42 /*
43  * s8 matrix multiplication with the right-hand-side matrix transposed
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_mat_mult_nt_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
49                                             const int8_t *rhs,
50                                             const int32_t *bias,
51                                             int8_t *dst,
52                                             const int32_t *dst_multipliers,
53                                             const int32_t *dst_shifts,
54                                             const int32_t lhs_rows,
55                                             const int32_t rhs_rows,
56                                             const int32_t rhs_cols,
57                                             const int32_t lhs_offset,
58                                             const int32_t dst_offset,
59                                             const int32_t activation_min,
60                                             const int32_t activation_max,
61                                             const int32_t lhs_cols_offset)
62 {
63 
64 #if defined(ARM_MATH_MVEI)
65     int i_items = 0;
66     for (; i_items <= (lhs_rows - 4); i_items += 4)
67     {
68         for (int i = 0; i < rhs_rows; i++)
69         {
70             int32_t acc_n0 = 0;
71             int32_t acc_n1 = 0;
72             int32_t acc_n2 = 0;
73             int32_t acc_n3 = 0;
74 
75             const int8_t *lhs_vec = lhs;
76             const int8_t *ip_row_1 = lhs + lhs_cols_offset;
77             const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
78             const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
79             const int8_t *col_base = rhs + i * rhs_cols;
80             int32_t sum_tmp = 0;
81 
82     #if defined(ARM_MATH_AUTOVECTORIZE)
83             for (int j = 0; j < rhs_cols; j++)
84             {
85                 int32_t col = col_base[j];
86                 sum_tmp += col;
87                 acc_n0 += lhs_vec[j] * col;
88                 acc_n1 += ip_row_1[j] * col;
89                 acc_n2 += ip_row_2[j] * col;
90                 acc_n3 += ip_row_3[j] * col;
91             }
92     #else
93             // Note: If operand initialization is moved around, use '&' constraint to
94             // specify earlyclobber operands.
95             __ASM volatile(" .p2align 2                             \n"
96                            "   wlstp.8         lr, %[cnt], 1f       \n"
97                            "   mov             %[sum], 0            \n"
98                            "   mov             %[out0], 0           \n"
99                            "   mov             %[out1], 0           \n"
100                            "   mov             %[out2], 0           \n"
101                            "   mov             %[out3], 0           \n"
102                            "   vldrb.8         q0, [%[col]], #16    \n"
103                            "2:                                      \n"
104                            "   vaddva.s8      %[sum], q0            \n"
105                            "   vldrb.8         q1, [%[row0]], #16   \n"
106                            "   vmladava.s8    %[out0], q0, q1       \n"
107                            "   vldrb.8         q2, [%[row1]], #16   \n"
108                            "   vmladava.s8     %[out1], q0, q2      \n"
109                            "   vldrb.8         q3, [%[row2]], #16   \n"
110                            "   vmladava.s8     %[out2], q0, q3      \n"
111                            "   vldrb.8         q4, [%[row3]], #16   \n"
112                            "   vmladava.s8     %[out3], q0, q4      \n"
113                            "   vldrb.8         q0, [%[col]], #16    \n"
114                            "   letp            lr, 2b               \n"
115                            "1:                                      \n"
116                            : [col] "+r"(col_base),
117                              [sum] "=Te"(sum_tmp),
118                              [row0] "+r"(lhs_vec),
119                              [row1] "+r"(ip_row_1),
120                              [row2] "+r"(ip_row_2),
121                              [row3] "+r"(ip_row_3),
122                              [out0] "=Te"(acc_n0),
123                              [out1] "=Te"(acc_n1),
124                              [out2] "=Te"(acc_n2),
125                              [out3] "=Te"(acc_n3)
126                            : [cnt] "r"(rhs_cols)
127                            : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
128     #endif
129             int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
130             sum_tmp *= lhs_offset;
131             if (bias)
132             {
133                 sum_tmp += bias[i];
134             }
135             res = vaddq_n_s32(res, sum_tmp);
136 
137             res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
138             res = vaddq_n_s32(res, dst_offset);
139 
140             res = vmaxq_s32(res, vdupq_n_s32(activation_min));
141             res = vminq_s32(res, vdupq_n_s32(activation_max));
142 
143             const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3};
144             vstrbq_scatter_offset_s32(dst, scatter_offset, res);
145             dst++;
146         }
147         lhs += 4 * lhs_cols_offset;
148         dst += (3 * rhs_rows);
149     }
150 
151     for (; i_items < lhs_rows; i_items++)
152     {
153         int32_t acc[4];
154         const int32_t *multipliers = dst_multipliers;
155         const int32_t *shifts = dst_shifts;
156         for (int i = 0; i < rhs_rows; i++)
157         {
158             int32_t acc_n0 = 0;
159             const int8_t *lhs_vec = lhs;
160             const int8_t *col_base = rhs + i * rhs_cols;
161             int32_t sum_tmp = 0;
162 
163     #if defined(ARM_MATH_AUTOVECTORIZE)
164             for (int j = 0; j < rhs_cols; j++)
165             {
166                 int32_t col = col_base[j];
167                 sum_tmp += col;
168                 acc_n0 += lhs_vec[j] * col;
169             }
170     #else
171             __ASM volatile(" .p2align 2                             \n"
172                            "   wlstp.8         lr, %[cnt], 1f       \n"
173                            "   mov             %[sum], 0            \n"
174                            "   mov             %[out0], 0            \n"
175                            "   vldrb.8         q0, [%[col]], #16    \n"
176                            "2:                                      \n"
177                            "   vaddva.s8      %[sum], q0            \n"
178                            "   vldrb.8         q1, [%[row0]], #16   \n"
179                            "   vmladava.s8    %[out0], q0, q1       \n"
180                            "   vldrb.8         q0, [%[col]], #16    \n"
181                            "   letp            lr, 2b               \n"
182                            "1:                                      \n"
183                            : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
184                            : [cnt] "r"(rhs_cols)
185                            : "q0", "q1", "memory", "r14");
186     #endif
187             sum_tmp *= lhs_offset;
188             sum_tmp += acc_n0;
189             if (bias)
190             {
191                 sum_tmp += bias[i];
192             }
193             const int32_t index = i & 0x3;
194             acc[index] = sum_tmp;
195 
196             if (index == 3)
197             {
198                 int32x4_t res = vldrwq_s32(acc);
199                 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
200                 multipliers += 4;
201                 shifts += 4;
202                 res = vaddq_n_s32(res, dst_offset);
203                 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
204                 res = vminq_s32(res, vdupq_n_s32(activation_max));
205                 vstrbq_s32(dst, res);
206                 dst += 4;
207             }
208         }
209         lhs += lhs_cols_offset;
210         const int32_t tail_rows = rhs_rows & 0x3;
211         for (int i = 0; i < tail_rows; i++)
212         {
213             int32_t acc_n0 = acc[i];
214             acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
215             acc_n0 += dst_offset;
216             acc_n0 = MAX(acc_n0, activation_min);
217             acc_n0 = MIN(acc_n0, activation_max);
218             *dst++ = (int8_t)acc_n0;
219         }
220     }
221 
222 #elif defined(ARM_MATH_DSP)
223     const int32_t rhs_off0 = rhs_cols - 4;
224     const int32_t lhs_off0 = lhs_cols_offset - 4;
225 
226     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
227     {
228         const int8_t *lhs_ptr = &lhs[0];
229         int8_t *dst_ptr = &dst[0];
230 
231         int32_t lhs_offset_contribution0 = 0;
232         int32_t lhs_offset_contribution1 = 0;
233 
234         for (int32_t x = 0; x < rhs_cols; ++x)
235         {
236             lhs_offset_contribution0 += rhs[x];
237             lhs_offset_contribution1 += rhs[x + rhs_cols];
238         }
239 
240         lhs_offset_contribution0 *= lhs_offset;
241         lhs_offset_contribution1 *= lhs_offset;
242         if (bias)
243         {
244             lhs_offset_contribution0 += bias[rhs_rows_idx];
245             lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
246         }
247 
248         int32_t lhs_rows_idx = lhs_rows >> 1;
249 
250         while (lhs_rows_idx)
251         {
252             const int8_t *rhs_ptr = &rhs[0];
253 
254             int32_t res00 = lhs_offset_contribution0;
255             int32_t res01 = lhs_offset_contribution1;
256             int32_t res10 = lhs_offset_contribution0;
257             int32_t res11 = lhs_offset_contribution1;
258 
259             int32_t rhs_cols_idx = 0;
260 
261             int32_t val0, val1, val2, val3, val4, val5;
262 
263             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
264             {
265                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
266                 val2 = SXTB16(val1);
267                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
268                 val3 = SXTB16(val0);
269                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
270                 val1 = SXTB16_RORn(val1, 8);
271                 val0 = SXTB16_RORn(val0, 8);
272 
273                 // 4 x MAC res00, res01
274                 res00 = SMLAD(val3, val2, res00);
275                 val5 = SXTB16(val4);
276                 res00 = SMLAD(val0, val1, res00);
277                 val4 = SXTB16_RORn(val4, 8);
278                 res01 = SMLAD(val3, val5, res01);
279                 res01 = SMLAD(val0, val4, res01);
280 
281                 // 4 x MAC res10, res11
282                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
283                 val3 = SXTB16(val0);
284                 val0 = SXTB16_RORn(val0, 8);
285                 res10 = SMLAD(val3, val2, res10);
286                 res11 = SMLAD(val3, val5, res11);
287                 res10 = SMLAD(val0, val1, res10);
288                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
289                 res11 = SMLAD(val0, val4, res11);
290 
291                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
292                 val2 = SXTB16(val1);
293                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
294                 val3 = SXTB16(val0);
295                 val1 = SXTB16_RORn(val1, 8);
296                 val0 = SXTB16_RORn(val0, 8);
297 
298                 // 4 x MAC res00, res01
299                 res00 = SMLAD(val3, val2, res00);
300                 val5 = SXTB16(val4);
301                 res00 = SMLAD(val0, val1, res00);
302                 val4 = SXTB16_RORn(val4, 8);
303                 res01 = SMLAD(val3, val5, res01);
304                 res01 = SMLAD(val0, val4, res01);
305 
306                 // 4 x MAC res10, res11
307                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
308                 val3 = SXTB16(val0);
309                 val0 = SXTB16_RORn(val0, 8);
310                 res10 = SMLAD(val3, val2, res10);
311                 res11 = SMLAD(val3, val5, res11);
312                 res10 = SMLAD(val0, val1, res10);
313                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
314                 res11 = SMLAD(val0, val4, res11);
315 
316                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
317                 val2 = SXTB16(val1);
318                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
319                 val3 = SXTB16(val0);
320                 val1 = SXTB16_RORn(val1, 8);
321                 val0 = SXTB16_RORn(val0, 8);
322 
323                 // 4 x MAC res00, res01
324                 res00 = SMLAD(val3, val2, res00);
325                 val5 = SXTB16(val4);
326                 res00 = SMLAD(val0, val1, res00);
327                 val4 = SXTB16_RORn(val4, 8);
328                 res01 = SMLAD(val3, val5, res01);
329                 res01 = SMLAD(val0, val4, res01);
330 
331                 // 4 x MAC res10, res11
332                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
333                 val3 = SXTB16(val0);
334                 val0 = SXTB16_RORn(val0, 8);
335                 res10 = SMLAD(val3, val2, res10);
336                 res11 = SMLAD(val3, val5, res11);
337                 res10 = SMLAD(val0, val1, res10);
338                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
339                 res11 = SMLAD(val0, val4, res11);
340 
341                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
342                 val2 = SXTB16(val1);
343                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
344                 val3 = SXTB16(val0);
345                 val1 = SXTB16_RORn(val1, 8);
346                 val0 = SXTB16_RORn(val0, 8);
347 
348                 // 4 x MAC res00, res01
349                 res00 = SMLAD(val3, val2, res00);
350                 val5 = SXTB16(val4);
351                 res00 = SMLAD(val0, val1, res00);
352                 val4 = SXTB16_RORn(val4, 8);
353                 res01 = SMLAD(val3, val5, res01);
354                 res01 = SMLAD(val0, val4, res01);
355 
356                 // 4 x MAC res10, res11
357                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
358                 val3 = SXTB16(val0);
359                 val0 = SXTB16_RORn(val0, 8);
360                 res10 = SMLAD(val3, val2, res10);
361                 res11 = SMLAD(val3, val5, res11);
362                 res10 = SMLAD(val0, val1, res10);
363                 res11 = SMLAD(val0, val4, res11);
364             }
365 
366             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
367             {
368                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
369                 val2 = SXTB16(val1);
370                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
371                 val3 = SXTB16(val0);
372                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
373                 val1 = SXTB16_RORn(val1, 8);
374                 val0 = SXTB16_RORn(val0, 8);
375 
376                 // 4 x MAC res00, res01
377                 res00 = SMLAD(val3, val2, res00);
378                 val5 = SXTB16(val4);
379                 res00 = SMLAD(val0, val1, res00);
380                 val4 = SXTB16_RORn(val4, 8);
381                 res01 = SMLAD(val3, val5, res01);
382                 res01 = SMLAD(val0, val4, res01);
383 
384                 // 4 x MAC res10, res11
385                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
386                 val3 = SXTB16(val0);
387                 val0 = SXTB16_RORn(val0, 8);
388                 res10 = SMLAD(val3, val2, res10);
389                 res11 = SMLAD(val3, val5, res11);
390                 res10 = SMLAD(val0, val1, res10);
391                 res11 = SMLAD(val0, val4, res11);
392             }
393 
394             for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
395             {
396                 int8_t rhs_value0 = rhs_ptr[0];
397                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
398                 int8_t lhs_value = lhs_ptr[0];
399 
400                 res00 += lhs_value * rhs_value0;
401                 res01 += lhs_value * rhs_value1;
402 
403                 lhs_value = lhs_ptr[lhs_cols_offset];
404                 res10 += lhs_value * rhs_value0;
405                 res11 += lhs_value * rhs_value1;
406 
407                 ++rhs_ptr;
408                 ++lhs_ptr;
409             }
410 
411             // Quantize down
412             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
413             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
414             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
415             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
416 
417             // Add offset
418             res00 += dst_offset;
419             res01 += dst_offset;
420             res10 += dst_offset;
421             res11 += dst_offset;
422 
423             // Clamp the result
424             res00 = MAX(res00, activation_min);
425             res00 = MIN(res00, activation_max);
426             res01 = MAX(res01, activation_min);
427             res01 = MIN(res01, activation_max);
428             res10 = MAX(res10, activation_min);
429             res10 = MIN(res10, activation_max);
430             res11 = MAX(res11, activation_min);
431             res11 = MIN(res11, activation_max);
432 
433             dst_ptr[0] = (int8_t)res00;
434             dst_ptr[1] = (int8_t)res01;
435             dst_ptr += rhs_rows;
436             dst_ptr[0] = (int8_t)res10;
437             dst_ptr[1] = (int8_t)res11;
438             dst_ptr += rhs_rows;
439 
440             lhs_ptr -= rhs_cols;
441             lhs_ptr += 2 * lhs_cols_offset;
442 
443             lhs_rows_idx--;
444         }
445 
446         // Left-over rows
447         if (lhs_rows % 2)
448         {
449             const int8_t *rhs_ptr = &rhs[0];
450 
451             int32_t res00 = lhs_offset_contribution0;
452             int32_t res01 = lhs_offset_contribution1;
453 
454             int32_t rhs_cols_idx = 0;
455 
456             int32_t val0, val1, val2, val3, val4, val5;
457             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
458             {
459                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
460                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
461                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
462                 val3 = SXTB16(val0);
463                 val5 = SXTB16(val2);
464                 val4 = SXTB16(val1);
465                 val0 = SXTB16_RORn(val0, 8);
466                 val2 = SXTB16_RORn(val2, 8);
467                 val1 = SXTB16_RORn(val1, 8);
468 
469                 // 4 x MAC res00, res01
470                 res00 = SMLAD(val5, val3, res00);
471                 res00 = SMLAD(val2, val0, res00);
472                 res01 = SMLAD(val5, val4, res01);
473                 res01 = SMLAD(val2, val1, res01);
474 
475                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
476                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
477                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
478                 val3 = SXTB16(val0);
479                 val5 = SXTB16(val2);
480                 val4 = SXTB16(val1);
481                 val0 = SXTB16_RORn(val0, 8);
482                 val2 = SXTB16_RORn(val2, 8);
483                 val1 = SXTB16_RORn(val1, 8);
484 
485                 // 4 x MAC res00, res01
486                 res00 = SMLAD(val5, val3, res00);
487                 res00 = SMLAD(val2, val0, res00);
488                 res01 = SMLAD(val5, val4, res01);
489                 res01 = SMLAD(val2, val1, res01);
490 
491                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
492                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
493                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
494                 val3 = SXTB16(val0);
495                 val5 = SXTB16(val2);
496                 val4 = SXTB16(val1);
497                 val0 = SXTB16_RORn(val0, 8);
498                 val2 = SXTB16_RORn(val2, 8);
499                 val1 = SXTB16_RORn(val1, 8);
500 
501                 // 4 x MAC res00, res01
502                 res00 = SMLAD(val5, val3, res00);
503                 res00 = SMLAD(val2, val0, res00);
504                 res01 = SMLAD(val5, val4, res01);
505                 res01 = SMLAD(val2, val1, res01);
506 
507                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
508                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
509                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
510                 val3 = SXTB16(val0);
511                 val5 = SXTB16(val2);
512                 val4 = SXTB16(val1);
513                 val0 = SXTB16_RORn(val0, 8);
514                 val2 = SXTB16_RORn(val2, 8);
515                 val1 = SXTB16_RORn(val1, 8);
516 
517                 // 4 x MAC res00, res01
518                 res00 = SMLAD(val5, val3, res00);
519                 res00 = SMLAD(val2, val0, res00);
520                 res01 = SMLAD(val5, val4, res01);
521                 res01 = SMLAD(val2, val1, res01);
522             }
523 
524             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
525             {
526                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
527                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
528                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
529                 val3 = SXTB16(val0);
530                 val5 = SXTB16(val2);
531                 val4 = SXTB16(val1);
532                 val0 = SXTB16_RORn(val0, 8);
533                 val2 = SXTB16_RORn(val2, 8);
534                 val1 = SXTB16_RORn(val1, 8);
535 
536                 // 4 x MAC res00, res01
537                 res00 = SMLAD(val5, val3, res00);
538                 res00 = SMLAD(val2, val0, res00);
539                 res01 = SMLAD(val5, val4, res01);
540                 res01 = SMLAD(val2, val1, res01);
541             }
542 
543             // Left-over accumulations
544             for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
545             {
546                 int8_t rhs_value0 = rhs_ptr[0];
547                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
548                 int8_t lhs_value = lhs_ptr[0];
549 
550                 res00 += lhs_value * rhs_value0;
551                 res01 += lhs_value * rhs_value1;
552 
553                 ++rhs_ptr;
554                 ++lhs_ptr;
555             }
556 
557             // Quantize down
558             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
559             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
560 
561             // Add offset
562             res00 += dst_offset;
563             res01 += dst_offset;
564 
565             // Clamp the result
566             res00 = MAX(res00, activation_min);
567             res00 = MIN(res00, activation_max);
568             res01 = MAX(res01, activation_min);
569             res01 = MIN(res01, activation_max);
570 
571             dst_ptr[0] = (int8_t)res00;
572             dst_ptr[1] = (int8_t)res01;
573         }
574 
575         rhs += 2 * rhs_cols;
576         dst += 2;
577     }
578 
579     if (rhs_rows % 2)
580     {
581         const int8_t *lhs_ptr = &lhs[0];
582         int8_t *dst_ptr = &dst[0];
583 
584         for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
585         {
586             const int8_t *rhs_ptr = &rhs[0];
587             int32_t res00 = 0;
588             if (bias)
589             {
590                 res00 = bias[rhs_rows - 1];
591             }
592 
593             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
594             {
595                 int32_t rhs_value = rhs_ptr[0];
596                 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
597 
598                 res00 += lhs_value * rhs_value;
599 
600                 ++rhs_ptr;
601                 ++lhs_ptr;
602             }
603             lhs_ptr -= rhs_cols;
604             lhs_ptr += lhs_cols_offset;
605 
606             // Quantize down
607             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
608 
609             // Add offset
610             res00 += dst_offset;
611 
612             // Clamp the result
613             res00 = MAX(res00, activation_min);
614             res00 = MIN(res00, activation_max);
615 
616             dst_ptr[0] = (int8_t)res00;
617             dst_ptr += rhs_rows;
618         }
619     }
620 #else
621     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
622     {
623         const int8_t *lhs_ptr = &lhs[0];
624         int8_t *dst_ptr = &dst[0];
625 
626         int32_t lhs_offset_contribution0 = 0;
627         int32_t lhs_offset_contribution1 = 0;
628 
629         for (int32_t x = 0; x < rhs_cols; ++x)
630         {
631             lhs_offset_contribution0 += rhs[x];
632             lhs_offset_contribution1 += rhs[x + rhs_cols];
633         }
634 
635         lhs_offset_contribution0 *= lhs_offset;
636         lhs_offset_contribution1 *= lhs_offset;
637         if (bias)
638         {
639             lhs_offset_contribution0 += bias[rhs_rows_idx];
640             lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
641         }
642 
643         int32_t lhs_rows_idx = lhs_rows >> 1;
644 
645         while (lhs_rows_idx)
646         {
647             const int8_t *rhs_ptr = &rhs[0];
648 
649             int32_t res00 = lhs_offset_contribution0;
650             int32_t res01 = lhs_offset_contribution1;
651             int32_t res10 = lhs_offset_contribution0;
652             int32_t res11 = lhs_offset_contribution1;
653 
654             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
655             {
656                 int8_t rhs_value0 = rhs_ptr[0];
657                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
658                 int8_t lhs_value = lhs_ptr[0];
659 
660                 res00 += lhs_value * rhs_value0;
661                 res01 += lhs_value * rhs_value1;
662 
663                 lhs_value = lhs_ptr[lhs_cols_offset];
664                 res10 += lhs_value * rhs_value0;
665                 res11 += lhs_value * rhs_value1;
666 
667                 ++rhs_ptr;
668                 ++lhs_ptr;
669             }
670 
671             // Quantize down
672             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
673             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
674             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
675             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
676 
677             // Add offset
678             res00 += dst_offset;
679             res01 += dst_offset;
680             res10 += dst_offset;
681             res11 += dst_offset;
682 
683             // Clamp the result
684             res00 = MAX(res00, activation_min);
685             res00 = MIN(res00, activation_max);
686             res01 = MAX(res01, activation_min);
687             res01 = MIN(res01, activation_max);
688             res10 = MAX(res10, activation_min);
689             res10 = MIN(res10, activation_max);
690             res11 = MAX(res11, activation_min);
691             res11 = MIN(res11, activation_max);
692 
693             dst_ptr[0] = (int8_t)res00;
694             dst_ptr[1] = (int8_t)res01;
695             dst_ptr += rhs_rows;
696             dst_ptr[0] = (int8_t)res10;
697             dst_ptr[1] = (int8_t)res11;
698             dst_ptr += rhs_rows;
699 
700             lhs_ptr -= rhs_cols;
701             lhs_ptr += 2 * lhs_cols_offset;
702 
703             lhs_rows_idx--;
704         }
705 
706         // Left-over rows
707         if (lhs_rows % 2)
708         {
709             const int8_t *rhs_ptr = &rhs[0];
710 
711             int32_t res00 = lhs_offset_contribution0;
712             int32_t res01 = lhs_offset_contribution1;
713 
714             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
715             {
716                 int8_t rhs_value0 = rhs_ptr[0];
717                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
718                 int8_t lhs_value = lhs_ptr[0];
719 
720                 res00 += lhs_value * rhs_value0;
721                 res01 += lhs_value * rhs_value1;
722 
723                 ++rhs_ptr;
724                 ++lhs_ptr;
725             }
726 
727             // Quantize down
728             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
729             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
730 
731             // Add offset
732             res00 += dst_offset;
733             res01 += dst_offset;
734 
735             // Clamp the result
736             res00 = MAX(res00, activation_min);
737             res00 = MIN(res00, activation_max);
738             res01 = MAX(res01, activation_min);
739             res01 = MIN(res01, activation_max);
740 
741             dst_ptr[0] = (int8_t)res00;
742             dst_ptr[1] = (int8_t)res01;
743         }
744 
745         rhs += 2 * rhs_cols;
746         dst += 2;
747     }
748 
749     if (rhs_rows % 2)
750     {
751         const int8_t *lhs_ptr = &lhs[0];
752         int8_t *dst_ptr = &dst[0];
753 
754         for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
755         {
756             const int8_t *rhs_ptr = &rhs[0];
757             int32_t res00 = 0;
758             if (bias)
759             {
760                 res00 = bias[rhs_rows - 1];
761             }
762 
763             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
764             {
765                 int32_t rhs_value = rhs_ptr[0];
766                 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
767 
768                 res00 += lhs_value * rhs_value;
769 
770                 ++rhs_ptr;
771                 ++lhs_ptr;
772             }
773             lhs_ptr -= rhs_cols;
774             lhs_ptr += lhs_cols_offset;
775 
776             // Quantize down
777             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
778 
779             // Add offset
780             res00 += dst_offset;
781 
782             // Clamp the result
783             res00 = MAX(res00, activation_min);
784             res00 = MIN(res00, activation_max);
785 
786             dst_ptr[0] = (int8_t)res00;
787             dst_ptr += rhs_rows;
788         }
789     }
790 #endif
791     return ARM_CMSIS_NN_SUCCESS;
792 }
793 
794 /**
795  * @} end of Doxygen group
796  */
797