1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_nt_t_s4
22  * Description:  Matrix multiplication support function with the right-hand-side (rhs) matrix transposed, and 4 bit rhs.
23  *
24  * $Date:        10 April 2024
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportConvolution
39  * @{
40  */
41 
42 /*
43  * s4 matrix multiplication with the right-hand-side matrix transposed
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_mat_mult_nt_t_s4(const int8_t * lhs,const int8_t * packed_rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
49                                             const int8_t *packed_rhs,
50                                             const int32_t *bias,
51                                             int8_t *dst,
52                                             const int32_t *dst_multipliers,
53                                             const int32_t *dst_shifts,
54                                             const int32_t lhs_rows,
55                                             const int32_t rhs_rows,
56                                             const int32_t rhs_cols,
57                                             const int32_t lhs_offset,
58                                             const int32_t dst_offset,
59                                             const int32_t activation_min,
60                                             const int32_t activation_max,
61                                             const int32_t lhs_cols_offset)
62 {
63 #if defined(ARM_MATH_MVEI)
64     int i_items = 0;
65     const int rhs_cols_offset = rhs_cols % 16;
66     const int32_t blk_cnt = rhs_cols >> 4;
67     const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101
68     const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
69     const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3};
70 
71     for (; i_items <= (lhs_rows - 4); i_items += 4)
72     {
73         int8_t const *col_base = packed_rhs;
74         for (int i = 0; i < rhs_rows; i++)
75         {
76 
77             int32_t acc_n0 = 0;
78             int32_t acc_n1 = 0;
79             int32_t acc_n2 = 0;
80             int32_t acc_n3 = 0;
81 
82             int8_t const *ip_row_0 = lhs;
83             int8_t const *ip_row_1 = lhs + lhs_cols_offset;
84             int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
85             int8_t const *ip_row_3 = lhs + (3 * lhs_cols_offset);
86             int32_t sum_tmp = 0;
87 
88             mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
89 
90             if ((rhs_cols & 0x1) & (i & 0x1))
91             {
92                 rmdr_mask >>= 1;
93                 int32_t col = col_base[0] >> 4;
94                 sum_tmp = col;
95                 acc_n0 += ip_row_0[0] * col;
96                 acc_n1 += ip_row_1[0] * col;
97                 acc_n2 += ip_row_2[0] * col;
98                 acc_n3 += ip_row_3[0] * col;
99 
100                 ++col_base;
101                 ++ip_row_0;
102                 ++ip_row_1;
103                 ++ip_row_2;
104                 ++ip_row_3;
105             }
106 
107             for (int j = blk_cnt; j > 0; --j)
108             {
109                 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
110                 col_base += 8;
111 
112                 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
113                 col_vec = vshrq_n_s8(col_vec, 4);
114 
115                 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
116 
117                 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
118                 ip_row_0 += 16;
119                 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
120 
121                 lhs_vec = vldrbq_s8(ip_row_1);
122                 ip_row_1 += 16;
123                 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
124 
125                 lhs_vec = vldrbq_s8(ip_row_2);
126                 ip_row_2 += 16;
127                 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
128 
129                 lhs_vec = vldrbq_s8(ip_row_3);
130                 ip_row_3 += 16;
131                 acc_n3 = vmladavaq_s8(acc_n3, col_vec, lhs_vec);
132             }
133 
134             if (rmdr_mask)
135             {
136                 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
137                 col_base += rhs_cols_offset >> 1;
138                 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
139                 col_vec = vshrq_n_s8(col_vec, 4);
140 
141                 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
142 
143                 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
144                 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
145 
146                 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
147                 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
148 
149                 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
150                 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
151 
152                 lhs_vec = vldrbq_z_s8(ip_row_3, rmdr_mask);
153                 acc_n3 = vmladavaq_p_s8(acc_n3, col_vec, lhs_vec, rmdr_mask);
154             }
155 
156             int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
157             sum_tmp *= lhs_offset;
158             if (bias)
159             {
160                 sum_tmp += bias[i];
161             }
162             res = vaddq_n_s32(res, sum_tmp);
163 
164             res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
165             res = vaddq_n_s32(res, dst_offset);
166 
167             res = vmaxq_s32(res, vdupq_n_s32(activation_min));
168             res = vminq_s32(res, vdupq_n_s32(activation_max));
169 
170             vstrbq_scatter_offset_s32(dst, scatter_offset, res);
171             dst++;
172         }
173         lhs += 4 * lhs_cols_offset;
174         dst += (3 * rhs_rows);
175     }
176     if (lhs_rows % 4 == 3)
177     {
178         int8_t const *col_base = packed_rhs;
179         const mve_pred16_t requant_mask = vctp32q(3);
180         for (int i = 0; i < rhs_rows; i++)
181         {
182 
183             int32_t acc_n0 = 0;
184             int32_t acc_n1 = 0;
185             int32_t acc_n2 = 0;
186 
187             int8_t const *ip_row_0 = lhs;
188             int8_t const *ip_row_1 = lhs + lhs_cols_offset;
189             int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
190             int32_t sum_tmp = 0;
191 
192             mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
193 
194             if ((rhs_cols & 0x1) & (i & 0x1))
195             {
196                 rmdr_mask >>= 1;
197                 int32_t col = col_base[0] >> 4;
198                 sum_tmp = col;
199                 acc_n0 += ip_row_0[0] * col;
200                 acc_n1 += ip_row_1[0] * col;
201                 acc_n2 += ip_row_2[0] * col;
202 
203                 ++col_base;
204                 ++ip_row_0;
205                 ++ip_row_1;
206                 ++ip_row_2;
207             }
208 
209             for (int j = blk_cnt; j > 0; --j)
210             {
211                 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
212                 col_base += 8;
213 
214                 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
215                 col_vec = vshrq_n_s8(col_vec, 4);
216 
217                 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
218 
219                 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
220                 ip_row_0 += 16;
221                 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
222 
223                 lhs_vec = vldrbq_s8(ip_row_1);
224                 ip_row_1 += 16;
225                 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
226 
227                 lhs_vec = vldrbq_s8(ip_row_2);
228                 ip_row_2 += 16;
229                 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
230             }
231 
232             if (rmdr_mask)
233             {
234                 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
235                 col_base += rhs_cols_offset >> 1;
236                 col_vec = vrshlq_m_n_s8(col_vec, 4, (lower_nibble_mask & rmdr_mask));
237                 col_vec = vshrq_n_s8(col_vec, 4);
238 
239                 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
240 
241                 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
242                 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
243 
244                 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
245                 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
246 
247                 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
248                 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
249             }
250 
251             int32x4_t res = {acc_n0, acc_n1, acc_n2, 0};
252             sum_tmp *= lhs_offset;
253             if (bias)
254             {
255                 sum_tmp += bias[i];
256             }
257 
258             res = vaddq_x_n_s32(res, sum_tmp, requant_mask);
259 
260             res = arm_requantize_mve_pred(res, dst_multipliers[i], dst_shifts[i], requant_mask);
261             res = vaddq_x_n_s32(res, dst_offset, requant_mask);
262 
263             res = vmaxq_x_s32(res, vdupq_n_s32(activation_min), requant_mask);
264             res = vminq_x_s32(res, vdupq_n_s32(activation_max), requant_mask);
265 
266             vstrbq_scatter_offset_p_s32(dst, scatter_offset, res, requant_mask);
267             dst++;
268         }
269         lhs += 3 * lhs_cols_offset;
270         dst += (2 * rhs_rows);
271     }
272     else
273     {
274         for (; i_items < lhs_rows; i_items++)
275         {
276             int32_t acc[4];
277             const int32_t *multipliers = dst_multipliers;
278             const int32_t *shifts = dst_shifts;
279             const int8_t *col_base = packed_rhs;
280             int col_inc = rhs_cols_offset >> 1;
281 
282             for (int i = 0; i < rhs_rows; i++)
283             {
284                 int32_t acc_n0 = 0;
285                 const int8_t *ip_row_0 = lhs;
286                 int32_t sum_tmp = 0;
287                 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
288 
289                 if ((rhs_cols & 0x1) & (i & 0x1))
290                 {
291                     rmdr_mask >>= 1;
292                     int32_t col = col_base[0] >> 4;
293                     sum_tmp += col;
294                     acc_n0 += ip_row_0[0] * col;
295 
296                     ++col_base;
297                     ++ip_row_0;
298                 }
299 
300                 for (int j = blk_cnt; j > 0; --j)
301                 {
302                     int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
303                     col_base += 8;
304 
305                     col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
306                     col_vec = vshrq_n_s8(col_vec, 4);
307 
308                     sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
309 
310                     int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
311                     ip_row_0 += 16;
312                     acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
313                 }
314 
315                 if (rmdr_mask)
316                 {
317                     int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
318                     col_base += col_inc;
319                     col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
320                     col_vec = vshrq_n_s8(col_vec, 4);
321 
322                     sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
323 
324                     int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
325                     acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
326                 }
327 
328                 sum_tmp *= lhs_offset;
329                 sum_tmp += acc_n0;
330                 if (bias)
331                 {
332                     sum_tmp += bias[i];
333                 }
334                 const int32_t index = i & 0x3;
335                 acc[index] = sum_tmp;
336 
337                 if (index == 3)
338                 {
339                     int32x4_t res = vldrwq_s32(acc);
340                     res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
341                     multipliers += 4;
342                     shifts += 4;
343                     res = vaddq_n_s32(res, dst_offset);
344                     res = vmaxq_s32(res, vdupq_n_s32(activation_min));
345                     res = vminq_s32(res, vdupq_n_s32(activation_max));
346                     vstrbq_s32(dst, res);
347                     dst += 4;
348                 }
349             }
350             lhs += lhs_cols_offset;
351             const int32_t tail_rows = rhs_rows & 0x3;
352             for (int i = 0; i < tail_rows; i++)
353             {
354                 int32_t acc_n0 = acc[i];
355                 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
356                 acc_n0 += dst_offset;
357                 acc_n0 = MAX(acc_n0, activation_min);
358                 acc_n0 = MIN(acc_n0, activation_max);
359                 *dst++ = (int8_t)acc_n0;
360             }
361         }
362     }
363 #elif defined(ARM_MATH_DSP)
364     const int32_t lhs_cols_off1 = lhs_cols_offset - 4;
365     const int16_t i16_lhs_offset = (int16_t)lhs_offset;
366     const uint32_t ui32_lhs_offset_i16x2 = PKHBT(i16_lhs_offset, i16_lhs_offset, 16);
367     const int32_t rhs_cols_int4 = rhs_cols >> 1;
368 
369     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
370     {
371 
372         const int8_t *lhs_ptr = &lhs[0];
373         int8_t *dst_ptr = &dst[0];
374 
375         int32_t lhs_rows_idx = lhs_rows >> 1;
376         while (lhs_rows_idx)
377         {
378             const int8_t *packed_rhs_ptr = &packed_rhs[0];
379 
380             int32_t res00 = 0;
381             int32_t res01 = 0;
382             int32_t res10 = 0;
383             int32_t res11 = 0;
384 
385             int32_t spillover00 = 0;
386             int32_t spillover01 = 0;
387             int32_t spillover10 = 0;
388             int32_t spillover11 = 0;
389 
390             if (bias)
391             {
392                 res00 = bias[rhs_rows_idx];
393                 res01 = bias[rhs_rows_idx + 2];
394                 res10 = bias[rhs_rows_idx];
395                 res11 = bias[rhs_rows_idx + 2];
396                 spillover00 = bias[rhs_rows_idx + 1];
397                 spillover01 = bias[rhs_rows_idx + 3];
398                 spillover10 = bias[rhs_rows_idx + 1];
399                 spillover11 = bias[rhs_rows_idx + 3];
400             }
401 
402             int32_t rhs_cols_idx = 0;
403 
404             int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
405 
406             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
407             {
408                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
409                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
410                 packed_rhs_ptr += 2;
411 
412                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
413                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
414                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
415 
416                 // 4 x MAC res00, res01
417                 res00 = SMLAD(rhs_low0, lhs_low, res00);
418                 res00 = SMLAD(rhs_high0, lhs_high, res00);
419                 res01 = SMLAD(rhs_low1, lhs_low, res01);
420                 res01 = SMLAD(rhs_high1, lhs_high, res01);
421 
422                 // 4 x MAC res10, res11
423                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
424                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
425                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
426 
427                 res10 = SMLAD(rhs_low0, lhs_low, res10);
428                 res11 = SMLAD(rhs_low1, lhs_low, res11);
429                 res10 = SMLAD(rhs_high0, lhs_high, res10);
430                 res11 = SMLAD(rhs_high1, lhs_high, res11);
431 
432                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
433                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
434                 packed_rhs_ptr += 2;
435 
436                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
437                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
438                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
439 
440                 // 4 x MAC res00, res01
441                 res00 = SMLAD(rhs_low0, lhs_low, res00);
442                 res00 = SMLAD(rhs_high0, lhs_high, res00);
443                 res01 = SMLAD(rhs_low1, lhs_low, res01);
444                 res01 = SMLAD(rhs_high1, lhs_high, res01);
445 
446                 // 4 x MAC res10, res11
447                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
448                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
449                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
450 
451                 res10 = SMLAD(rhs_low0, lhs_low, res10);
452                 res11 = SMLAD(rhs_low1, lhs_low, res11);
453                 res10 = SMLAD(rhs_high0, lhs_high, res10);
454                 res11 = SMLAD(rhs_high1, lhs_high, res11);
455 
456                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
457                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
458                 packed_rhs_ptr += 2;
459 
460                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
461                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
462                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
463 
464                 // 4 x MAC res00, res01
465                 res00 = SMLAD(rhs_low0, lhs_low, res00);
466                 res00 = SMLAD(rhs_high0, lhs_high, res00);
467                 res01 = SMLAD(rhs_low1, lhs_low, res01);
468                 res01 = SMLAD(rhs_high1, lhs_high, res01);
469 
470                 // 4 x MAC res10, res11
471                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
472                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
473                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
474 
475                 res10 = SMLAD(rhs_low0, lhs_low, res10);
476                 res11 = SMLAD(rhs_low1, lhs_low, res11);
477                 res10 = SMLAD(rhs_high0, lhs_high, res10);
478                 res11 = SMLAD(rhs_high1, lhs_high, res11);
479 
480                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
481                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
482                 packed_rhs_ptr += 2;
483 
484                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
485                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
486                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
487 
488                 // 4 x MAC res00, res01
489                 res00 = SMLAD(rhs_low0, lhs_low, res00);
490                 res00 = SMLAD(rhs_high0, lhs_high, res00);
491                 res01 = SMLAD(rhs_low1, lhs_low, res01);
492                 res01 = SMLAD(rhs_high1, lhs_high, res01);
493 
494                 // 4 x MAC res10, res11
495                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
496                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
497                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
498 
499                 res10 = SMLAD(rhs_low0, lhs_low, res10);
500                 res11 = SMLAD(rhs_low1, lhs_low, res11);
501                 res10 = SMLAD(rhs_high0, lhs_high, res10);
502                 res11 = SMLAD(rhs_high1, lhs_high, res11);
503             }
504             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
505             {
506                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
507                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
508                 packed_rhs_ptr += 2;
509 
510                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
511                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
512                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
513 
514                 // 4 x MAC res00, res01
515                 res00 = SMLAD(rhs_low0, lhs_low, res00);
516                 res00 = SMLAD(rhs_high0, lhs_high, res00);
517                 res01 = SMLAD(rhs_low1, lhs_low, res01);
518                 res01 = SMLAD(rhs_high1, lhs_high, res01);
519 
520                 // 4 x MAC res10, res11
521                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
522                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
523                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
524                 res10 = SMLAD(rhs_low0, lhs_low, res10);
525                 res11 = SMLAD(rhs_low1, lhs_low, res11);
526                 res10 = SMLAD(rhs_high0, lhs_high, res10);
527                 res11 = SMLAD(rhs_high1, lhs_high, res11);
528             }
529             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
530             {
531                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
532                 rhs_high0 = packed_rhs_ptr[0] >> 4;
533 
534                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
535                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
536 
537                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
538                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
539 
540                 res00 += lhs_low * rhs_low0;
541                 res00 += lhs_high * rhs_high0;
542                 res01 += lhs_low * rhs_low1;
543                 res01 += lhs_high * rhs_high1;
544 
545                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
546                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
547                 res10 += lhs_low * rhs_low0;
548                 res10 += lhs_high * rhs_high0;
549                 res11 += lhs_low * rhs_low1;
550                 res11 += lhs_high * rhs_high1;
551 
552                 ++packed_rhs_ptr;
553                 lhs_ptr += 2;
554             }
555             if (rhs_cols % 2)
556             {
557                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
558                 rhs_high0 = packed_rhs_ptr[0] >> 4;
559                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
560                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
561 
562                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
563                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
564 
565                 res00 += lhs_low * rhs_low0;
566                 res01 += lhs_low * rhs_low1;
567 
568                 res10 += lhs_high * rhs_low0;
569                 res11 += lhs_high * rhs_low1;
570 
571                 lhs_ptr -= rhs_cols - 1;
572 
573                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
574                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
575 
576                 spillover00 += lhs_low * rhs_high0;
577                 spillover01 += lhs_low * rhs_high1;
578 
579                 spillover10 += lhs_high * rhs_high0;
580                 spillover11 += lhs_high * rhs_high1;
581 
582                 ++packed_rhs_ptr;
583                 ++lhs_ptr;
584             }
585             else
586             {
587                 lhs_ptr -= rhs_cols;
588             }
589 
590             // Quantize down
591             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
592             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
593             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
594             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
595 
596             // Add offset
597             res00 += dst_offset;
598             res01 += dst_offset;
599             res10 += dst_offset;
600             res11 += dst_offset;
601 
602             // Clamp the result
603             res00 = MAX(res00, activation_min);
604             res00 = MIN(res00, activation_max);
605             res01 = MAX(res01, activation_min);
606             res01 = MIN(res01, activation_max);
607             res10 = MAX(res10, activation_min);
608             res10 = MIN(res10, activation_max);
609             res11 = MAX(res11, activation_min);
610             res11 = MIN(res11, activation_max);
611 
612             dst_ptr[0] = (int8_t)res00;
613             dst_ptr[2] = (int8_t)res01;
614             dst_ptr += rhs_rows;
615             dst_ptr[0] = (int8_t)res10;
616             dst_ptr[2] = (int8_t)res11;
617             dst_ptr -= rhs_rows;
618 
619             res00 = spillover00;
620             res01 = spillover01;
621             res10 = spillover10;
622             res11 = spillover11;
623 
624             rhs_cols_idx = 0;
625 
626             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
627             {
628                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
629                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
630                 packed_rhs_ptr += 2;
631 
632                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
633                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
634                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
635 
636                 // 4 x MAC res00, res01
637                 res00 = SMLAD(rhs_low0, lhs_low, res00);
638                 res00 = SMLAD(rhs_high0, lhs_high, res00);
639                 res01 = SMLAD(rhs_low1, lhs_low, res01);
640                 res01 = SMLAD(rhs_high1, lhs_high, res01);
641 
642                 // 4 x MAC res10, res11
643                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
644                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
645                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
646 
647                 res10 = SMLAD(rhs_low0, lhs_low, res10);
648                 res11 = SMLAD(rhs_low1, lhs_low, res11);
649                 res10 = SMLAD(rhs_high0, lhs_high, res10);
650                 res11 = SMLAD(rhs_high1, lhs_high, res11);
651 
652                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
653                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
654                 packed_rhs_ptr += 2;
655 
656                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
657                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
658                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
659 
660                 // 4 x MAC res00, res01
661                 res00 = SMLAD(rhs_low0, lhs_low, res00);
662                 res00 = SMLAD(rhs_high0, lhs_high, res00);
663                 res01 = SMLAD(rhs_low1, lhs_low, res01);
664                 res01 = SMLAD(rhs_high1, lhs_high, res01);
665 
666                 // 4 x MAC res10, res11
667                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
668                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
669                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
670 
671                 res10 = SMLAD(rhs_low0, lhs_low, res10);
672                 res11 = SMLAD(rhs_low1, lhs_low, res11);
673                 res10 = SMLAD(rhs_high0, lhs_high, res10);
674                 res11 = SMLAD(rhs_high1, lhs_high, res11);
675 
676                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
677                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
678                 packed_rhs_ptr += 2;
679 
680                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
681                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
682                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
683 
684                 // 4 x MAC res00, res01
685                 res00 = SMLAD(rhs_low0, lhs_low, res00);
686                 res00 = SMLAD(rhs_high0, lhs_high, res00);
687                 res01 = SMLAD(rhs_low1, lhs_low, res01);
688                 res01 = SMLAD(rhs_high1, lhs_high, res01);
689 
690                 // 4 x MAC res10, res11
691                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
692                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
693                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
694 
695                 res10 = SMLAD(rhs_low0, lhs_low, res10);
696                 res11 = SMLAD(rhs_low1, lhs_low, res11);
697                 res10 = SMLAD(rhs_high0, lhs_high, res10);
698                 res11 = SMLAD(rhs_high1, lhs_high, res11);
699 
700                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
701                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
702                 packed_rhs_ptr += 2;
703 
704                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
705                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
706                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
707 
708                 // 4 x MAC res00, res01
709                 res00 = SMLAD(rhs_low0, lhs_low, res00);
710                 res00 = SMLAD(rhs_high0, lhs_high, res00);
711                 res01 = SMLAD(rhs_low1, lhs_low, res01);
712                 res01 = SMLAD(rhs_high1, lhs_high, res01);
713 
714                 // 4 x MAC res10, res11
715                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
716                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
717                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
718 
719                 res10 = SMLAD(rhs_low0, lhs_low, res10);
720                 res11 = SMLAD(rhs_low1, lhs_low, res11);
721                 res10 = SMLAD(rhs_high0, lhs_high, res10);
722                 res11 = SMLAD(rhs_high1, lhs_high, res11);
723             }
724             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
725             {
726                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
727                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
728                 packed_rhs_ptr += 2;
729 
730                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
731                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
732                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
733 
734                 // 4 x MAC res00, res01
735                 res00 = SMLAD(rhs_low0, lhs_low, res00);
736                 res00 = SMLAD(rhs_high0, lhs_high, res00);
737                 res01 = SMLAD(rhs_low1, lhs_low, res01);
738                 res01 = SMLAD(rhs_high1, lhs_high, res01);
739 
740                 // 4 x MAC res10, res11
741                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
742                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
743                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
744                 res10 = SMLAD(rhs_low0, lhs_low, res10);
745                 res11 = SMLAD(rhs_low1, lhs_low, res11);
746                 res10 = SMLAD(rhs_high0, lhs_high, res10);
747                 res11 = SMLAD(rhs_high1, lhs_high, res11);
748             }
749             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
750             {
751                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
752                 rhs_high0 = packed_rhs_ptr[0] >> 4;
753 
754                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
755                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
756 
757                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
758                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
759 
760                 res00 += lhs_low * rhs_low0;
761                 res00 += lhs_high * rhs_high0;
762                 res01 += lhs_low * rhs_low1;
763                 res01 += lhs_high * rhs_high1;
764 
765                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
766                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
767                 res10 += lhs_low * rhs_low0;
768                 res10 += lhs_high * rhs_high0;
769                 res11 += lhs_low * rhs_low1;
770                 res11 += lhs_high * rhs_high1;
771 
772                 ++packed_rhs_ptr;
773                 lhs_ptr += 2;
774             }
775 
776             // Quantize down
777             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
778             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
779             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
780             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
781 
782             // Add offset
783             res00 += dst_offset;
784             res01 += dst_offset;
785             res10 += dst_offset;
786             res11 += dst_offset;
787 
788             // Clamp the result
789             res00 = MAX(res00, activation_min);
790             res00 = MIN(res00, activation_max);
791             res01 = MAX(res01, activation_min);
792             res01 = MIN(res01, activation_max);
793             res10 = MAX(res10, activation_min);
794             res10 = MIN(res10, activation_max);
795             res11 = MAX(res11, activation_min);
796             res11 = MIN(res11, activation_max);
797 
798             dst_ptr[1] = (int8_t)res00;
799             dst_ptr[3] = (int8_t)res01;
800             dst_ptr += rhs_rows;
801             dst_ptr[1] = (int8_t)res10;
802             dst_ptr[3] = (int8_t)res11;
803             dst_ptr += rhs_rows;
804 
805             lhs_ptr -= rhs_cols;
806             lhs_ptr += 2 * lhs_cols_offset;
807 
808             lhs_rows_idx--;
809         }
810 
811         // Left-over rows
812         if (lhs_rows % 2)
813         {
814             const int8_t *packed_rhs_ptr = &packed_rhs[0];
815 
816             int32_t res00 = 0;
817             int32_t res01 = 0;
818 
819             int32_t spillover00 = 0;
820             int32_t spillover01 = 0;
821 
822             if (bias)
823             {
824                 res00 = bias[rhs_rows_idx];
825                 spillover00 = bias[rhs_rows_idx + 1];
826                 res01 = bias[rhs_rows_idx + 2];
827                 spillover01 = bias[rhs_rows_idx + 3];
828             }
829 
830             int32_t rhs_cols_idx = 0;
831 
832             int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
833 
834             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
835             {
836                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
837                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
838                 packed_rhs_ptr += 2;
839 
840                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
841                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
842                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
843 
844                 // 4 x MAC res00, res01
845                 res00 = SMLAD(rhs_low0, lhs_low, res00);
846                 res00 = SMLAD(rhs_high0, lhs_high, res00);
847                 res01 = SMLAD(rhs_low1, lhs_low, res01);
848                 res01 = SMLAD(rhs_high1, lhs_high, res01);
849 
850                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
851                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
852                 packed_rhs_ptr += 2;
853 
854                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
855                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
856                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
857 
858                 // 4 x MAC res00, res01
859                 res00 = SMLAD(rhs_low0, lhs_low, res00);
860                 res00 = SMLAD(rhs_high0, lhs_high, res00);
861                 res01 = SMLAD(rhs_low1, lhs_low, res01);
862                 res01 = SMLAD(rhs_high1, lhs_high, res01);
863 
864                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
865                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
866                 packed_rhs_ptr += 2;
867 
868                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
869                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
870                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
871 
872                 // 4 x MAC res00, res01
873                 res00 = SMLAD(rhs_low0, lhs_low, res00);
874                 res00 = SMLAD(rhs_high0, lhs_high, res00);
875                 res01 = SMLAD(rhs_low1, lhs_low, res01);
876                 res01 = SMLAD(rhs_high1, lhs_high, res01);
877 
878                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
879                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
880                 packed_rhs_ptr += 2;
881 
882                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
883                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
884                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
885 
886                 // 4 x MAC res00, res01
887                 res00 = SMLAD(rhs_low0, lhs_low, res00);
888                 res00 = SMLAD(rhs_high0, lhs_high, res00);
889                 res01 = SMLAD(rhs_low1, lhs_low, res01);
890                 res01 = SMLAD(rhs_high1, lhs_high, res01);
891             }
892             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
893             {
894                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
895                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
896                 packed_rhs_ptr += 2;
897 
898                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
899                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
900                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
901 
902                 // 4 x MAC res00, res01
903                 res00 = SMLAD(rhs_low0, lhs_low, res00);
904                 res00 = SMLAD(rhs_high0, lhs_high, res00);
905                 res01 = SMLAD(rhs_low1, lhs_low, res01);
906                 res01 = SMLAD(rhs_high1, lhs_high, res01);
907             }
908             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
909             {
910                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
911                 rhs_high0 = packed_rhs_ptr[0] >> 4;
912 
913                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
914                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
915 
916                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
917                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
918 
919                 res00 += lhs_low * rhs_low0;
920                 res00 += lhs_high * rhs_high0;
921                 res01 += lhs_low * rhs_low1;
922                 res01 += lhs_high * rhs_high1;
923 
924                 ++packed_rhs_ptr;
925                 lhs_ptr += 2;
926             }
927             if (rhs_cols % 2)
928             {
929                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
930                 rhs_high0 = packed_rhs_ptr[0] >> 4;
931                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
932                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
933 
934                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
935                 lhs_ptr -= rhs_cols - 1;
936                 lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
937 
938                 res00 += lhs_low * rhs_low0;
939                 res01 += lhs_low * rhs_low1;
940                 spillover00 += lhs_high * rhs_high0;
941                 spillover01 += lhs_high * rhs_high1;
942 
943                 ++packed_rhs_ptr;
944                 ++lhs_ptr;
945             }
946             else
947             {
948                 lhs_ptr -= rhs_cols;
949             }
950 
951             // Quantize down
952             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
953             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
954 
955             // Add offset
956             res00 += dst_offset;
957             res01 += dst_offset;
958 
959             // Clamp the result
960             res00 = MAX(res00, activation_min);
961             res00 = MIN(res00, activation_max);
962             res01 = MAX(res01, activation_min);
963             res01 = MIN(res01, activation_max);
964 
965             dst_ptr[0] = (int8_t)res00;
966             dst_ptr[2] = (int8_t)res01;
967 
968             res00 = spillover00;
969             res01 = spillover01;
970 
971             rhs_cols_idx = 0;
972 
973             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
974             {
975                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
976                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
977                 packed_rhs_ptr += 2;
978 
979                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
980                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
981                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
982 
983                 // 4 x MAC res00, res01
984                 res00 = SMLAD(rhs_low0, lhs_low, res00);
985                 res00 = SMLAD(rhs_high0, lhs_high, res00);
986                 res01 = SMLAD(rhs_low1, lhs_low, res01);
987                 res01 = SMLAD(rhs_high1, lhs_high, res01);
988 
989                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
990                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
991                 packed_rhs_ptr += 2;
992 
993                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
994                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
995                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
996 
997                 // 4 x MAC res00, res01
998                 res00 = SMLAD(rhs_low0, lhs_low, res00);
999                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1000                 res01 = SMLAD(rhs_low1, lhs_low, res01);
1001                 res01 = SMLAD(rhs_high1, lhs_high, res01);
1002 
1003                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1004                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1005                 packed_rhs_ptr += 2;
1006 
1007                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1008                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1009                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1010 
1011                 // 4 x MAC res00, res01
1012                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1013                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1014                 res01 = SMLAD(rhs_low1, lhs_low, res01);
1015                 res01 = SMLAD(rhs_high1, lhs_high, res01);
1016 
1017                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1018                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1019                 packed_rhs_ptr += 2;
1020 
1021                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1022                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1023                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1024 
1025                 // 4 x MAC res00, res01
1026                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1027                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1028                 res01 = SMLAD(rhs_low1, lhs_low, res01);
1029                 res01 = SMLAD(rhs_high1, lhs_high, res01);
1030             }
1031             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1032             {
1033                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1034                 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1035                 packed_rhs_ptr += 2;
1036 
1037                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1038                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1039                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1040 
1041                 // 4 x MAC res00, res01
1042                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1043                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1044                 res01 = SMLAD(rhs_low1, lhs_low, res01);
1045                 res01 = SMLAD(rhs_high1, lhs_high, res01);
1046             }
1047             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1048             {
1049                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1050                 rhs_high0 = packed_rhs_ptr[0] >> 4;
1051 
1052                 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1053                 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1054 
1055                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1056                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1057 
1058                 res00 += lhs_low * rhs_low0;
1059                 res00 += lhs_high * rhs_high0;
1060                 res01 += lhs_low * rhs_low1;
1061                 res01 += lhs_high * rhs_high1;
1062 
1063                 ++packed_rhs_ptr;
1064                 lhs_ptr += 2;
1065             }
1066 
1067             // Quantize down
1068             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1069             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1070 
1071             // Add offset
1072             res00 += dst_offset;
1073             res01 += dst_offset;
1074 
1075             // Clamp the result
1076             res00 = MAX(res00, activation_min);
1077             res00 = MIN(res00, activation_max);
1078             res01 = MAX(res01, activation_min);
1079             res01 = MIN(res01, activation_max);
1080 
1081             dst_ptr[1] = (int8_t)res00;
1082             dst_ptr[3] = (int8_t)res01;
1083         }
1084 
1085         packed_rhs += 2 * rhs_cols;
1086         dst += 4;
1087     }
1088 
1089     int8_t rhs_spilled_col = 0;
1090     const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1091     // Left over rhs rows will be in the range 0 -> 3
1092     for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1093     {
1094         const int8_t *lhs_ptr = &lhs[0];
1095         int8_t *dst_ptr = &dst[0];
1096 
1097         int32_t lhs_rows_idx = lhs_rows >> 1;
1098         while (lhs_rows_idx)
1099         {
1100             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1101 
1102             int32_t res00 = 0;
1103             int32_t res10 = 0;
1104 
1105             if (bias)
1106             {
1107                 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1108                 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1109             }
1110 
1111             // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1112             // differently by dealing with the leftover column from rhs_row_idx[0]
1113             if (rhs_cols % 2 && rhs_rows_idx == 1)
1114             {
1115                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1116                 res00 += lhs_low * rhs_spilled_col;
1117 
1118                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1119                 res10 += lhs_low * rhs_spilled_col;
1120 
1121                 ++lhs_ptr;
1122             }
1123 
1124             int32_t rhs_cols_idx = 0;
1125 
1126             int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1127 
1128             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1129             {
1130                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1131                 packed_rhs_ptr += 2;
1132 
1133                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1134                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1135                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1136 
1137                 // 2 x MAC res00
1138                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1139                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1140 
1141                 // 2 x MAC res10
1142                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1143                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1144                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1145 
1146                 res10 = SMLAD(rhs_low0, lhs_low, res10);
1147                 res10 = SMLAD(rhs_high0, lhs_high, res10);
1148 
1149                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1150                 packed_rhs_ptr += 2;
1151 
1152                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1153                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1154                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1155 
1156                 // 2 x MAC res00
1157                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1158                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1159 
1160                 // 2 x MAC res10
1161                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1162                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1163                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1164 
1165                 res10 = SMLAD(rhs_low0, lhs_low, res10);
1166                 res10 = SMLAD(rhs_high0, lhs_high, res10);
1167 
1168                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1169                 packed_rhs_ptr += 2;
1170 
1171                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1172                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1173                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1174 
1175                 // 2 x MAC res00
1176                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1177                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1178 
1179                 // 2 x MAC res10
1180                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1181                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1182                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1183 
1184                 res10 = SMLAD(rhs_low0, lhs_low, res10);
1185                 res10 = SMLAD(rhs_high0, lhs_high, res10);
1186 
1187                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1188                 packed_rhs_ptr += 2;
1189 
1190                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1191                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1192                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1193 
1194                 // 2 x MAC res00
1195                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1196                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1197 
1198                 // 2 x MAC res10
1199                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1200                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1201                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1202 
1203                 res10 = SMLAD(rhs_low0, lhs_low, res10);
1204                 res10 = SMLAD(rhs_high0, lhs_high, res10);
1205             }
1206 
1207             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1208             {
1209                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1210                 packed_rhs_ptr += 2;
1211 
1212                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1213                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1214                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1215 
1216                 // 2 x MAC res00
1217                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1218                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1219 
1220                 // 2 x MAC res10
1221                 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1222                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1223                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1224                 res10 = SMLAD(rhs_low0, lhs_low, res10);
1225                 res10 = SMLAD(rhs_high0, lhs_high, res10);
1226             }
1227 
1228             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1229             {
1230                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1231                 rhs_high0 = packed_rhs_ptr[0] >> 4;
1232 
1233                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1234                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1235 
1236                 res00 += lhs_low * rhs_low0;
1237                 res00 += lhs_high * rhs_high0;
1238 
1239                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1240                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1241                 res10 += lhs_low * rhs_low0;
1242                 res10 += lhs_high * rhs_high0;
1243 
1244                 ++packed_rhs_ptr;
1245                 lhs_ptr += 2;
1246             }
1247 
1248             if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1249             {
1250                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1251 
1252                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1253                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1254 
1255                 res00 += lhs_low * rhs_low0;
1256                 res10 += lhs_high * rhs_low0;
1257 
1258                 ++lhs_ptr;
1259             }
1260 
1261             lhs_ptr -= rhs_cols;
1262             lhs_ptr += 2 * lhs_cols_offset;
1263 
1264             // Quantize down
1265             res00 = arm_nn_requantize(
1266                 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1267             res10 = arm_nn_requantize(
1268                 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1269 
1270             // Add offset
1271             res00 += dst_offset;
1272             res10 += dst_offset;
1273 
1274             // Clamp the result
1275             res00 = MAX(res00, activation_min);
1276             res00 = MIN(res00, activation_max);
1277             res10 = MAX(res10, activation_min);
1278             res10 = MIN(res10, activation_max);
1279 
1280             dst_ptr[0] = (int8_t)res00;
1281             dst_ptr += rhs_rows;
1282             dst_ptr[0] = (int8_t)res10;
1283             dst_ptr += rhs_rows;
1284 
1285             lhs_rows_idx--;
1286         }
1287         if (lhs_rows % 2)
1288         {
1289             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1290 
1291             int32_t res00 = 0;
1292 
1293             if (bias)
1294             {
1295                 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1296             }
1297 
1298             // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1299             // differently by dealing with the leftover column from rhs_row_idx[0]
1300             if (rhs_cols % 2 && rhs_rows_idx == 1)
1301             {
1302                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1303                 res00 += lhs_low * rhs_spilled_col;
1304 
1305                 ++lhs_ptr;
1306             }
1307 
1308             int32_t rhs_cols_idx = 0;
1309 
1310             int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1311 
1312             for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1313             {
1314                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1315                 packed_rhs_ptr += 2;
1316 
1317                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1318                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1319                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1320 
1321                 // 2 x MAC res00
1322                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1323                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1324 
1325                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1326                 packed_rhs_ptr += 2;
1327 
1328                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1329                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1330                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1331 
1332                 // 2 x MAC res00
1333                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1334                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1335 
1336                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1337                 packed_rhs_ptr += 2;
1338 
1339                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1340                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1341                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1342 
1343                 // 2 x MAC res00
1344                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1345                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1346 
1347                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1348                 packed_rhs_ptr += 2;
1349 
1350                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1351                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1352                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1353 
1354                 // 2 x MAC res00
1355                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1356                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1357             }
1358 
1359             for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1360             {
1361                 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1362                 packed_rhs_ptr += 2;
1363 
1364                 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1365                 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1366                 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1367 
1368                 // 2 x MAC res00
1369                 res00 = SMLAD(rhs_low0, lhs_low, res00);
1370                 res00 = SMLAD(rhs_high0, lhs_high, res00);
1371             }
1372 
1373             for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1374             {
1375                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1376                 rhs_high0 = packed_rhs_ptr[0] >> 4;
1377 
1378                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1379                 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1380 
1381                 res00 += lhs_low * rhs_low0;
1382                 res00 += lhs_high * rhs_high0;
1383 
1384                 ++packed_rhs_ptr;
1385                 lhs_ptr += 2;
1386             }
1387 
1388             if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1389             {
1390                 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1391 
1392                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1393                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1394 
1395                 res00 += lhs_low * rhs_low0;
1396 
1397                 ++lhs_ptr;
1398             }
1399 
1400             // Quantize down
1401             res00 = arm_nn_requantize(
1402                 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1403 
1404             // Add offset
1405             res00 += dst_offset;
1406 
1407             // Clamp the result
1408             res00 = MAX(res00, activation_min);
1409             res00 = MIN(res00, activation_max);
1410 
1411             dst_ptr[0] = (int8_t)res00;
1412         }
1413         if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1414         {
1415             rhs_spilled_col = packed_rhs[rhs_cols_int4] >> 4;
1416             packed_rhs += rhs_cols_int4 + 1;
1417         }
1418         else
1419         {
1420             rhs_spilled_col = 0;
1421             packed_rhs += rhs_cols_int4;
1422         }
1423 
1424         ++dst;
1425     }
1426 #else
1427 
1428     const int32_t rhs_cols_int4 = rhs_cols >> 1;
1429 
1430     for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
1431     {
1432         const int8_t *lhs_ptr = &lhs[0];
1433         int8_t *dst_ptr = &dst[0];
1434 
1435         for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1436         {
1437             // To avoid the issue of packed values leaking into the next rhs row
1438             // we instead evaluate the rhs rows in pairs like so:
1439             // rhs[0] and rhs[2], rhs[1] and rhs[3]
1440 
1441             // Start processing rhs_row[0] and rhs_row[2]
1442             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1443 
1444             int32_t res00 = 0;
1445             int32_t res01 = 0;
1446             int32_t res10 = 0;
1447             int32_t res11 = 0;
1448 
1449             int32_t spillover00 = 0;
1450             int32_t spillover01 = 0;
1451             int32_t spillover10 = 0;
1452             int32_t spillover11 = 0;
1453 
1454             if (bias)
1455             {
1456                 res00 = bias[rhs_rows_idx];
1457                 res01 = bias[rhs_rows_idx + 2];
1458                 res10 = bias[rhs_rows_idx];
1459                 res11 = bias[rhs_rows_idx + 2];
1460             }
1461 
1462             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1463             {
1464                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1465                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1466                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1467                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1468 
1469                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1470                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1471 
1472                 res00 += lhs_low * rhs_low0;
1473                 res00 += lhs_high * rhs_high0;
1474                 res01 += lhs_low * rhs_low1;
1475                 res01 += lhs_high * rhs_high1;
1476 
1477                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1478                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1479 
1480                 res10 += lhs_low * rhs_low0;
1481                 res10 += lhs_high * rhs_high0;
1482                 res11 += lhs_low * rhs_low1;
1483                 res11 += lhs_high * rhs_high1;
1484 
1485                 ++packed_rhs_ptr;
1486                 lhs_ptr += 2;
1487             }
1488             if (rhs_cols % 2)
1489             {
1490                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1491                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1492                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1493                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1494 
1495                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1496                 int32_t lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1497 
1498                 res00 += lhs_low * rhs_low0;
1499                 res01 += lhs_low * rhs_low1;
1500 
1501                 res10 += lhs_high * rhs_low0;
1502                 res11 += lhs_high * rhs_low1;
1503 
1504                 lhs_ptr -= rhs_cols - 1;
1505 
1506                 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1507                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1508 
1509                 spillover00 += lhs_low * rhs_high0;
1510                 spillover01 += lhs_low * rhs_high1;
1511 
1512                 spillover10 += lhs_high * rhs_high0;
1513                 spillover11 += lhs_high * rhs_high1;
1514 
1515                 ++packed_rhs_ptr;
1516                 ++lhs_ptr;
1517             }
1518             else
1519             {
1520                 lhs_ptr -= rhs_cols;
1521             }
1522 
1523             // Quantize down
1524             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1525             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1526             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1527             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1528 
1529             // Add offset
1530             res00 += dst_offset;
1531             res01 += dst_offset;
1532             res10 += dst_offset;
1533             res11 += dst_offset;
1534 
1535             // Clamp the result
1536             res00 = MAX(res00, activation_min);
1537             res00 = MIN(res00, activation_max);
1538             res01 = MAX(res01, activation_min);
1539             res01 = MIN(res01, activation_max);
1540             res10 = MAX(res10, activation_min);
1541             res10 = MIN(res10, activation_max);
1542             res11 = MAX(res11, activation_min);
1543             res11 = MIN(res11, activation_max);
1544 
1545             dst_ptr[0] = (int8_t)res00;
1546             dst_ptr[2] = (int8_t)res01;
1547             dst_ptr += rhs_rows;
1548             dst_ptr[0] = (int8_t)res10;
1549             dst_ptr[2] = (int8_t)res11;
1550             dst_ptr -= rhs_rows;
1551 
1552             // Start processing rhs_row[1] and rhs_row[3]
1553             res00 = spillover00;
1554             res01 = spillover01;
1555             res10 = spillover10;
1556             res11 = spillover11;
1557             if (bias)
1558             {
1559                 res00 += bias[rhs_rows_idx + 1];
1560                 res01 += bias[rhs_rows_idx + 3];
1561                 res10 += bias[rhs_rows_idx + 1];
1562                 res11 += bias[rhs_rows_idx + 3];
1563             }
1564 
1565             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1566             {
1567                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1568                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1569                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1570                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1571 
1572                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1573                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1574 
1575                 res00 += lhs_low * rhs_low0;
1576                 res00 += lhs_high * rhs_high0;
1577                 res01 += lhs_low * rhs_low1;
1578                 res01 += lhs_high * rhs_high1;
1579 
1580                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1581                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1582 
1583                 res10 += lhs_low * rhs_low0;
1584                 res10 += lhs_high * rhs_high0;
1585                 res11 += lhs_low * rhs_low1;
1586                 res11 += lhs_high * rhs_high1;
1587 
1588                 ++packed_rhs_ptr;
1589                 lhs_ptr += 2;
1590             }
1591 
1592             // Quantize down
1593             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1594             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1595             res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1596             res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1597 
1598             // Add offset
1599             res00 += dst_offset;
1600             res01 += dst_offset;
1601             res10 += dst_offset;
1602             res11 += dst_offset;
1603 
1604             // Clamp the result
1605             res00 = MAX(res00, activation_min);
1606             res00 = MIN(res00, activation_max);
1607             res01 = MAX(res01, activation_min);
1608             res01 = MIN(res01, activation_max);
1609             res10 = MAX(res10, activation_min);
1610             res10 = MIN(res10, activation_max);
1611             res11 = MAX(res11, activation_min);
1612             res11 = MIN(res11, activation_max);
1613 
1614             dst_ptr[1] = (int8_t)res00;
1615             dst_ptr[3] = (int8_t)res01;
1616             dst_ptr += rhs_rows;
1617             dst_ptr[1] = (int8_t)res10;
1618             dst_ptr[3] = (int8_t)res11;
1619             dst_ptr += rhs_rows;
1620 
1621             lhs_ptr -= rhs_cols;
1622             lhs_ptr += 2 * lhs_cols_offset;
1623         }
1624 
1625         // Left-over row
1626         if (lhs_rows % 2)
1627         {
1628             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1629 
1630             int32_t res00 = 0;
1631             int32_t res01 = 0;
1632 
1633             int32_t spillover00 = 0;
1634             int32_t spillover01 = 0;
1635 
1636             if (bias)
1637             {
1638                 res00 += bias[rhs_rows_idx];
1639                 res01 += bias[rhs_rows_idx + 2];
1640             }
1641 
1642             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1643             {
1644                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1645                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1646 
1647                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1648                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1649 
1650                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1651                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1652 
1653                 res00 += lhs_low * rhs_low0;
1654                 res00 += lhs_high * rhs_high0;
1655 
1656                 res01 += lhs_low * rhs_low1;
1657                 res01 += lhs_high * rhs_high1;
1658 
1659                 ++packed_rhs_ptr;
1660                 lhs_ptr += 2;
1661             }
1662             if (rhs_cols % 2)
1663             {
1664                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1665                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1666                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1667                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1668 
1669                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1670                 lhs_ptr -= rhs_cols - 1;
1671                 int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
1672 
1673                 res00 += lhs_low * rhs_low0;
1674                 res01 += lhs_low * rhs_low1;
1675                 spillover00 = lhs_high * rhs_high0;
1676                 spillover01 = lhs_high * rhs_high1;
1677 
1678                 ++packed_rhs_ptr;
1679                 ++lhs_ptr;
1680             }
1681             else
1682             {
1683                 lhs_ptr -= rhs_cols;
1684             }
1685 
1686             // Quantize down
1687             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1688             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1689 
1690             // Add offset
1691             res00 += dst_offset;
1692             res01 += dst_offset;
1693 
1694             // Clamp the result
1695             res00 = MAX(res00, activation_min);
1696             res00 = MIN(res00, activation_max);
1697             res01 = MAX(res01, activation_min);
1698             res01 = MIN(res01, activation_max);
1699 
1700             dst_ptr[0] = (int8_t)res00;
1701             dst_ptr[2] = (int8_t)res01;
1702 
1703             res00 = spillover00;
1704             res01 = spillover01;
1705 
1706             if (bias)
1707             {
1708                 res00 += bias[rhs_rows_idx + 1];
1709                 res01 += bias[rhs_rows_idx + 3];
1710             }
1711 
1712             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1713             {
1714                 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1715                 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1716 
1717                 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1718                 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1719 
1720                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1721                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1722 
1723                 res00 += lhs_low * rhs_low0;
1724                 res00 += lhs_high * rhs_high0;
1725 
1726                 res01 += lhs_low * rhs_low1;
1727                 res01 += lhs_high * rhs_high1;
1728 
1729                 ++packed_rhs_ptr;
1730                 lhs_ptr += 2;
1731             }
1732 
1733             // Quantize down
1734             res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1735             res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1736 
1737             // Add offset
1738             res00 += dst_offset;
1739             res01 += dst_offset;
1740 
1741             // Clamp the result
1742             res00 = MAX(res00, activation_min);
1743             res00 = MIN(res00, activation_max);
1744             res01 = MAX(res01, activation_min);
1745             res01 = MIN(res01, activation_max);
1746 
1747             dst_ptr[1] = (int8_t)res00;
1748             dst_ptr[3] = (int8_t)res01;
1749         }
1750 
1751         packed_rhs += 2 * rhs_cols;
1752         dst += 4;
1753     }
1754 
1755     int32_t spillover00 = 0;
1756     const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1757     // Left over rhs rows will be in the range 0 -> 3
1758     for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1759     {
1760         const int8_t *lhs_ptr = &lhs[0];
1761         int8_t *dst_ptr = &dst[0];
1762 
1763         for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1764         {
1765             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1766             int32_t res00 = 0;
1767             int32_t res10 = 0;
1768             if (bias)
1769             {
1770                 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1771                 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1772             }
1773             // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1774             // differently by dealing with the leftover column from rhs_row_idx[0]
1775             if (rhs_cols % 2 && rhs_rows_idx == 1)
1776             {
1777                 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1778                 res00 += lhs_low * spillover00;
1779 
1780                 lhs_low = lhs_ptr[lhs_cols_offset] + lhs_offset;
1781                 res10 += lhs_low * spillover00;
1782 
1783                 ++lhs_ptr;
1784             }
1785             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1786             {
1787                 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1788                 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
1789 
1790                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1791                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1792 
1793                 res00 += lhs_low * rhs_low;
1794                 res00 += lhs_high * rhs_high;
1795 
1796                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1797                 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1798 
1799                 res10 += lhs_low * rhs_low;
1800                 res10 += lhs_high * rhs_high;
1801 
1802                 ++packed_rhs_ptr;
1803                 lhs_ptr += 2;
1804             }
1805             if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1806             {
1807                 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1808 
1809                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1810 
1811                 res00 += lhs_low * rhs_low;
1812 
1813                 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1814 
1815                 res10 += lhs_low * rhs_low;
1816 
1817                 ++lhs_ptr;
1818             }
1819 
1820             lhs_ptr -= rhs_cols;
1821             lhs_ptr += 2 * lhs_cols_offset;
1822 
1823             // Quantize down
1824             res00 = arm_nn_requantize(
1825                 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1826             res10 = arm_nn_requantize(
1827                 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1828 
1829             // Add offset
1830             res00 += dst_offset;
1831             res10 += dst_offset;
1832 
1833             // Clamp the result
1834             res00 = MAX(res00, activation_min);
1835             res00 = MIN(res00, activation_max);
1836             res10 = MAX(res10, activation_min);
1837             res10 = MIN(res10, activation_max);
1838 
1839             dst_ptr[0] = (int8_t)res00;
1840             dst_ptr += rhs_rows;
1841             dst_ptr[0] = (int8_t)res10;
1842             dst_ptr += rhs_rows;
1843         }
1844         if (lhs_rows % 2)
1845         {
1846             const int8_t *packed_rhs_ptr = &packed_rhs[0];
1847             int32_t res00 = 0;
1848             if (bias)
1849             {
1850                 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1851             }
1852             // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1853             // differently by dealing with the leftover column from rhs_row_idx[0]
1854             if (rhs_cols % 2 && rhs_rows_idx == 1)
1855             {
1856                 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1857                 res00 += lhs_low * spillover00;
1858 
1859                 ++lhs_ptr;
1860             }
1861             for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1862             {
1863                 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1864                 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
1865 
1866                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1867                 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1868 
1869                 res00 += lhs_low * rhs_low;
1870                 res00 += lhs_high * rhs_high;
1871 
1872                 ++packed_rhs_ptr;
1873                 lhs_ptr += 2;
1874             }
1875             if (rhs_cols % 2 && (rhs_rows_idx != 1))
1876             {
1877                 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1878 
1879                 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1880                 res00 += lhs_low * rhs_low;
1881 
1882                 ++lhs_ptr;
1883             }
1884 
1885             // Quantize down
1886             res00 = arm_nn_requantize(
1887                 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1888 
1889             // Add offset
1890             res00 += dst_offset;
1891 
1892             // Clamp the result
1893             res00 = MAX(res00, activation_min);
1894             res00 = MIN(res00, activation_max);
1895 
1896             dst_ptr[0] = (int8_t)res00;
1897         }
1898         if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1899         {
1900             spillover00 = packed_rhs[rhs_cols_int4] >> 4;
1901             packed_rhs += rhs_cols_int4 + (rhs_cols & 0x1);
1902         }
1903         else
1904         {
1905             spillover00 = 0;
1906             packed_rhs += rhs_cols_int4;
1907         }
1908 
1909         ++dst;
1910     }
1911 
1912 #endif
1913     return ARM_CMSIS_NN_SUCCESS;
1914 }
1915 
1916 /**
1917  * @} end of Doxygen group
1918  */