1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_nt_t_s4
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed, and 4 bit rhs.
23 *
24 * $Date: 27 May 2024
25 * $Revision: V.1.2.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s4 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s4(const int8_t * lhs,const int8_t * packed_rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
49 const int8_t *packed_rhs,
50 const int32_t *bias,
51 int8_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t activation_min,
60 const int32_t activation_max,
61 const int32_t lhs_cols_offset)
62 {
63 #if defined(ARM_MATH_MVEI)
64 int i_items = 0;
65 const int rhs_cols_offset = rhs_cols % 16;
66 const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101
67 const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
68 const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3};
69 const int I6_elements_spill = rhs_cols & 0x10;
70
71 for (; i_items <= (lhs_rows - 4); i_items += 4)
72 {
73 const int32_t blk_cnt = rhs_cols >> 4;
74 int8_t const *col_base = packed_rhs;
75
76 for (int i = 0; i < rhs_rows; i++)
77 {
78
79 int32_t acc_n0 = 0;
80 int32_t acc_n1 = 0;
81 int32_t acc_n2 = 0;
82 int32_t acc_n3 = 0;
83
84 int8_t const *ip_row_0 = lhs;
85 int8_t const *ip_row_1 = lhs + lhs_cols_offset;
86 int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
87 int8_t const *ip_row_3 = lhs + (3 * lhs_cols_offset);
88 int32_t sum_tmp = 0;
89
90 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
91
92 if ((rhs_cols & 0x1) & (i & 0x1))
93 {
94 rmdr_mask >>= 1;
95 int32_t col = col_base[0] >> 4;
96 sum_tmp = col;
97 acc_n0 += ip_row_0[0] * col;
98 acc_n1 += ip_row_1[0] * col;
99 acc_n2 += ip_row_2[0] * col;
100 acc_n3 += ip_row_3[0] * col;
101
102 ++col_base;
103 ++ip_row_0;
104 ++ip_row_1;
105 ++ip_row_2;
106 ++ip_row_3;
107 }
108
109 for (int j = blk_cnt; j > 0; --j)
110 {
111 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
112 col_base += 8;
113
114 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
115 col_vec = vshrq_n_s8(col_vec, 4);
116
117 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
118
119 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
120 ip_row_0 += 16;
121 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
122
123 lhs_vec = vldrbq_s8(ip_row_1);
124 ip_row_1 += 16;
125 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
126
127 lhs_vec = vldrbq_s8(ip_row_2);
128 ip_row_2 += 16;
129 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
130
131 lhs_vec = vldrbq_s8(ip_row_3);
132 ip_row_3 += 16;
133 acc_n3 = vmladavaq_s8(acc_n3, col_vec, lhs_vec);
134 }
135
136 if (rmdr_mask)
137 {
138 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
139 col_base += rhs_cols_offset >> 1;
140 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
141 col_vec = vshrq_n_s8(col_vec, 4);
142
143 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
144
145 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
146 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
147
148 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
149 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
150
151 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
152 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
153
154 lhs_vec = vldrbq_z_s8(ip_row_3, rmdr_mask);
155 acc_n3 = vmladavaq_p_s8(acc_n3, col_vec, lhs_vec, rmdr_mask);
156 }
157
158 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
159 sum_tmp *= lhs_offset;
160 if (bias)
161 {
162 sum_tmp += bias[i];
163 }
164 res = vaddq_n_s32(res, sum_tmp);
165
166 res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
167 res = vaddq_n_s32(res, dst_offset);
168
169 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
170 res = vminq_s32(res, vdupq_n_s32(activation_max));
171
172 vstrbq_scatter_offset_s32(dst, scatter_offset, res);
173 dst++;
174 }
175 lhs += 4 * lhs_cols_offset;
176 dst += (3 * rhs_rows);
177 }
178
179 for (; i_items <= (lhs_rows - 3); i_items += 3)
180 {
181 int8_t const *col_base = packed_rhs;
182 const mve_pred16_t requant_mask = vctp32q(3);
183 const int32_t blk_cnt = rhs_cols >> 4;
184
185 for (int i = 0; i < rhs_rows; i++)
186 {
187 int32_t acc_n0 = 0;
188 int32_t acc_n1 = 0;
189 int32_t acc_n2 = 0;
190
191 int8_t const *ip_row_0 = lhs;
192 int8_t const *ip_row_1 = lhs + lhs_cols_offset;
193 int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
194 int32_t sum_tmp = 0;
195
196 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
197
198 if ((rhs_cols & 0x1) & (i & 0x1))
199 {
200 rmdr_mask >>= 1;
201 int32_t col = col_base[0] >> 4;
202 sum_tmp = col;
203 acc_n0 += ip_row_0[0] * col;
204 acc_n1 += ip_row_1[0] * col;
205 acc_n2 += ip_row_2[0] * col;
206
207 ++col_base;
208 ++ip_row_0;
209 ++ip_row_1;
210 ++ip_row_2;
211 }
212
213 for (int j = blk_cnt; j > 0; --j)
214 {
215 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
216 col_base += 8;
217
218 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
219 col_vec = vshrq_n_s8(col_vec, 4);
220
221 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
222
223 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
224 ip_row_0 += 16;
225 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
226
227 lhs_vec = vldrbq_s8(ip_row_1);
228 ip_row_1 += 16;
229 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
230
231 lhs_vec = vldrbq_s8(ip_row_2);
232 ip_row_2 += 16;
233 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
234 }
235
236 if (rmdr_mask)
237 {
238 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
239 col_base += rhs_cols_offset >> 1;
240 col_vec = vrshlq_m_n_s8(col_vec, 4, (lower_nibble_mask & rmdr_mask));
241 col_vec = vshrq_n_s8(col_vec, 4);
242
243 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
244
245 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
246 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
247
248 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
249 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
250
251 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
252 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
253 }
254
255 int32x4_t res = {acc_n0, acc_n1, acc_n2, 0};
256 sum_tmp *= lhs_offset;
257 if (bias)
258 {
259 sum_tmp += bias[i];
260 }
261
262 res = vaddq_x_n_s32(res, sum_tmp, requant_mask);
263
264 res = arm_requantize_mve_pred(res, dst_multipliers[i], dst_shifts[i], requant_mask);
265 res = vaddq_x_n_s32(res, dst_offset, requant_mask);
266
267 res = vmaxq_x_s32(res, vdupq_n_s32(activation_min), requant_mask);
268 res = vminq_x_s32(res, vdupq_n_s32(activation_max), requant_mask);
269
270 vstrbq_scatter_offset_p_s32(dst, scatter_offset, res, requant_mask);
271 dst++;
272 }
273 lhs += 3 * lhs_cols_offset;
274 dst += (2 * rhs_rows);
275 }
276
277 for (; i_items <= (lhs_rows - 2); i_items += 2)
278 {
279 int8_t const *col_base = packed_rhs;
280 const mve_pred16_t requant_mask = vctp32q(2);
281 const int32_t blk_cnt = rhs_cols >> 5;
282
283 for (int i = 0; i < rhs_rows; i++)
284 {
285 int32_t acc_n0 = 0;
286 int32_t acc_n1 = 0;
287
288 int8_t const *ip_row_0 = lhs;
289 int8_t const *ip_row_1 = lhs + lhs_cols_offset;
290
291 int32_t sum_tmp = 0;
292
293 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
294
295 if ((rhs_cols & 0x1) & (i & 0x1))
296 {
297 rmdr_mask >>= 1;
298 int32_t col = col_base[0] >> 4;
299 sum_tmp = col;
300 acc_n0 += ip_row_0[0] * col;
301 acc_n1 += ip_row_1[0] * col;
302
303 ++col_base;
304 ++ip_row_0;
305 ++ip_row_1;
306 }
307
308 for (int j = blk_cnt; j > 0; --j)
309 {
310 const int8x16_t col_vec = vldrbq_s8(col_base);
311
312 int8x16_t ker_low = vrshlq_n_s8(col_vec, 4);
313 ker_low = vshrq_n_s8(ker_low, 4);
314 int8x16_t ker_high = vshrq_n_s8(col_vec, 4);
315
316 sum_tmp = vaddvaq_s8(sum_tmp, ker_low);
317 sum_tmp = vaddvaq_s8(sum_tmp, ker_high);
318
319 int8x16x2_t lhs_x2 = vld2q_s8(ip_row_0);
320
321 acc_n0 = vmladavaq_s8(acc_n0, ker_low, lhs_x2.val[0]);
322 acc_n0 = vmladavaq_s8(acc_n0, ker_high, lhs_x2.val[1]);
323
324 lhs_x2 = vld2q_s8(ip_row_1);
325 acc_n1 = vmladavaq_s8(acc_n1, ker_low, lhs_x2.val[0]);
326 acc_n1 = vmladavaq_s8(acc_n1, ker_high, lhs_x2.val[1]);
327
328 ip_row_0 += 32;
329 ip_row_1 += 32;
330 col_base += 16;
331 }
332
333 if (I6_elements_spill)
334 {
335 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
336 col_base += 8;
337
338 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
339 col_vec = vshrq_n_s8(col_vec, 4);
340
341 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
342
343 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
344 ip_row_0 += 16;
345 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
346
347 lhs_vec = vldrbq_s8(ip_row_1);
348 ip_row_1 += 16;
349 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
350 }
351
352 if (rmdr_mask)
353 {
354 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
355 col_base += rhs_cols_offset >> 1;
356 col_vec = vrshlq_m_n_s8(col_vec, 4, (lower_nibble_mask & rmdr_mask));
357 col_vec = vshrq_n_s8(col_vec, 4);
358
359 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
360
361 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
362 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
363
364 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
365 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
366 }
367
368 int32x4_t res = {acc_n0, acc_n1, 0, 0};
369 sum_tmp *= lhs_offset;
370 if (bias)
371 {
372 sum_tmp += bias[i];
373 }
374
375 res = vaddq_x_n_s32(res, sum_tmp, requant_mask);
376
377 res = arm_requantize_mve_pred(res, dst_multipliers[i], dst_shifts[i], requant_mask);
378 res = vaddq_x_n_s32(res, dst_offset, requant_mask);
379
380 res = vmaxq_x_s32(res, vdupq_n_s32(activation_min), requant_mask);
381 res = vminq_x_s32(res, vdupq_n_s32(activation_max), requant_mask);
382
383 vstrbq_scatter_offset_p_s32(dst, scatter_offset, res, requant_mask);
384 dst++;
385 }
386 lhs += 2 * lhs_cols_offset;
387 dst += (2 * rhs_rows);
388 }
389
390 for (; i_items < lhs_rows; i_items++)
391 {
392 int32_t acc[4];
393 const int32_t *multipliers = dst_multipliers;
394 const int32_t *shifts = dst_shifts;
395 const int8_t *col_base = packed_rhs;
396 const int32_t blk_cnt = rhs_cols >> 5;
397 int col_inc = rhs_cols_offset >> 1;
398
399 for (int i = 0; i < rhs_rows; i++)
400 {
401 int32_t acc_n0 = 0;
402 const int8_t *ip_row_0 = lhs;
403 int32_t sum_tmp = 0;
404 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
405
406 if ((rhs_cols & 0x1) & (i & 0x1))
407 {
408 rmdr_mask >>= 1;
409 int32_t col = col_base[0] >> 4;
410 sum_tmp += col;
411 acc_n0 += ip_row_0[0] * col;
412
413 ++col_base;
414 ++ip_row_0;
415 }
416
417 for (int j = blk_cnt; j > 0; --j)
418 {
419 const int8x16_t col_vec = vldrbq_s8(col_base);
420
421 int8x16_t ker_low = vrshlq_n_s8(col_vec, 4);
422 ker_low = vshrq_n_s8(ker_low, 4);
423 int8x16_t ker_high = vshrq_n_s8(col_vec, 4);
424
425 sum_tmp = vaddvaq_s8(sum_tmp, ker_low);
426 sum_tmp = vaddvaq_s8(sum_tmp, ker_high);
427
428 int8x16x2_t lhs_x2 = vld2q_s8(ip_row_0);
429
430 acc_n0 = vmladavaq_s8(acc_n0, ker_low, lhs_x2.val[0]);
431 acc_n0 = vmladavaq_s8(acc_n0, ker_high, lhs_x2.val[1]);
432
433 ip_row_0 += 32;
434 col_base += 16;
435 }
436
437 if (I6_elements_spill)
438 {
439 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
440 col_base += 8;
441
442 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
443 col_vec = vshrq_n_s8(col_vec, 4);
444
445 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
446
447 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
448 ip_row_0 += 16;
449 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
450 }
451
452 if (rmdr_mask)
453 {
454 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
455 col_base += col_inc;
456 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
457 col_vec = vshrq_n_s8(col_vec, 4);
458
459 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
460
461 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
462 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
463 }
464
465 sum_tmp *= lhs_offset;
466 sum_tmp += acc_n0;
467 if (bias)
468 {
469 sum_tmp += bias[i];
470 }
471 const int32_t index = i & 0x3;
472 acc[index] = sum_tmp;
473
474 if (index == 3)
475 {
476 int32x4_t res = vldrwq_s32(acc);
477 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
478 multipliers += 4;
479 shifts += 4;
480 res = vaddq_n_s32(res, dst_offset);
481 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
482 res = vminq_s32(res, vdupq_n_s32(activation_max));
483 vstrbq_s32(dst, res);
484 dst += 4;
485 }
486 }
487 lhs += lhs_cols_offset;
488 const int32_t tail_rows = rhs_rows & 0x3;
489 for (int i = 0; i < tail_rows; i++)
490 {
491 int32_t acc_n0 = acc[i];
492 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
493 acc_n0 += dst_offset;
494 acc_n0 = MAX(acc_n0, activation_min);
495 acc_n0 = MIN(acc_n0, activation_max);
496 *dst++ = (int8_t)acc_n0;
497 }
498 }
499
500 #elif defined(ARM_MATH_DSP)
501 const int32_t lhs_cols_off1 = lhs_cols_offset - 4;
502 const int16_t i16_lhs_offset = (int16_t)lhs_offset;
503 const uint32_t ui32_lhs_offset_i16x2 = PKHBT(i16_lhs_offset, i16_lhs_offset, 16);
504 const int32_t rhs_cols_int4 = rhs_cols >> 1;
505
506 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
507 {
508
509 const int8_t *lhs_ptr = &lhs[0];
510 int8_t *dst_ptr = &dst[0];
511
512 int32_t lhs_rows_idx = lhs_rows >> 1;
513 while (lhs_rows_idx)
514 {
515 const int8_t *packed_rhs_ptr = &packed_rhs[0];
516
517 int32_t res00 = 0;
518 int32_t res01 = 0;
519 int32_t res10 = 0;
520 int32_t res11 = 0;
521
522 int32_t spillover00 = 0;
523 int32_t spillover01 = 0;
524 int32_t spillover10 = 0;
525 int32_t spillover11 = 0;
526
527 if (bias)
528 {
529 res00 = bias[rhs_rows_idx];
530 res01 = bias[rhs_rows_idx + 2];
531 res10 = bias[rhs_rows_idx];
532 res11 = bias[rhs_rows_idx + 2];
533 spillover00 = bias[rhs_rows_idx + 1];
534 spillover01 = bias[rhs_rows_idx + 3];
535 spillover10 = bias[rhs_rows_idx + 1];
536 spillover11 = bias[rhs_rows_idx + 3];
537 }
538
539 int32_t rhs_cols_idx = 0;
540
541 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
542
543 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
544 {
545 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
546 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
547 packed_rhs_ptr += 2;
548
549 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
550 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
551 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
552
553 // 4 x MAC res00, res01
554 res00 = SMLAD(rhs_low0, lhs_low, res00);
555 res00 = SMLAD(rhs_high0, lhs_high, res00);
556 res01 = SMLAD(rhs_low1, lhs_low, res01);
557 res01 = SMLAD(rhs_high1, lhs_high, res01);
558
559 // 4 x MAC res10, res11
560 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
561 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
562 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
563
564 res10 = SMLAD(rhs_low0, lhs_low, res10);
565 res11 = SMLAD(rhs_low1, lhs_low, res11);
566 res10 = SMLAD(rhs_high0, lhs_high, res10);
567 res11 = SMLAD(rhs_high1, lhs_high, res11);
568
569 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
570 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
571 packed_rhs_ptr += 2;
572
573 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
574 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
575 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
576
577 // 4 x MAC res00, res01
578 res00 = SMLAD(rhs_low0, lhs_low, res00);
579 res00 = SMLAD(rhs_high0, lhs_high, res00);
580 res01 = SMLAD(rhs_low1, lhs_low, res01);
581 res01 = SMLAD(rhs_high1, lhs_high, res01);
582
583 // 4 x MAC res10, res11
584 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
585 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
586 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
587
588 res10 = SMLAD(rhs_low0, lhs_low, res10);
589 res11 = SMLAD(rhs_low1, lhs_low, res11);
590 res10 = SMLAD(rhs_high0, lhs_high, res10);
591 res11 = SMLAD(rhs_high1, lhs_high, res11);
592
593 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
594 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
595 packed_rhs_ptr += 2;
596
597 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
598 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
599 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
600
601 // 4 x MAC res00, res01
602 res00 = SMLAD(rhs_low0, lhs_low, res00);
603 res00 = SMLAD(rhs_high0, lhs_high, res00);
604 res01 = SMLAD(rhs_low1, lhs_low, res01);
605 res01 = SMLAD(rhs_high1, lhs_high, res01);
606
607 // 4 x MAC res10, res11
608 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
609 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
610 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
611
612 res10 = SMLAD(rhs_low0, lhs_low, res10);
613 res11 = SMLAD(rhs_low1, lhs_low, res11);
614 res10 = SMLAD(rhs_high0, lhs_high, res10);
615 res11 = SMLAD(rhs_high1, lhs_high, res11);
616
617 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
618 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
619 packed_rhs_ptr += 2;
620
621 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
622 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
623 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
624
625 // 4 x MAC res00, res01
626 res00 = SMLAD(rhs_low0, lhs_low, res00);
627 res00 = SMLAD(rhs_high0, lhs_high, res00);
628 res01 = SMLAD(rhs_low1, lhs_low, res01);
629 res01 = SMLAD(rhs_high1, lhs_high, res01);
630
631 // 4 x MAC res10, res11
632 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
633 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
634 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
635
636 res10 = SMLAD(rhs_low0, lhs_low, res10);
637 res11 = SMLAD(rhs_low1, lhs_low, res11);
638 res10 = SMLAD(rhs_high0, lhs_high, res10);
639 res11 = SMLAD(rhs_high1, lhs_high, res11);
640 }
641 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
642 {
643 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
644 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
645 packed_rhs_ptr += 2;
646
647 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
648 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
649 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
650
651 // 4 x MAC res00, res01
652 res00 = SMLAD(rhs_low0, lhs_low, res00);
653 res00 = SMLAD(rhs_high0, lhs_high, res00);
654 res01 = SMLAD(rhs_low1, lhs_low, res01);
655 res01 = SMLAD(rhs_high1, lhs_high, res01);
656
657 // 4 x MAC res10, res11
658 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
659 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
660 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
661 res10 = SMLAD(rhs_low0, lhs_low, res10);
662 res11 = SMLAD(rhs_low1, lhs_low, res11);
663 res10 = SMLAD(rhs_high0, lhs_high, res10);
664 res11 = SMLAD(rhs_high1, lhs_high, res11);
665 }
666 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
667 {
668 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
669 rhs_high0 = packed_rhs_ptr[0] >> 4;
670
671 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
672 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
673
674 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
675 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
676
677 res00 += lhs_low * rhs_low0;
678 res00 += lhs_high * rhs_high0;
679 res01 += lhs_low * rhs_low1;
680 res01 += lhs_high * rhs_high1;
681
682 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
683 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
684 res10 += lhs_low * rhs_low0;
685 res10 += lhs_high * rhs_high0;
686 res11 += lhs_low * rhs_low1;
687 res11 += lhs_high * rhs_high1;
688
689 ++packed_rhs_ptr;
690 lhs_ptr += 2;
691 }
692 if (rhs_cols % 2)
693 {
694 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
695 rhs_high0 = packed_rhs_ptr[0] >> 4;
696 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
697 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
698
699 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
700 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
701
702 res00 += lhs_low * rhs_low0;
703 res01 += lhs_low * rhs_low1;
704
705 res10 += lhs_high * rhs_low0;
706 res11 += lhs_high * rhs_low1;
707
708 lhs_ptr -= rhs_cols - 1;
709
710 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
711 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
712
713 spillover00 += lhs_low * rhs_high0;
714 spillover01 += lhs_low * rhs_high1;
715
716 spillover10 += lhs_high * rhs_high0;
717 spillover11 += lhs_high * rhs_high1;
718
719 ++packed_rhs_ptr;
720 ++lhs_ptr;
721 }
722 else
723 {
724 lhs_ptr -= rhs_cols;
725 }
726
727 // Quantize down
728 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
729 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
730 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
731 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
732
733 // Add offset
734 res00 += dst_offset;
735 res01 += dst_offset;
736 res10 += dst_offset;
737 res11 += dst_offset;
738
739 // Clamp the result
740 res00 = MAX(res00, activation_min);
741 res00 = MIN(res00, activation_max);
742 res01 = MAX(res01, activation_min);
743 res01 = MIN(res01, activation_max);
744 res10 = MAX(res10, activation_min);
745 res10 = MIN(res10, activation_max);
746 res11 = MAX(res11, activation_min);
747 res11 = MIN(res11, activation_max);
748
749 dst_ptr[0] = (int8_t)res00;
750 dst_ptr[2] = (int8_t)res01;
751 dst_ptr += rhs_rows;
752 dst_ptr[0] = (int8_t)res10;
753 dst_ptr[2] = (int8_t)res11;
754 dst_ptr -= rhs_rows;
755
756 res00 = spillover00;
757 res01 = spillover01;
758 res10 = spillover10;
759 res11 = spillover11;
760
761 rhs_cols_idx = 0;
762
763 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
764 {
765 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
766 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
767 packed_rhs_ptr += 2;
768
769 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
770 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
771 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
772
773 // 4 x MAC res00, res01
774 res00 = SMLAD(rhs_low0, lhs_low, res00);
775 res00 = SMLAD(rhs_high0, lhs_high, res00);
776 res01 = SMLAD(rhs_low1, lhs_low, res01);
777 res01 = SMLAD(rhs_high1, lhs_high, res01);
778
779 // 4 x MAC res10, res11
780 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
781 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
782 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
783
784 res10 = SMLAD(rhs_low0, lhs_low, res10);
785 res11 = SMLAD(rhs_low1, lhs_low, res11);
786 res10 = SMLAD(rhs_high0, lhs_high, res10);
787 res11 = SMLAD(rhs_high1, lhs_high, res11);
788
789 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
790 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
791 packed_rhs_ptr += 2;
792
793 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
794 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
795 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
796
797 // 4 x MAC res00, res01
798 res00 = SMLAD(rhs_low0, lhs_low, res00);
799 res00 = SMLAD(rhs_high0, lhs_high, res00);
800 res01 = SMLAD(rhs_low1, lhs_low, res01);
801 res01 = SMLAD(rhs_high1, lhs_high, res01);
802
803 // 4 x MAC res10, res11
804 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
805 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
806 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
807
808 res10 = SMLAD(rhs_low0, lhs_low, res10);
809 res11 = SMLAD(rhs_low1, lhs_low, res11);
810 res10 = SMLAD(rhs_high0, lhs_high, res10);
811 res11 = SMLAD(rhs_high1, lhs_high, res11);
812
813 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
814 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
815 packed_rhs_ptr += 2;
816
817 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
818 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
819 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
820
821 // 4 x MAC res00, res01
822 res00 = SMLAD(rhs_low0, lhs_low, res00);
823 res00 = SMLAD(rhs_high0, lhs_high, res00);
824 res01 = SMLAD(rhs_low1, lhs_low, res01);
825 res01 = SMLAD(rhs_high1, lhs_high, res01);
826
827 // 4 x MAC res10, res11
828 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
829 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
830 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
831
832 res10 = SMLAD(rhs_low0, lhs_low, res10);
833 res11 = SMLAD(rhs_low1, lhs_low, res11);
834 res10 = SMLAD(rhs_high0, lhs_high, res10);
835 res11 = SMLAD(rhs_high1, lhs_high, res11);
836
837 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
838 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
839 packed_rhs_ptr += 2;
840
841 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
842 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
843 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
844
845 // 4 x MAC res00, res01
846 res00 = SMLAD(rhs_low0, lhs_low, res00);
847 res00 = SMLAD(rhs_high0, lhs_high, res00);
848 res01 = SMLAD(rhs_low1, lhs_low, res01);
849 res01 = SMLAD(rhs_high1, lhs_high, res01);
850
851 // 4 x MAC res10, res11
852 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
853 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
854 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
855
856 res10 = SMLAD(rhs_low0, lhs_low, res10);
857 res11 = SMLAD(rhs_low1, lhs_low, res11);
858 res10 = SMLAD(rhs_high0, lhs_high, res10);
859 res11 = SMLAD(rhs_high1, lhs_high, res11);
860 }
861 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
862 {
863 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
864 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
865 packed_rhs_ptr += 2;
866
867 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
868 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
869 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
870
871 // 4 x MAC res00, res01
872 res00 = SMLAD(rhs_low0, lhs_low, res00);
873 res00 = SMLAD(rhs_high0, lhs_high, res00);
874 res01 = SMLAD(rhs_low1, lhs_low, res01);
875 res01 = SMLAD(rhs_high1, lhs_high, res01);
876
877 // 4 x MAC res10, res11
878 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
879 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
880 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
881 res10 = SMLAD(rhs_low0, lhs_low, res10);
882 res11 = SMLAD(rhs_low1, lhs_low, res11);
883 res10 = SMLAD(rhs_high0, lhs_high, res10);
884 res11 = SMLAD(rhs_high1, lhs_high, res11);
885 }
886 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
887 {
888 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
889 rhs_high0 = packed_rhs_ptr[0] >> 4;
890
891 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
892 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
893
894 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
895 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
896
897 res00 += lhs_low * rhs_low0;
898 res00 += lhs_high * rhs_high0;
899 res01 += lhs_low * rhs_low1;
900 res01 += lhs_high * rhs_high1;
901
902 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
903 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
904 res10 += lhs_low * rhs_low0;
905 res10 += lhs_high * rhs_high0;
906 res11 += lhs_low * rhs_low1;
907 res11 += lhs_high * rhs_high1;
908
909 ++packed_rhs_ptr;
910 lhs_ptr += 2;
911 }
912
913 // Quantize down
914 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
915 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
916 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
917 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
918
919 // Add offset
920 res00 += dst_offset;
921 res01 += dst_offset;
922 res10 += dst_offset;
923 res11 += dst_offset;
924
925 // Clamp the result
926 res00 = MAX(res00, activation_min);
927 res00 = MIN(res00, activation_max);
928 res01 = MAX(res01, activation_min);
929 res01 = MIN(res01, activation_max);
930 res10 = MAX(res10, activation_min);
931 res10 = MIN(res10, activation_max);
932 res11 = MAX(res11, activation_min);
933 res11 = MIN(res11, activation_max);
934
935 dst_ptr[1] = (int8_t)res00;
936 dst_ptr[3] = (int8_t)res01;
937 dst_ptr += rhs_rows;
938 dst_ptr[1] = (int8_t)res10;
939 dst_ptr[3] = (int8_t)res11;
940 dst_ptr += rhs_rows;
941
942 lhs_ptr -= rhs_cols;
943 lhs_ptr += 2 * lhs_cols_offset;
944
945 lhs_rows_idx--;
946 }
947
948 // Left-over rows
949 if (lhs_rows % 2)
950 {
951 const int8_t *packed_rhs_ptr = &packed_rhs[0];
952
953 int32_t res00 = 0;
954 int32_t res01 = 0;
955
956 int32_t spillover00 = 0;
957 int32_t spillover01 = 0;
958
959 if (bias)
960 {
961 res00 = bias[rhs_rows_idx];
962 spillover00 = bias[rhs_rows_idx + 1];
963 res01 = bias[rhs_rows_idx + 2];
964 spillover01 = bias[rhs_rows_idx + 3];
965 }
966
967 int32_t rhs_cols_idx = 0;
968
969 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
970
971 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
972 {
973 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
974 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
975 packed_rhs_ptr += 2;
976
977 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
978 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
979 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
980
981 // 4 x MAC res00, res01
982 res00 = SMLAD(rhs_low0, lhs_low, res00);
983 res00 = SMLAD(rhs_high0, lhs_high, res00);
984 res01 = SMLAD(rhs_low1, lhs_low, res01);
985 res01 = SMLAD(rhs_high1, lhs_high, res01);
986
987 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
988 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
989 packed_rhs_ptr += 2;
990
991 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
992 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
993 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
994
995 // 4 x MAC res00, res01
996 res00 = SMLAD(rhs_low0, lhs_low, res00);
997 res00 = SMLAD(rhs_high0, lhs_high, res00);
998 res01 = SMLAD(rhs_low1, lhs_low, res01);
999 res01 = SMLAD(rhs_high1, lhs_high, res01);
1000
1001 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1002 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1003 packed_rhs_ptr += 2;
1004
1005 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1006 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1007 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1008
1009 // 4 x MAC res00, res01
1010 res00 = SMLAD(rhs_low0, lhs_low, res00);
1011 res00 = SMLAD(rhs_high0, lhs_high, res00);
1012 res01 = SMLAD(rhs_low1, lhs_low, res01);
1013 res01 = SMLAD(rhs_high1, lhs_high, res01);
1014
1015 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1016 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1017 packed_rhs_ptr += 2;
1018
1019 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1020 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1021 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1022
1023 // 4 x MAC res00, res01
1024 res00 = SMLAD(rhs_low0, lhs_low, res00);
1025 res00 = SMLAD(rhs_high0, lhs_high, res00);
1026 res01 = SMLAD(rhs_low1, lhs_low, res01);
1027 res01 = SMLAD(rhs_high1, lhs_high, res01);
1028 }
1029 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1030 {
1031 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1032 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1033 packed_rhs_ptr += 2;
1034
1035 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1036 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1037 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1038
1039 // 4 x MAC res00, res01
1040 res00 = SMLAD(rhs_low0, lhs_low, res00);
1041 res00 = SMLAD(rhs_high0, lhs_high, res00);
1042 res01 = SMLAD(rhs_low1, lhs_low, res01);
1043 res01 = SMLAD(rhs_high1, lhs_high, res01);
1044 }
1045 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1046 {
1047 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1048 rhs_high0 = packed_rhs_ptr[0] >> 4;
1049
1050 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1051 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1052
1053 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1054 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1055
1056 res00 += lhs_low * rhs_low0;
1057 res00 += lhs_high * rhs_high0;
1058 res01 += lhs_low * rhs_low1;
1059 res01 += lhs_high * rhs_high1;
1060
1061 ++packed_rhs_ptr;
1062 lhs_ptr += 2;
1063 }
1064 if (rhs_cols % 2)
1065 {
1066 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1067 rhs_high0 = packed_rhs_ptr[0] >> 4;
1068 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1069 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1070
1071 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1072 lhs_ptr -= rhs_cols - 1;
1073 lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
1074
1075 res00 += lhs_low * rhs_low0;
1076 res01 += lhs_low * rhs_low1;
1077 spillover00 += lhs_high * rhs_high0;
1078 spillover01 += lhs_high * rhs_high1;
1079
1080 ++packed_rhs_ptr;
1081 ++lhs_ptr;
1082 }
1083 else
1084 {
1085 lhs_ptr -= rhs_cols;
1086 }
1087
1088 // Quantize down
1089 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1090 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1091
1092 // Add offset
1093 res00 += dst_offset;
1094 res01 += dst_offset;
1095
1096 // Clamp the result
1097 res00 = MAX(res00, activation_min);
1098 res00 = MIN(res00, activation_max);
1099 res01 = MAX(res01, activation_min);
1100 res01 = MIN(res01, activation_max);
1101
1102 dst_ptr[0] = (int8_t)res00;
1103 dst_ptr[2] = (int8_t)res01;
1104
1105 res00 = spillover00;
1106 res01 = spillover01;
1107
1108 rhs_cols_idx = 0;
1109
1110 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1111 {
1112 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1113 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1114 packed_rhs_ptr += 2;
1115
1116 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1117 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1118 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1119
1120 // 4 x MAC res00, res01
1121 res00 = SMLAD(rhs_low0, lhs_low, res00);
1122 res00 = SMLAD(rhs_high0, lhs_high, res00);
1123 res01 = SMLAD(rhs_low1, lhs_low, res01);
1124 res01 = SMLAD(rhs_high1, lhs_high, res01);
1125
1126 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1127 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1128 packed_rhs_ptr += 2;
1129
1130 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1131 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1132 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1133
1134 // 4 x MAC res00, res01
1135 res00 = SMLAD(rhs_low0, lhs_low, res00);
1136 res00 = SMLAD(rhs_high0, lhs_high, res00);
1137 res01 = SMLAD(rhs_low1, lhs_low, res01);
1138 res01 = SMLAD(rhs_high1, lhs_high, res01);
1139
1140 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1141 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1142 packed_rhs_ptr += 2;
1143
1144 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1145 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1146 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1147
1148 // 4 x MAC res00, res01
1149 res00 = SMLAD(rhs_low0, lhs_low, res00);
1150 res00 = SMLAD(rhs_high0, lhs_high, res00);
1151 res01 = SMLAD(rhs_low1, lhs_low, res01);
1152 res01 = SMLAD(rhs_high1, lhs_high, res01);
1153
1154 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1155 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1156 packed_rhs_ptr += 2;
1157
1158 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1159 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1160 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1161
1162 // 4 x MAC res00, res01
1163 res00 = SMLAD(rhs_low0, lhs_low, res00);
1164 res00 = SMLAD(rhs_high0, lhs_high, res00);
1165 res01 = SMLAD(rhs_low1, lhs_low, res01);
1166 res01 = SMLAD(rhs_high1, lhs_high, res01);
1167 }
1168 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1169 {
1170 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1171 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1172 packed_rhs_ptr += 2;
1173
1174 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1175 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1176 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1177
1178 // 4 x MAC res00, res01
1179 res00 = SMLAD(rhs_low0, lhs_low, res00);
1180 res00 = SMLAD(rhs_high0, lhs_high, res00);
1181 res01 = SMLAD(rhs_low1, lhs_low, res01);
1182 res01 = SMLAD(rhs_high1, lhs_high, res01);
1183 }
1184 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1185 {
1186 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1187 rhs_high0 = packed_rhs_ptr[0] >> 4;
1188
1189 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1190 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1191
1192 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1193 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1194
1195 res00 += lhs_low * rhs_low0;
1196 res00 += lhs_high * rhs_high0;
1197 res01 += lhs_low * rhs_low1;
1198 res01 += lhs_high * rhs_high1;
1199
1200 ++packed_rhs_ptr;
1201 lhs_ptr += 2;
1202 }
1203
1204 // Quantize down
1205 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1206 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1207
1208 // Add offset
1209 res00 += dst_offset;
1210 res01 += dst_offset;
1211
1212 // Clamp the result
1213 res00 = MAX(res00, activation_min);
1214 res00 = MIN(res00, activation_max);
1215 res01 = MAX(res01, activation_min);
1216 res01 = MIN(res01, activation_max);
1217
1218 dst_ptr[1] = (int8_t)res00;
1219 dst_ptr[3] = (int8_t)res01;
1220 }
1221
1222 packed_rhs += 2 * rhs_cols;
1223 dst += 4;
1224 }
1225
1226 int8_t rhs_spilled_col = 0;
1227 const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1228 // Left over rhs rows will be in the range 0 -> 3
1229 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1230 {
1231 const int8_t *lhs_ptr = &lhs[0];
1232 int8_t *dst_ptr = &dst[0];
1233
1234 int32_t lhs_rows_idx = lhs_rows >> 1;
1235 while (lhs_rows_idx)
1236 {
1237 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1238
1239 int32_t res00 = 0;
1240 int32_t res10 = 0;
1241
1242 if (bias)
1243 {
1244 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1245 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1246 }
1247
1248 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1249 // differently by dealing with the leftover column from rhs_row_idx[0]
1250 if (rhs_cols % 2 && rhs_rows_idx == 1)
1251 {
1252 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1253 res00 += lhs_low * rhs_spilled_col;
1254
1255 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1256 res10 += lhs_low * rhs_spilled_col;
1257
1258 ++lhs_ptr;
1259 }
1260
1261 int32_t rhs_cols_idx = 0;
1262
1263 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1264
1265 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1266 {
1267 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1268 packed_rhs_ptr += 2;
1269
1270 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1271 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1272 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1273
1274 // 2 x MAC res00
1275 res00 = SMLAD(rhs_low0, lhs_low, res00);
1276 res00 = SMLAD(rhs_high0, lhs_high, res00);
1277
1278 // 2 x MAC res10
1279 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1280 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1281 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1282
1283 res10 = SMLAD(rhs_low0, lhs_low, res10);
1284 res10 = SMLAD(rhs_high0, lhs_high, res10);
1285
1286 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1287 packed_rhs_ptr += 2;
1288
1289 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1290 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1291 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1292
1293 // 2 x MAC res00
1294 res00 = SMLAD(rhs_low0, lhs_low, res00);
1295 res00 = SMLAD(rhs_high0, lhs_high, res00);
1296
1297 // 2 x MAC res10
1298 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1299 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1300 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1301
1302 res10 = SMLAD(rhs_low0, lhs_low, res10);
1303 res10 = SMLAD(rhs_high0, lhs_high, res10);
1304
1305 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1306 packed_rhs_ptr += 2;
1307
1308 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1309 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1310 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1311
1312 // 2 x MAC res00
1313 res00 = SMLAD(rhs_low0, lhs_low, res00);
1314 res00 = SMLAD(rhs_high0, lhs_high, res00);
1315
1316 // 2 x MAC res10
1317 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1318 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1319 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1320
1321 res10 = SMLAD(rhs_low0, lhs_low, res10);
1322 res10 = SMLAD(rhs_high0, lhs_high, res10);
1323
1324 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1325 packed_rhs_ptr += 2;
1326
1327 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1328 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1329 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1330
1331 // 2 x MAC res00
1332 res00 = SMLAD(rhs_low0, lhs_low, res00);
1333 res00 = SMLAD(rhs_high0, lhs_high, res00);
1334
1335 // 2 x MAC res10
1336 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1337 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1338 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1339
1340 res10 = SMLAD(rhs_low0, lhs_low, res10);
1341 res10 = SMLAD(rhs_high0, lhs_high, res10);
1342 }
1343
1344 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1345 {
1346 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1347 packed_rhs_ptr += 2;
1348
1349 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1350 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1351 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1352
1353 // 2 x MAC res00
1354 res00 = SMLAD(rhs_low0, lhs_low, res00);
1355 res00 = SMLAD(rhs_high0, lhs_high, res00);
1356
1357 // 2 x MAC res10
1358 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1359 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1360 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1361 res10 = SMLAD(rhs_low0, lhs_low, res10);
1362 res10 = SMLAD(rhs_high0, lhs_high, res10);
1363 }
1364
1365 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1366 {
1367 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1368 rhs_high0 = packed_rhs_ptr[0] >> 4;
1369
1370 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1371 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1372
1373 res00 += lhs_low * rhs_low0;
1374 res00 += lhs_high * rhs_high0;
1375
1376 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1377 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1378 res10 += lhs_low * rhs_low0;
1379 res10 += lhs_high * rhs_high0;
1380
1381 ++packed_rhs_ptr;
1382 lhs_ptr += 2;
1383 }
1384
1385 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1386 {
1387 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1388
1389 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1390 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1391
1392 res00 += lhs_low * rhs_low0;
1393 res10 += lhs_high * rhs_low0;
1394
1395 ++lhs_ptr;
1396 }
1397
1398 lhs_ptr -= rhs_cols;
1399 lhs_ptr += 2 * lhs_cols_offset;
1400
1401 // Quantize down
1402 res00 = arm_nn_requantize(
1403 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1404 res10 = arm_nn_requantize(
1405 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1406
1407 // Add offset
1408 res00 += dst_offset;
1409 res10 += dst_offset;
1410
1411 // Clamp the result
1412 res00 = MAX(res00, activation_min);
1413 res00 = MIN(res00, activation_max);
1414 res10 = MAX(res10, activation_min);
1415 res10 = MIN(res10, activation_max);
1416
1417 dst_ptr[0] = (int8_t)res00;
1418 dst_ptr += rhs_rows;
1419 dst_ptr[0] = (int8_t)res10;
1420 dst_ptr += rhs_rows;
1421
1422 lhs_rows_idx--;
1423 }
1424 if (lhs_rows % 2)
1425 {
1426 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1427
1428 int32_t res00 = 0;
1429
1430 if (bias)
1431 {
1432 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1433 }
1434
1435 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1436 // differently by dealing with the leftover column from rhs_row_idx[0]
1437 if (rhs_cols % 2 && rhs_rows_idx == 1)
1438 {
1439 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1440 res00 += lhs_low * rhs_spilled_col;
1441
1442 ++lhs_ptr;
1443 }
1444
1445 int32_t rhs_cols_idx = 0;
1446
1447 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1448
1449 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1450 {
1451 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1452 packed_rhs_ptr += 2;
1453
1454 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1455 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1456 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1457
1458 // 2 x MAC res00
1459 res00 = SMLAD(rhs_low0, lhs_low, res00);
1460 res00 = SMLAD(rhs_high0, lhs_high, res00);
1461
1462 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1463 packed_rhs_ptr += 2;
1464
1465 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1466 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1467 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1468
1469 // 2 x MAC res00
1470 res00 = SMLAD(rhs_low0, lhs_low, res00);
1471 res00 = SMLAD(rhs_high0, lhs_high, res00);
1472
1473 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1474 packed_rhs_ptr += 2;
1475
1476 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1477 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1478 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1479
1480 // 2 x MAC res00
1481 res00 = SMLAD(rhs_low0, lhs_low, res00);
1482 res00 = SMLAD(rhs_high0, lhs_high, res00);
1483
1484 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1485 packed_rhs_ptr += 2;
1486
1487 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1488 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1489 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1490
1491 // 2 x MAC res00
1492 res00 = SMLAD(rhs_low0, lhs_low, res00);
1493 res00 = SMLAD(rhs_high0, lhs_high, res00);
1494 }
1495
1496 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1497 {
1498 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1499 packed_rhs_ptr += 2;
1500
1501 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1502 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1503 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1504
1505 // 2 x MAC res00
1506 res00 = SMLAD(rhs_low0, lhs_low, res00);
1507 res00 = SMLAD(rhs_high0, lhs_high, res00);
1508 }
1509
1510 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1511 {
1512 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1513 rhs_high0 = packed_rhs_ptr[0] >> 4;
1514
1515 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1516 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1517
1518 res00 += lhs_low * rhs_low0;
1519 res00 += lhs_high * rhs_high0;
1520
1521 ++packed_rhs_ptr;
1522 lhs_ptr += 2;
1523 }
1524
1525 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1526 {
1527 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1528
1529 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1530 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1531
1532 res00 += lhs_low * rhs_low0;
1533
1534 ++lhs_ptr;
1535 }
1536
1537 // Quantize down
1538 res00 = arm_nn_requantize(
1539 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1540
1541 // Add offset
1542 res00 += dst_offset;
1543
1544 // Clamp the result
1545 res00 = MAX(res00, activation_min);
1546 res00 = MIN(res00, activation_max);
1547
1548 dst_ptr[0] = (int8_t)res00;
1549 }
1550 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1551 {
1552 rhs_spilled_col = packed_rhs[rhs_cols_int4] >> 4;
1553 packed_rhs += rhs_cols_int4 + 1;
1554 }
1555 else
1556 {
1557 rhs_spilled_col = 0;
1558 packed_rhs += rhs_cols_int4;
1559 }
1560
1561 ++dst;
1562 }
1563 #else
1564
1565 const int32_t rhs_cols_int4 = rhs_cols >> 1;
1566
1567 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
1568 {
1569 const int8_t *lhs_ptr = &lhs[0];
1570 int8_t *dst_ptr = &dst[0];
1571
1572 for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1573 {
1574 // To avoid the issue of packed values leaking into the next rhs row
1575 // we instead evaluate the rhs rows in pairs like so:
1576 // rhs[0] and rhs[2], rhs[1] and rhs[3]
1577
1578 // Start processing rhs_row[0] and rhs_row[2]
1579 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1580
1581 int32_t res00 = 0;
1582 int32_t res01 = 0;
1583 int32_t res10 = 0;
1584 int32_t res11 = 0;
1585
1586 int32_t spillover00 = 0;
1587 int32_t spillover01 = 0;
1588 int32_t spillover10 = 0;
1589 int32_t spillover11 = 0;
1590
1591 if (bias)
1592 {
1593 res00 = bias[rhs_rows_idx];
1594 res01 = bias[rhs_rows_idx + 2];
1595 res10 = bias[rhs_rows_idx];
1596 res11 = bias[rhs_rows_idx + 2];
1597 }
1598
1599 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1600 {
1601 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1602 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1603 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1604 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1605
1606 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1607 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1608
1609 res00 += lhs_low * rhs_low0;
1610 res00 += lhs_high * rhs_high0;
1611 res01 += lhs_low * rhs_low1;
1612 res01 += lhs_high * rhs_high1;
1613
1614 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1615 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1616
1617 res10 += lhs_low * rhs_low0;
1618 res10 += lhs_high * rhs_high0;
1619 res11 += lhs_low * rhs_low1;
1620 res11 += lhs_high * rhs_high1;
1621
1622 ++packed_rhs_ptr;
1623 lhs_ptr += 2;
1624 }
1625 if (rhs_cols % 2)
1626 {
1627 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1628 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1629 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1630 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1631
1632 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1633 int32_t lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1634
1635 res00 += lhs_low * rhs_low0;
1636 res01 += lhs_low * rhs_low1;
1637
1638 res10 += lhs_high * rhs_low0;
1639 res11 += lhs_high * rhs_low1;
1640
1641 lhs_ptr -= rhs_cols - 1;
1642
1643 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1644 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1645
1646 spillover00 += lhs_low * rhs_high0;
1647 spillover01 += lhs_low * rhs_high1;
1648
1649 spillover10 += lhs_high * rhs_high0;
1650 spillover11 += lhs_high * rhs_high1;
1651
1652 ++packed_rhs_ptr;
1653 ++lhs_ptr;
1654 }
1655 else
1656 {
1657 lhs_ptr -= rhs_cols;
1658 }
1659
1660 // Quantize down
1661 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1662 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1663 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1664 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1665
1666 // Add offset
1667 res00 += dst_offset;
1668 res01 += dst_offset;
1669 res10 += dst_offset;
1670 res11 += dst_offset;
1671
1672 // Clamp the result
1673 res00 = MAX(res00, activation_min);
1674 res00 = MIN(res00, activation_max);
1675 res01 = MAX(res01, activation_min);
1676 res01 = MIN(res01, activation_max);
1677 res10 = MAX(res10, activation_min);
1678 res10 = MIN(res10, activation_max);
1679 res11 = MAX(res11, activation_min);
1680 res11 = MIN(res11, activation_max);
1681
1682 dst_ptr[0] = (int8_t)res00;
1683 dst_ptr[2] = (int8_t)res01;
1684 dst_ptr += rhs_rows;
1685 dst_ptr[0] = (int8_t)res10;
1686 dst_ptr[2] = (int8_t)res11;
1687 dst_ptr -= rhs_rows;
1688
1689 // Start processing rhs_row[1] and rhs_row[3]
1690 res00 = spillover00;
1691 res01 = spillover01;
1692 res10 = spillover10;
1693 res11 = spillover11;
1694 if (bias)
1695 {
1696 res00 += bias[rhs_rows_idx + 1];
1697 res01 += bias[rhs_rows_idx + 3];
1698 res10 += bias[rhs_rows_idx + 1];
1699 res11 += bias[rhs_rows_idx + 3];
1700 }
1701
1702 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1703 {
1704 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1705 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1706 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1707 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1708
1709 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1710 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1711
1712 res00 += lhs_low * rhs_low0;
1713 res00 += lhs_high * rhs_high0;
1714 res01 += lhs_low * rhs_low1;
1715 res01 += lhs_high * rhs_high1;
1716
1717 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1718 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1719
1720 res10 += lhs_low * rhs_low0;
1721 res10 += lhs_high * rhs_high0;
1722 res11 += lhs_low * rhs_low1;
1723 res11 += lhs_high * rhs_high1;
1724
1725 ++packed_rhs_ptr;
1726 lhs_ptr += 2;
1727 }
1728
1729 // Quantize down
1730 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1731 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1732 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1733 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1734
1735 // Add offset
1736 res00 += dst_offset;
1737 res01 += dst_offset;
1738 res10 += dst_offset;
1739 res11 += dst_offset;
1740
1741 // Clamp the result
1742 res00 = MAX(res00, activation_min);
1743 res00 = MIN(res00, activation_max);
1744 res01 = MAX(res01, activation_min);
1745 res01 = MIN(res01, activation_max);
1746 res10 = MAX(res10, activation_min);
1747 res10 = MIN(res10, activation_max);
1748 res11 = MAX(res11, activation_min);
1749 res11 = MIN(res11, activation_max);
1750
1751 dst_ptr[1] = (int8_t)res00;
1752 dst_ptr[3] = (int8_t)res01;
1753 dst_ptr += rhs_rows;
1754 dst_ptr[1] = (int8_t)res10;
1755 dst_ptr[3] = (int8_t)res11;
1756 dst_ptr += rhs_rows;
1757
1758 lhs_ptr -= rhs_cols;
1759 lhs_ptr += 2 * lhs_cols_offset;
1760 }
1761
1762 // Left-over row
1763 if (lhs_rows % 2)
1764 {
1765 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1766
1767 int32_t res00 = 0;
1768 int32_t res01 = 0;
1769
1770 int32_t spillover00 = 0;
1771 int32_t spillover01 = 0;
1772
1773 if (bias)
1774 {
1775 res00 += bias[rhs_rows_idx];
1776 res01 += bias[rhs_rows_idx + 2];
1777 }
1778
1779 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1780 {
1781 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1782 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1783
1784 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1785 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1786
1787 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1788 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1789
1790 res00 += lhs_low * rhs_low0;
1791 res00 += lhs_high * rhs_high0;
1792
1793 res01 += lhs_low * rhs_low1;
1794 res01 += lhs_high * rhs_high1;
1795
1796 ++packed_rhs_ptr;
1797 lhs_ptr += 2;
1798 }
1799 if (rhs_cols % 2)
1800 {
1801 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1802 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1803 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1804 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1805
1806 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1807 lhs_ptr -= rhs_cols - 1;
1808 int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
1809
1810 res00 += lhs_low * rhs_low0;
1811 res01 += lhs_low * rhs_low1;
1812 spillover00 = lhs_high * rhs_high0;
1813 spillover01 = lhs_high * rhs_high1;
1814
1815 ++packed_rhs_ptr;
1816 ++lhs_ptr;
1817 }
1818 else
1819 {
1820 lhs_ptr -= rhs_cols;
1821 }
1822
1823 // Quantize down
1824 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1825 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1826
1827 // Add offset
1828 res00 += dst_offset;
1829 res01 += dst_offset;
1830
1831 // Clamp the result
1832 res00 = MAX(res00, activation_min);
1833 res00 = MIN(res00, activation_max);
1834 res01 = MAX(res01, activation_min);
1835 res01 = MIN(res01, activation_max);
1836
1837 dst_ptr[0] = (int8_t)res00;
1838 dst_ptr[2] = (int8_t)res01;
1839
1840 res00 = spillover00;
1841 res01 = spillover01;
1842
1843 if (bias)
1844 {
1845 res00 += bias[rhs_rows_idx + 1];
1846 res01 += bias[rhs_rows_idx + 3];
1847 }
1848
1849 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1850 {
1851 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1852 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1853
1854 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1855 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1856
1857 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1858 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1859
1860 res00 += lhs_low * rhs_low0;
1861 res00 += lhs_high * rhs_high0;
1862
1863 res01 += lhs_low * rhs_low1;
1864 res01 += lhs_high * rhs_high1;
1865
1866 ++packed_rhs_ptr;
1867 lhs_ptr += 2;
1868 }
1869
1870 // Quantize down
1871 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1872 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1873
1874 // Add offset
1875 res00 += dst_offset;
1876 res01 += dst_offset;
1877
1878 // Clamp the result
1879 res00 = MAX(res00, activation_min);
1880 res00 = MIN(res00, activation_max);
1881 res01 = MAX(res01, activation_min);
1882 res01 = MIN(res01, activation_max);
1883
1884 dst_ptr[1] = (int8_t)res00;
1885 dst_ptr[3] = (int8_t)res01;
1886 }
1887
1888 packed_rhs += 2 * rhs_cols;
1889 dst += 4;
1890 }
1891
1892 int32_t spillover00 = 0;
1893 const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1894 // Left over rhs rows will be in the range 0 -> 3
1895 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1896 {
1897 const int8_t *lhs_ptr = &lhs[0];
1898 int8_t *dst_ptr = &dst[0];
1899
1900 for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1901 {
1902 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1903 int32_t res00 = 0;
1904 int32_t res10 = 0;
1905 if (bias)
1906 {
1907 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1908 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1909 }
1910 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1911 // differently by dealing with the leftover column from rhs_row_idx[0]
1912 if (rhs_cols % 2 && rhs_rows_idx == 1)
1913 {
1914 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1915 res00 += lhs_low * spillover00;
1916
1917 lhs_low = lhs_ptr[lhs_cols_offset] + lhs_offset;
1918 res10 += lhs_low * spillover00;
1919
1920 ++lhs_ptr;
1921 }
1922 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1923 {
1924 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1925 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
1926
1927 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1928 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1929
1930 res00 += lhs_low * rhs_low;
1931 res00 += lhs_high * rhs_high;
1932
1933 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1934 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1935
1936 res10 += lhs_low * rhs_low;
1937 res10 += lhs_high * rhs_high;
1938
1939 ++packed_rhs_ptr;
1940 lhs_ptr += 2;
1941 }
1942 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1943 {
1944 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1945
1946 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1947
1948 res00 += lhs_low * rhs_low;
1949
1950 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1951
1952 res10 += lhs_low * rhs_low;
1953
1954 ++lhs_ptr;
1955 }
1956
1957 lhs_ptr -= rhs_cols;
1958 lhs_ptr += 2 * lhs_cols_offset;
1959
1960 // Quantize down
1961 res00 = arm_nn_requantize(
1962 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1963 res10 = arm_nn_requantize(
1964 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1965
1966 // Add offset
1967 res00 += dst_offset;
1968 res10 += dst_offset;
1969
1970 // Clamp the result
1971 res00 = MAX(res00, activation_min);
1972 res00 = MIN(res00, activation_max);
1973 res10 = MAX(res10, activation_min);
1974 res10 = MIN(res10, activation_max);
1975
1976 dst_ptr[0] = (int8_t)res00;
1977 dst_ptr += rhs_rows;
1978 dst_ptr[0] = (int8_t)res10;
1979 dst_ptr += rhs_rows;
1980 }
1981 if (lhs_rows % 2)
1982 {
1983 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1984 int32_t res00 = 0;
1985 if (bias)
1986 {
1987 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1988 }
1989 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1990 // differently by dealing with the leftover column from rhs_row_idx[0]
1991 if (rhs_cols % 2 && rhs_rows_idx == 1)
1992 {
1993 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1994 res00 += lhs_low * spillover00;
1995
1996 ++lhs_ptr;
1997 }
1998 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1999 {
2000 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
2001 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
2002
2003 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
2004 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
2005
2006 res00 += lhs_low * rhs_low;
2007 res00 += lhs_high * rhs_high;
2008
2009 ++packed_rhs_ptr;
2010 lhs_ptr += 2;
2011 }
2012 if (rhs_cols % 2 && (rhs_rows_idx != 1))
2013 {
2014 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
2015
2016 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
2017 res00 += lhs_low * rhs_low;
2018
2019 ++lhs_ptr;
2020 }
2021
2022 // Quantize down
2023 res00 = arm_nn_requantize(
2024 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
2025
2026 // Add offset
2027 res00 += dst_offset;
2028
2029 // Clamp the result
2030 res00 = MAX(res00, activation_min);
2031 res00 = MIN(res00, activation_max);
2032
2033 dst_ptr[0] = (int8_t)res00;
2034 }
2035 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
2036 {
2037 spillover00 = packed_rhs[rhs_cols_int4] >> 4;
2038 packed_rhs += rhs_cols_int4 + (rhs_cols & 0x1);
2039 }
2040 else
2041 {
2042 spillover00 = 0;
2043 packed_rhs += rhs_cols_int4;
2044 }
2045
2046 ++dst;
2047 }
2048
2049 #endif
2050
2051 return ARM_CMSIS_NN_SUCCESS;
2052 }
2053
2054 /**
2055 * @} end of Doxygen group
2056 */
2057