1 /*
2  * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_s8_nt_t_s8
22  * Description:  Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23  *
24  * $Date:        04 January 2024
25  * $Revision:    V.3.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportConvolution
39  * @{
40  */
41 
42 /*
43  * s8 matrix multiplication with the right-hand-side matrix transposed
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_mat_mult_nt_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t row_address_offset,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
49                                             const int8_t *rhs,
50                                             const int32_t *bias,
51                                             int8_t *dst,
52                                             const int32_t *dst_multipliers,
53                                             const int32_t *dst_shifts,
54                                             const int32_t lhs_rows,
55                                             const int32_t rhs_rows,
56                                             const int32_t rhs_cols,
57                                             const int32_t lhs_offset,
58                                             const int32_t dst_offset,
59                                             const int32_t activation_min,
60                                             const int32_t activation_max,
61                                             const int32_t row_address_offset,
62                                             const int32_t lhs_cols_offset)
63 {
64 
65 #if defined(ARM_MATH_MVEI)
66     int i_items = 0;
67     for (; i_items <= (lhs_rows - 4); i_items += 4)
68     {
69         for (int i = 0; i < rhs_rows; i++)
70         {
71             int32_t acc_n0 = 0;
72             int32_t acc_n1 = 0;
73             int32_t acc_n2 = 0;
74             int32_t acc_n3 = 0;
75 
76             const int8_t *lhs_vec = lhs;
77             const int8_t *ip_row_1 = lhs + lhs_cols_offset;
78             const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
79             const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
80             const int8_t *col_base = rhs + i * rhs_cols;
81             int32_t sum_tmp = 0;
82 
83     #if defined(ARM_MATH_AUTOVECTORIZE)
84             for (int j = 0; j < rhs_cols; j++)
85             {
86                 int32_t col = col_base[j];
87                 sum_tmp += col;
88                 acc_n0 += lhs_vec[j] * col;
89                 acc_n1 += ip_row_1[j] * col;
90                 acc_n2 += ip_row_2[j] * col;
91                 acc_n3 += ip_row_3[j] * col;
92             }
93     #else
94             // Note: If operand initialization is moved around, use '&' constraint to
95             // specify earlyclobber operands.
96             __ASM volatile(" .p2align 2                             \n"
97                            "   wlstp.8         lr, %[cnt], 1f       \n"
98                            "   mov             %[sum], 0            \n"
99                            "   mov             %[out0], 0           \n"
100                            "   mov             %[out1], 0           \n"
101                            "   mov             %[out2], 0           \n"
102                            "   mov             %[out3], 0           \n"
103                            "   vldrb.8         q0, [%[col]], #16    \n"
104                            "2:                                      \n"
105                            "   vaddva.s8      %[sum], q0            \n"
106                            "   vldrb.8         q1, [%[row0]], #16   \n"
107                            "   vmladava.s8    %[out0], q0, q1       \n"
108                            "   vldrb.8         q2, [%[row1]], #16   \n"
109                            "   vmladava.s8     %[out1], q0, q2      \n"
110                            "   vldrb.8         q3, [%[row2]], #16   \n"
111                            "   vmladava.s8     %[out2], q0, q3      \n"
112                            "   vldrb.8         q4, [%[row3]], #16   \n"
113                            "   vmladava.s8     %[out3], q0, q4      \n"
114                            "   vldrb.8         q0, [%[col]], #16    \n"
115                            "   letp            lr, 2b               \n"
116                            "1:                                      \n"
117                            : [col] "+r"(col_base),
118                              [sum] "=Te"(sum_tmp),
119                              [row0] "+r"(lhs_vec),
120                              [row1] "+r"(ip_row_1),
121                              [row2] "+r"(ip_row_2),
122                              [row3] "+r"(ip_row_3),
123                              [out0] "=Te"(acc_n0),
124                              [out1] "=Te"(acc_n1),
125                              [out2] "=Te"(acc_n2),
126                              [out3] "=Te"(acc_n3)
127                            : [cnt] "r"(rhs_cols)
128                            : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
129     #endif
130             int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
131             sum_tmp *= lhs_offset;
132             if (bias)
133             {
134                 sum_tmp += bias[i];
135             }
136             res = vaddq_n_s32(res, sum_tmp);
137 
138             res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
139             res = vaddq_n_s32(res, dst_offset);
140 
141             res = vmaxq_s32(res, vdupq_n_s32(activation_min));
142             res = vminq_s32(res, vdupq_n_s32(activation_max));
143 
144             const uint32x4_t scatter_offset = {
145                 0, (uint32_t)row_address_offset, (uint32_t)row_address_offset * 2, (uint32_t)row_address_offset * 3};
146             vstrbq_scatter_offset_s32(dst, scatter_offset, res);
147             dst++;
148         }
149         lhs += 4 * lhs_cols_offset;
150         dst += 4 * row_address_offset - rhs_rows;
151     }
152 
153     for (; i_items < lhs_rows; i_items++)
154     {
155         int32_t acc[4];
156         const int32_t *multipliers = dst_multipliers;
157         const int32_t *shifts = dst_shifts;
158         for (int i = 0; i < rhs_rows; i++)
159         {
160             int32_t acc_n0 = 0;
161             const int8_t *lhs_vec = lhs;
162             const int8_t *col_base = rhs + i * rhs_cols;
163             int32_t sum_tmp = 0;
164 
165     #if defined(ARM_MATH_AUTOVECTORIZE)
166             for (int j = 0; j < rhs_cols; j++)
167             {
168                 int32_t col = col_base[j];
169                 sum_tmp += col;
170                 acc_n0 += lhs_vec[j] * col;
171             }
172     #else
173             __ASM volatile(" .p2align 2                             \n"
174                            "   wlstp.8         lr, %[cnt], 1f       \n"
175                            "   mov             %[sum], 0            \n"
176                            "   mov             %[out0], 0            \n"
177                            "   vldrb.8         q0, [%[col]], #16    \n"
178                            "2:                                      \n"
179                            "   vaddva.s8      %[sum], q0            \n"
180                            "   vldrb.8         q1, [%[row0]], #16   \n"
181                            "   vmladava.s8    %[out0], q0, q1       \n"
182                            "   vldrb.8         q0, [%[col]], #16    \n"
183                            "   letp            lr, 2b               \n"
184                            "1:                                      \n"
185                            : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
186                            : [cnt] "r"(rhs_cols)
187                            : "q0", "q1", "memory", "r14");
188     #endif
189             sum_tmp *= lhs_offset;
190             sum_tmp += acc_n0;
191             if (bias)
192             {
193                 sum_tmp += bias[i];
194             }
195             const int32_t index = i & 0x3;
196             acc[index] = sum_tmp;
197 
198             if (index == 3)
199             {
200                 int32x4_t res = vldrwq_s32(acc);
201                 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
202                 multipliers += 4;
203                 shifts += 4;
204                 res = vaddq_n_s32(res, dst_offset);
205                 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
206                 res = vminq_s32(res, vdupq_n_s32(activation_max));
207                 vstrbq_s32(dst, res);
208                 dst += 4;
209             }
210         }
211         lhs += lhs_cols_offset;
212         const int32_t tail_rows = rhs_rows & 0x3;
213         for (int i = 0; i < tail_rows; i++)
214         {
215             int32_t acc_n0 = acc[i];
216             acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
217             acc_n0 += dst_offset;
218             acc_n0 = MAX(acc_n0, activation_min);
219             acc_n0 = MIN(acc_n0, activation_max);
220             *dst++ = (int8_t)acc_n0;
221         }
222         dst += row_address_offset - rhs_rows;
223     }
224 
225 #elif defined(ARM_MATH_DSP)
226     (void)row_address_offset;
227     const int32_t rhs_off0 = rhs_cols - 4;
228     const int32_t lhs_off0 = lhs_cols_offset - 4;
229 
230     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
231     {
232         const int8_t *lhs_ptr = &lhs[0];
233         int8_t *dst_ptr = &dst[0];
234 
235         int32_t lhs_offset_contribution0 = 0;
236         int32_t lhs_offset_contribution1 = 0;
237 
238         for (int32_t x = 0; x < rhs_cols; ++x)
239         {
240             lhs_offset_contribution0 += rhs[x];
241             lhs_offset_contribution1 += rhs[x + rhs_cols];
242         }
243 
244         lhs_offset_contribution0 *= lhs_offset;
245         lhs_offset_contribution1 *= lhs_offset;
246         if (bias)
247         {
248             lhs_offset_contribution0 += bias[rhs_rows_idx];
249             lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
250         }
251 
252         int32_t lhs_rows_idx = lhs_rows >> 1;
253 
254         while (lhs_rows_idx)
255         {
256             const int8_t *rhs_ptr = &rhs[0];
257 
258             int32_t res00 = lhs_offset_contribution0;
259             int32_t res01 = lhs_offset_contribution1;
260             int32_t res10 = lhs_offset_contribution0;
261             int32_t res11 = lhs_offset_contribution1;
262 
263             int32_t rhs_cols_idx = 0;
264 
265             int32_t val0, val1, val2, val3, val4, val5;
266 
267             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
268             {
269                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
270                 val2 = SXTB16(val1);
271                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
272                 val3 = SXTB16(val0);
273                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
274                 val1 = SXTB16_RORn(val1, 8);
275                 val0 = SXTB16_RORn(val0, 8);
276 
277                 // 4 x MAC res00, res01
278                 res00 = SMLAD(val3, val2, res00);
279                 val5 = SXTB16(val4);
280                 res00 = SMLAD(val0, val1, res00);
281                 val4 = SXTB16_RORn(val4, 8);
282                 res01 = SMLAD(val3, val5, res01);
283                 res01 = SMLAD(val0, val4, res01);
284 
285                 // 4 x MAC res10, res11
286                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
287                 val3 = SXTB16(val0);
288                 val0 = SXTB16_RORn(val0, 8);
289                 res10 = SMLAD(val3, val2, res10);
290                 res11 = SMLAD(val3, val5, res11);
291                 res10 = SMLAD(val0, val1, res10);
292                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
293                 res11 = SMLAD(val0, val4, res11);
294 
295                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
296                 val2 = SXTB16(val1);
297                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
298                 val3 = SXTB16(val0);
299                 val1 = SXTB16_RORn(val1, 8);
300                 val0 = SXTB16_RORn(val0, 8);
301 
302                 // 4 x MAC res00, res01
303                 res00 = SMLAD(val3, val2, res00);
304                 val5 = SXTB16(val4);
305                 res00 = SMLAD(val0, val1, res00);
306                 val4 = SXTB16_RORn(val4, 8);
307                 res01 = SMLAD(val3, val5, res01);
308                 res01 = SMLAD(val0, val4, res01);
309 
310                 // 4 x MAC res10, res11
311                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
312                 val3 = SXTB16(val0);
313                 val0 = SXTB16_RORn(val0, 8);
314                 res10 = SMLAD(val3, val2, res10);
315                 res11 = SMLAD(val3, val5, res11);
316                 res10 = SMLAD(val0, val1, res10);
317                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
318                 res11 = SMLAD(val0, val4, res11);
319 
320                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
321                 val2 = SXTB16(val1);
322                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
323                 val3 = SXTB16(val0);
324                 val1 = SXTB16_RORn(val1, 8);
325                 val0 = SXTB16_RORn(val0, 8);
326 
327                 // 4 x MAC res00, res01
328                 res00 = SMLAD(val3, val2, res00);
329                 val5 = SXTB16(val4);
330                 res00 = SMLAD(val0, val1, res00);
331                 val4 = SXTB16_RORn(val4, 8);
332                 res01 = SMLAD(val3, val5, res01);
333                 res01 = SMLAD(val0, val4, res01);
334 
335                 // 4 x MAC res10, res11
336                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
337                 val3 = SXTB16(val0);
338                 val0 = SXTB16_RORn(val0, 8);
339                 res10 = SMLAD(val3, val2, res10);
340                 res11 = SMLAD(val3, val5, res11);
341                 res10 = SMLAD(val0, val1, res10);
342                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
343                 res11 = SMLAD(val0, val4, res11);
344 
345                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
346                 val2 = SXTB16(val1);
347                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
348                 val3 = SXTB16(val0);
349                 val1 = SXTB16_RORn(val1, 8);
350                 val0 = SXTB16_RORn(val0, 8);
351 
352                 // 4 x MAC res00, res01
353                 res00 = SMLAD(val3, val2, res00);
354                 val5 = SXTB16(val4);
355                 res00 = SMLAD(val0, val1, res00);
356                 val4 = SXTB16_RORn(val4, 8);
357                 res01 = SMLAD(val3, val5, res01);
358                 res01 = SMLAD(val0, val4, res01);
359 
360                 // 4 x MAC res10, res11
361                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
362                 val3 = SXTB16(val0);
363                 val0 = SXTB16_RORn(val0, 8);
364                 res10 = SMLAD(val3, val2, res10);
365                 res11 = SMLAD(val3, val5, res11);
366                 res10 = SMLAD(val0, val1, res10);
367                 res11 = SMLAD(val0, val4, res11);
368             }
369 
370             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
371             {
372                 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
373                 val2 = SXTB16(val1);
374                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
375                 val3 = SXTB16(val0);
376                 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
377                 val1 = SXTB16_RORn(val1, 8);
378                 val0 = SXTB16_RORn(val0, 8);
379 
380                 // 4 x MAC res00, res01
381                 res00 = SMLAD(val3, val2, res00);
382                 val5 = SXTB16(val4);
383                 res00 = SMLAD(val0, val1, res00);
384                 val4 = SXTB16_RORn(val4, 8);
385                 res01 = SMLAD(val3, val5, res01);
386                 res01 = SMLAD(val0, val4, res01);
387 
388                 // 4 x MAC res10, res11
389                 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
390                 val3 = SXTB16(val0);
391                 val0 = SXTB16_RORn(val0, 8);
392                 res10 = SMLAD(val3, val2, res10);
393                 res11 = SMLAD(val3, val5, res11);
394                 res10 = SMLAD(val0, val1, res10);
395                 res11 = SMLAD(val0, val4, res11);
396             }
397 
398             for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
399             {
400                 int8_t rhs_value0 = rhs_ptr[0];
401                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
402                 int8_t lhs_value = lhs_ptr[0];
403 
404                 res00 += lhs_value * rhs_value0;
405                 res01 += lhs_value * rhs_value1;
406 
407                 lhs_value = lhs_ptr[lhs_cols_offset];
408                 res10 += lhs_value * rhs_value0;
409                 res11 += lhs_value * rhs_value1;
410 
411                 ++rhs_ptr;
412                 ++lhs_ptr;
413             }
414 
415             // Quantize down
416             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
417             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
418             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
419             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
420 
421             // Add offset
422             res00 += dst_offset;
423             res01 += dst_offset;
424             res10 += dst_offset;
425             res11 += dst_offset;
426 
427             // Clamp the result
428             res00 = MAX(res00, activation_min);
429             res00 = MIN(res00, activation_max);
430             res01 = MAX(res01, activation_min);
431             res01 = MIN(res01, activation_max);
432             res10 = MAX(res10, activation_min);
433             res10 = MIN(res10, activation_max);
434             res11 = MAX(res11, activation_min);
435             res11 = MIN(res11, activation_max);
436 
437             dst_ptr[0] = (int8_t)res00;
438             dst_ptr[1] = (int8_t)res01;
439             dst_ptr += rhs_rows;
440             dst_ptr[0] = (int8_t)res10;
441             dst_ptr[1] = (int8_t)res11;
442             dst_ptr += rhs_rows;
443 
444             lhs_ptr -= rhs_cols;
445             lhs_ptr += 2 * lhs_cols_offset;
446 
447             lhs_rows_idx--;
448         }
449 
450         // Left-over rows
451         if (lhs_rows % 2)
452         {
453             const int8_t *rhs_ptr = &rhs[0];
454 
455             int32_t res00 = lhs_offset_contribution0;
456             int32_t res01 = lhs_offset_contribution1;
457 
458             int32_t rhs_cols_idx = 0;
459 
460             int32_t val0, val1, val2, val3, val4, val5;
461             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
462             {
463                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
464                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
465                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
466                 val3 = SXTB16(val0);
467                 val5 = SXTB16(val2);
468                 val4 = SXTB16(val1);
469                 val0 = SXTB16_RORn(val0, 8);
470                 val2 = SXTB16_RORn(val2, 8);
471                 val1 = SXTB16_RORn(val1, 8);
472 
473                 // 4 x MAC res00, res01
474                 res00 = SMLAD(val5, val3, res00);
475                 res00 = SMLAD(val2, val0, res00);
476                 res01 = SMLAD(val5, val4, res01);
477                 res01 = SMLAD(val2, val1, res01);
478 
479                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
480                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
481                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
482                 val3 = SXTB16(val0);
483                 val5 = SXTB16(val2);
484                 val4 = SXTB16(val1);
485                 val0 = SXTB16_RORn(val0, 8);
486                 val2 = SXTB16_RORn(val2, 8);
487                 val1 = SXTB16_RORn(val1, 8);
488 
489                 // 4 x MAC res00, res01
490                 res00 = SMLAD(val5, val3, res00);
491                 res00 = SMLAD(val2, val0, res00);
492                 res01 = SMLAD(val5, val4, res01);
493                 res01 = SMLAD(val2, val1, res01);
494 
495                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
496                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
497                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
498                 val3 = SXTB16(val0);
499                 val5 = SXTB16(val2);
500                 val4 = SXTB16(val1);
501                 val0 = SXTB16_RORn(val0, 8);
502                 val2 = SXTB16_RORn(val2, 8);
503                 val1 = SXTB16_RORn(val1, 8);
504 
505                 // 4 x MAC res00, res01
506                 res00 = SMLAD(val5, val3, res00);
507                 res00 = SMLAD(val2, val0, res00);
508                 res01 = SMLAD(val5, val4, res01);
509                 res01 = SMLAD(val2, val1, res01);
510 
511                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
512                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
513                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
514                 val3 = SXTB16(val0);
515                 val5 = SXTB16(val2);
516                 val4 = SXTB16(val1);
517                 val0 = SXTB16_RORn(val0, 8);
518                 val2 = SXTB16_RORn(val2, 8);
519                 val1 = SXTB16_RORn(val1, 8);
520 
521                 // 4 x MAC res00, res01
522                 res00 = SMLAD(val5, val3, res00);
523                 res00 = SMLAD(val2, val0, res00);
524                 res01 = SMLAD(val5, val4, res01);
525                 res01 = SMLAD(val2, val1, res01);
526             }
527 
528             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
529             {
530                 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
531                 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
532                 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
533                 val3 = SXTB16(val0);
534                 val5 = SXTB16(val2);
535                 val4 = SXTB16(val1);
536                 val0 = SXTB16_RORn(val0, 8);
537                 val2 = SXTB16_RORn(val2, 8);
538                 val1 = SXTB16_RORn(val1, 8);
539 
540                 // 4 x MAC res00, res01
541                 res00 = SMLAD(val5, val3, res00);
542                 res00 = SMLAD(val2, val0, res00);
543                 res01 = SMLAD(val5, val4, res01);
544                 res01 = SMLAD(val2, val1, res01);
545             }
546 
547             // Left-over accumulations
548             for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
549             {
550                 int8_t rhs_value0 = rhs_ptr[0];
551                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
552                 int8_t lhs_value = lhs_ptr[0];
553 
554                 res00 += lhs_value * rhs_value0;
555                 res01 += lhs_value * rhs_value1;
556 
557                 ++rhs_ptr;
558                 ++lhs_ptr;
559             }
560 
561             // Quantize down
562             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
563             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
564 
565             // Add offset
566             res00 += dst_offset;
567             res01 += dst_offset;
568 
569             // Clamp the result
570             res00 = MAX(res00, activation_min);
571             res00 = MIN(res00, activation_max);
572             res01 = MAX(res01, activation_min);
573             res01 = MIN(res01, activation_max);
574 
575             dst_ptr[0] = (int8_t)res00;
576             dst_ptr[1] = (int8_t)res01;
577         }
578 
579         rhs += 2 * rhs_cols;
580         dst += 2;
581     }
582 
583     if (rhs_rows % 2)
584     {
585         const int8_t *lhs_ptr = &lhs[0];
586         int8_t *dst_ptr = &dst[0];
587 
588         for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
589         {
590             const int8_t *rhs_ptr = &rhs[0];
591             int32_t res00 = 0;
592             if (bias)
593             {
594                 res00 = bias[rhs_rows - 1];
595             }
596 
597             for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
598             {
599                 int32_t rhs_value = rhs_ptr[0];
600                 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
601 
602                 res00 += lhs_value * rhs_value;
603 
604                 ++rhs_ptr;
605                 ++lhs_ptr;
606             }
607             lhs_ptr -= rhs_cols;
608             lhs_ptr += lhs_cols_offset;
609 
610             // Quantize down
611             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
612 
613             // Add offset
614             res00 += dst_offset;
615 
616             // Clamp the result
617             res00 = MAX(res00, activation_min);
618             res00 = MIN(res00, activation_max);
619 
620             dst_ptr[0] = (int8_t)res00;
621             dst_ptr += rhs_rows;
622         }
623     }
624 #else
625     (void)row_address_offset;
626     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
627     {
628         const int8_t *lhs_ptr = &lhs[0];
629         int8_t *dst_ptr = &dst[0];
630 
631         int32_t lhs_offset_contribution0 = 0;
632         int32_t lhs_offset_contribution1 = 0;
633 
634         for (int32_t x = 0; x < rhs_cols; ++x)
635         {
636             lhs_offset_contribution0 += rhs[x];
637             lhs_offset_contribution1 += rhs[x + rhs_cols];
638         }
639 
640         lhs_offset_contribution0 *= lhs_offset;
641         lhs_offset_contribution1 *= lhs_offset;
642         if (bias)
643         {
644             lhs_offset_contribution0 += bias[rhs_rows_idx];
645             lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
646         }
647 
648         int32_t lhs_rows_idx = lhs_rows >> 1;
649 
650         while (lhs_rows_idx)
651         {
652             const int8_t *rhs_ptr = &rhs[0];
653 
654             int32_t res00 = lhs_offset_contribution0;
655             int32_t res01 = lhs_offset_contribution1;
656             int32_t res10 = lhs_offset_contribution0;
657             int32_t res11 = lhs_offset_contribution1;
658 
659             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
660             {
661                 int8_t rhs_value0 = rhs_ptr[0];
662                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
663                 int8_t lhs_value = lhs_ptr[0];
664 
665                 res00 += lhs_value * rhs_value0;
666                 res01 += lhs_value * rhs_value1;
667 
668                 lhs_value = lhs_ptr[lhs_cols_offset];
669                 res10 += lhs_value * rhs_value0;
670                 res11 += lhs_value * rhs_value1;
671 
672                 ++rhs_ptr;
673                 ++lhs_ptr;
674             }
675 
676             // Quantize down
677             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
678             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
679             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
680             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
681 
682             // Add offset
683             res00 += dst_offset;
684             res01 += dst_offset;
685             res10 += dst_offset;
686             res11 += dst_offset;
687 
688             // Clamp the result
689             res00 = MAX(res00, activation_min);
690             res00 = MIN(res00, activation_max);
691             res01 = MAX(res01, activation_min);
692             res01 = MIN(res01, activation_max);
693             res10 = MAX(res10, activation_min);
694             res10 = MIN(res10, activation_max);
695             res11 = MAX(res11, activation_min);
696             res11 = MIN(res11, activation_max);
697 
698             dst_ptr[0] = (int8_t)res00;
699             dst_ptr[1] = (int8_t)res01;
700             dst_ptr += rhs_rows;
701             dst_ptr[0] = (int8_t)res10;
702             dst_ptr[1] = (int8_t)res11;
703             dst_ptr += rhs_rows;
704 
705             lhs_ptr -= rhs_cols;
706             lhs_ptr += 2 * lhs_cols_offset;
707 
708             lhs_rows_idx--;
709         }
710 
711         // Left-over rows
712         if (lhs_rows % 2)
713         {
714             const int8_t *rhs_ptr = &rhs[0];
715 
716             int32_t res00 = lhs_offset_contribution0;
717             int32_t res01 = lhs_offset_contribution1;
718 
719             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
720             {
721                 int8_t rhs_value0 = rhs_ptr[0];
722                 int8_t rhs_value1 = rhs_ptr[rhs_cols];
723                 int8_t lhs_value = lhs_ptr[0];
724 
725                 res00 += lhs_value * rhs_value0;
726                 res01 += lhs_value * rhs_value1;
727 
728                 ++rhs_ptr;
729                 ++lhs_ptr;
730             }
731 
732             // Quantize down
733             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
734             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
735 
736             // Add offset
737             res00 += dst_offset;
738             res01 += dst_offset;
739 
740             // Clamp the result
741             res00 = MAX(res00, activation_min);
742             res00 = MIN(res00, activation_max);
743             res01 = MAX(res01, activation_min);
744             res01 = MIN(res01, activation_max);
745 
746             dst_ptr[0] = (int8_t)res00;
747             dst_ptr[1] = (int8_t)res01;
748         }
749 
750         rhs += 2 * rhs_cols;
751         dst += 2;
752     }
753 
754     if (rhs_rows % 2)
755     {
756         const int8_t *lhs_ptr = &lhs[0];
757         int8_t *dst_ptr = &dst[0];
758 
759         for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
760         {
761             const int8_t *rhs_ptr = &rhs[0];
762             int32_t res00 = 0;
763             if (bias)
764             {
765                 res00 = bias[rhs_rows - 1];
766             }
767 
768             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
769             {
770                 int32_t rhs_value = rhs_ptr[0];
771                 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
772 
773                 res00 += lhs_value * rhs_value;
774 
775                 ++rhs_ptr;
776                 ++lhs_ptr;
777             }
778             lhs_ptr -= rhs_cols;
779             lhs_ptr += lhs_cols_offset;
780 
781             // Quantize down
782             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
783 
784             // Add offset
785             res00 += dst_offset;
786 
787             // Clamp the result
788             res00 = MAX(res00, activation_min);
789             res00 = MIN(res00, activation_max);
790 
791             dst_ptr[0] = (int8_t)res00;
792             dst_ptr += rhs_rows;
793         }
794     }
795 #endif
796     return ARM_CMSIS_NN_SUCCESS;
797 }
798 
799 /**
800  * @} end of Doxygen group
801  */
802