1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_nt_t_s4
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed, and 4 bit rhs.
23 *
24 * $Date: 10 April 2024
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s4 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s4(const int8_t * lhs,const int8_t * packed_rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
49 const int8_t *packed_rhs,
50 const int32_t *bias,
51 int8_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t activation_min,
60 const int32_t activation_max,
61 const int32_t lhs_cols_offset)
62 {
63 #if defined(ARM_MATH_MVEI)
64 int i_items = 0;
65 const int rhs_cols_offset = rhs_cols % 16;
66 const int32_t blk_cnt = rhs_cols >> 4;
67 const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101
68 const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
69 const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3};
70
71 for (; i_items <= (lhs_rows - 4); i_items += 4)
72 {
73 int8_t const *col_base = packed_rhs;
74 for (int i = 0; i < rhs_rows; i++)
75 {
76
77 int32_t acc_n0 = 0;
78 int32_t acc_n1 = 0;
79 int32_t acc_n2 = 0;
80 int32_t acc_n3 = 0;
81
82 int8_t const *ip_row_0 = lhs;
83 int8_t const *ip_row_1 = lhs + lhs_cols_offset;
84 int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
85 int8_t const *ip_row_3 = lhs + (3 * lhs_cols_offset);
86 int32_t sum_tmp = 0;
87
88 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
89
90 if ((rhs_cols & 0x1) & (i & 0x1))
91 {
92 rmdr_mask >>= 1;
93 int32_t col = col_base[0] >> 4;
94 sum_tmp = col;
95 acc_n0 += ip_row_0[0] * col;
96 acc_n1 += ip_row_1[0] * col;
97 acc_n2 += ip_row_2[0] * col;
98 acc_n3 += ip_row_3[0] * col;
99
100 ++col_base;
101 ++ip_row_0;
102 ++ip_row_1;
103 ++ip_row_2;
104 ++ip_row_3;
105 }
106
107 for (int j = blk_cnt; j > 0; --j)
108 {
109 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
110 col_base += 8;
111
112 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
113 col_vec = vshrq_n_s8(col_vec, 4);
114
115 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
116
117 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
118 ip_row_0 += 16;
119 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
120
121 lhs_vec = vldrbq_s8(ip_row_1);
122 ip_row_1 += 16;
123 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
124
125 lhs_vec = vldrbq_s8(ip_row_2);
126 ip_row_2 += 16;
127 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
128
129 lhs_vec = vldrbq_s8(ip_row_3);
130 ip_row_3 += 16;
131 acc_n3 = vmladavaq_s8(acc_n3, col_vec, lhs_vec);
132 }
133
134 if (rmdr_mask)
135 {
136 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
137 col_base += rhs_cols_offset >> 1;
138 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
139 col_vec = vshrq_n_s8(col_vec, 4);
140
141 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
142
143 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
144 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
145
146 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
147 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
148
149 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
150 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
151
152 lhs_vec = vldrbq_z_s8(ip_row_3, rmdr_mask);
153 acc_n3 = vmladavaq_p_s8(acc_n3, col_vec, lhs_vec, rmdr_mask);
154 }
155
156 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
157 sum_tmp *= lhs_offset;
158 if (bias)
159 {
160 sum_tmp += bias[i];
161 }
162 res = vaddq_n_s32(res, sum_tmp);
163
164 res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
165 res = vaddq_n_s32(res, dst_offset);
166
167 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
168 res = vminq_s32(res, vdupq_n_s32(activation_max));
169
170 vstrbq_scatter_offset_s32(dst, scatter_offset, res);
171 dst++;
172 }
173 lhs += 4 * lhs_cols_offset;
174 dst += (3 * rhs_rows);
175 }
176 if (lhs_rows % 4 == 3)
177 {
178 int8_t const *col_base = packed_rhs;
179 const mve_pred16_t requant_mask = vctp32q(3);
180 for (int i = 0; i < rhs_rows; i++)
181 {
182
183 int32_t acc_n0 = 0;
184 int32_t acc_n1 = 0;
185 int32_t acc_n2 = 0;
186
187 int8_t const *ip_row_0 = lhs;
188 int8_t const *ip_row_1 = lhs + lhs_cols_offset;
189 int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset);
190 int32_t sum_tmp = 0;
191
192 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
193
194 if ((rhs_cols & 0x1) & (i & 0x1))
195 {
196 rmdr_mask >>= 1;
197 int32_t col = col_base[0] >> 4;
198 sum_tmp = col;
199 acc_n0 += ip_row_0[0] * col;
200 acc_n1 += ip_row_1[0] * col;
201 acc_n2 += ip_row_2[0] * col;
202
203 ++col_base;
204 ++ip_row_0;
205 ++ip_row_1;
206 ++ip_row_2;
207 }
208
209 for (int j = blk_cnt; j > 0; --j)
210 {
211 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
212 col_base += 8;
213
214 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
215 col_vec = vshrq_n_s8(col_vec, 4);
216
217 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
218
219 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
220 ip_row_0 += 16;
221 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
222
223 lhs_vec = vldrbq_s8(ip_row_1);
224 ip_row_1 += 16;
225 acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec);
226
227 lhs_vec = vldrbq_s8(ip_row_2);
228 ip_row_2 += 16;
229 acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec);
230 }
231
232 if (rmdr_mask)
233 {
234 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
235 col_base += rhs_cols_offset >> 1;
236 col_vec = vrshlq_m_n_s8(col_vec, 4, (lower_nibble_mask & rmdr_mask));
237 col_vec = vshrq_n_s8(col_vec, 4);
238
239 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
240
241 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
242 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
243
244 lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask);
245 acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask);
246
247 lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask);
248 acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask);
249 }
250
251 int32x4_t res = {acc_n0, acc_n1, acc_n2, 0};
252 sum_tmp *= lhs_offset;
253 if (bias)
254 {
255 sum_tmp += bias[i];
256 }
257
258 res = vaddq_x_n_s32(res, sum_tmp, requant_mask);
259
260 res = arm_requantize_mve_pred(res, dst_multipliers[i], dst_shifts[i], requant_mask);
261 res = vaddq_x_n_s32(res, dst_offset, requant_mask);
262
263 res = vmaxq_x_s32(res, vdupq_n_s32(activation_min), requant_mask);
264 res = vminq_x_s32(res, vdupq_n_s32(activation_max), requant_mask);
265
266 vstrbq_scatter_offset_p_s32(dst, scatter_offset, res, requant_mask);
267 dst++;
268 }
269 lhs += 3 * lhs_cols_offset;
270 dst += (2 * rhs_rows);
271 }
272 else
273 {
274 for (; i_items < lhs_rows; i_items++)
275 {
276 int32_t acc[4];
277 const int32_t *multipliers = dst_multipliers;
278 const int32_t *shifts = dst_shifts;
279 const int8_t *col_base = packed_rhs;
280 int col_inc = rhs_cols_offset >> 1;
281
282 for (int i = 0; i < rhs_rows; i++)
283 {
284 int32_t acc_n0 = 0;
285 const int8_t *ip_row_0 = lhs;
286 int32_t sum_tmp = 0;
287 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
288
289 if ((rhs_cols & 0x1) & (i & 0x1))
290 {
291 rmdr_mask >>= 1;
292 int32_t col = col_base[0] >> 4;
293 sum_tmp += col;
294 acc_n0 += ip_row_0[0] * col;
295
296 ++col_base;
297 ++ip_row_0;
298 }
299
300 for (int j = blk_cnt; j > 0; --j)
301 {
302 int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset);
303 col_base += 8;
304
305 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
306 col_vec = vshrq_n_s8(col_vec, 4);
307
308 sum_tmp = vaddvaq_s8(sum_tmp, col_vec);
309
310 int8x16_t lhs_vec = vldrbq_s8(ip_row_0);
311 ip_row_0 += 16;
312 acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec);
313 }
314
315 if (rmdr_mask)
316 {
317 int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask);
318 col_base += col_inc;
319 col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask);
320 col_vec = vshrq_n_s8(col_vec, 4);
321
322 sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask);
323
324 int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask);
325 acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask);
326 }
327
328 sum_tmp *= lhs_offset;
329 sum_tmp += acc_n0;
330 if (bias)
331 {
332 sum_tmp += bias[i];
333 }
334 const int32_t index = i & 0x3;
335 acc[index] = sum_tmp;
336
337 if (index == 3)
338 {
339 int32x4_t res = vldrwq_s32(acc);
340 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
341 multipliers += 4;
342 shifts += 4;
343 res = vaddq_n_s32(res, dst_offset);
344 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
345 res = vminq_s32(res, vdupq_n_s32(activation_max));
346 vstrbq_s32(dst, res);
347 dst += 4;
348 }
349 }
350 lhs += lhs_cols_offset;
351 const int32_t tail_rows = rhs_rows & 0x3;
352 for (int i = 0; i < tail_rows; i++)
353 {
354 int32_t acc_n0 = acc[i];
355 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
356 acc_n0 += dst_offset;
357 acc_n0 = MAX(acc_n0, activation_min);
358 acc_n0 = MIN(acc_n0, activation_max);
359 *dst++ = (int8_t)acc_n0;
360 }
361 }
362 }
363 #elif defined(ARM_MATH_DSP)
364 const int32_t lhs_cols_off1 = lhs_cols_offset - 4;
365 const int16_t i16_lhs_offset = (int16_t)lhs_offset;
366 const uint32_t ui32_lhs_offset_i16x2 = PKHBT(i16_lhs_offset, i16_lhs_offset, 16);
367 const int32_t rhs_cols_int4 = rhs_cols >> 1;
368
369 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
370 {
371
372 const int8_t *lhs_ptr = &lhs[0];
373 int8_t *dst_ptr = &dst[0];
374
375 int32_t lhs_rows_idx = lhs_rows >> 1;
376 while (lhs_rows_idx)
377 {
378 const int8_t *packed_rhs_ptr = &packed_rhs[0];
379
380 int32_t res00 = 0;
381 int32_t res01 = 0;
382 int32_t res10 = 0;
383 int32_t res11 = 0;
384
385 int32_t spillover00 = 0;
386 int32_t spillover01 = 0;
387 int32_t spillover10 = 0;
388 int32_t spillover11 = 0;
389
390 if (bias)
391 {
392 res00 = bias[rhs_rows_idx];
393 res01 = bias[rhs_rows_idx + 2];
394 res10 = bias[rhs_rows_idx];
395 res11 = bias[rhs_rows_idx + 2];
396 spillover00 = bias[rhs_rows_idx + 1];
397 spillover01 = bias[rhs_rows_idx + 3];
398 spillover10 = bias[rhs_rows_idx + 1];
399 spillover11 = bias[rhs_rows_idx + 3];
400 }
401
402 int32_t rhs_cols_idx = 0;
403
404 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
405
406 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
407 {
408 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
409 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
410 packed_rhs_ptr += 2;
411
412 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
413 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
414 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
415
416 // 4 x MAC res00, res01
417 res00 = SMLAD(rhs_low0, lhs_low, res00);
418 res00 = SMLAD(rhs_high0, lhs_high, res00);
419 res01 = SMLAD(rhs_low1, lhs_low, res01);
420 res01 = SMLAD(rhs_high1, lhs_high, res01);
421
422 // 4 x MAC res10, res11
423 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
424 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
425 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
426
427 res10 = SMLAD(rhs_low0, lhs_low, res10);
428 res11 = SMLAD(rhs_low1, lhs_low, res11);
429 res10 = SMLAD(rhs_high0, lhs_high, res10);
430 res11 = SMLAD(rhs_high1, lhs_high, res11);
431
432 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
433 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
434 packed_rhs_ptr += 2;
435
436 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
437 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
438 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
439
440 // 4 x MAC res00, res01
441 res00 = SMLAD(rhs_low0, lhs_low, res00);
442 res00 = SMLAD(rhs_high0, lhs_high, res00);
443 res01 = SMLAD(rhs_low1, lhs_low, res01);
444 res01 = SMLAD(rhs_high1, lhs_high, res01);
445
446 // 4 x MAC res10, res11
447 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
448 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
449 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
450
451 res10 = SMLAD(rhs_low0, lhs_low, res10);
452 res11 = SMLAD(rhs_low1, lhs_low, res11);
453 res10 = SMLAD(rhs_high0, lhs_high, res10);
454 res11 = SMLAD(rhs_high1, lhs_high, res11);
455
456 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
457 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
458 packed_rhs_ptr += 2;
459
460 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
461 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
462 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
463
464 // 4 x MAC res00, res01
465 res00 = SMLAD(rhs_low0, lhs_low, res00);
466 res00 = SMLAD(rhs_high0, lhs_high, res00);
467 res01 = SMLAD(rhs_low1, lhs_low, res01);
468 res01 = SMLAD(rhs_high1, lhs_high, res01);
469
470 // 4 x MAC res10, res11
471 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
472 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
473 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
474
475 res10 = SMLAD(rhs_low0, lhs_low, res10);
476 res11 = SMLAD(rhs_low1, lhs_low, res11);
477 res10 = SMLAD(rhs_high0, lhs_high, res10);
478 res11 = SMLAD(rhs_high1, lhs_high, res11);
479
480 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
481 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
482 packed_rhs_ptr += 2;
483
484 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
485 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
486 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
487
488 // 4 x MAC res00, res01
489 res00 = SMLAD(rhs_low0, lhs_low, res00);
490 res00 = SMLAD(rhs_high0, lhs_high, res00);
491 res01 = SMLAD(rhs_low1, lhs_low, res01);
492 res01 = SMLAD(rhs_high1, lhs_high, res01);
493
494 // 4 x MAC res10, res11
495 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
496 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
497 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
498
499 res10 = SMLAD(rhs_low0, lhs_low, res10);
500 res11 = SMLAD(rhs_low1, lhs_low, res11);
501 res10 = SMLAD(rhs_high0, lhs_high, res10);
502 res11 = SMLAD(rhs_high1, lhs_high, res11);
503 }
504 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
505 {
506 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
507 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
508 packed_rhs_ptr += 2;
509
510 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
511 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
512 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
513
514 // 4 x MAC res00, res01
515 res00 = SMLAD(rhs_low0, lhs_low, res00);
516 res00 = SMLAD(rhs_high0, lhs_high, res00);
517 res01 = SMLAD(rhs_low1, lhs_low, res01);
518 res01 = SMLAD(rhs_high1, lhs_high, res01);
519
520 // 4 x MAC res10, res11
521 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
522 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
523 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
524 res10 = SMLAD(rhs_low0, lhs_low, res10);
525 res11 = SMLAD(rhs_low1, lhs_low, res11);
526 res10 = SMLAD(rhs_high0, lhs_high, res10);
527 res11 = SMLAD(rhs_high1, lhs_high, res11);
528 }
529 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
530 {
531 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
532 rhs_high0 = packed_rhs_ptr[0] >> 4;
533
534 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
535 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
536
537 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
538 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
539
540 res00 += lhs_low * rhs_low0;
541 res00 += lhs_high * rhs_high0;
542 res01 += lhs_low * rhs_low1;
543 res01 += lhs_high * rhs_high1;
544
545 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
546 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
547 res10 += lhs_low * rhs_low0;
548 res10 += lhs_high * rhs_high0;
549 res11 += lhs_low * rhs_low1;
550 res11 += lhs_high * rhs_high1;
551
552 ++packed_rhs_ptr;
553 lhs_ptr += 2;
554 }
555 if (rhs_cols % 2)
556 {
557 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
558 rhs_high0 = packed_rhs_ptr[0] >> 4;
559 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
560 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
561
562 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
563 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
564
565 res00 += lhs_low * rhs_low0;
566 res01 += lhs_low * rhs_low1;
567
568 res10 += lhs_high * rhs_low0;
569 res11 += lhs_high * rhs_low1;
570
571 lhs_ptr -= rhs_cols - 1;
572
573 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
574 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
575
576 spillover00 += lhs_low * rhs_high0;
577 spillover01 += lhs_low * rhs_high1;
578
579 spillover10 += lhs_high * rhs_high0;
580 spillover11 += lhs_high * rhs_high1;
581
582 ++packed_rhs_ptr;
583 ++lhs_ptr;
584 }
585 else
586 {
587 lhs_ptr -= rhs_cols;
588 }
589
590 // Quantize down
591 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
592 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
593 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
594 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
595
596 // Add offset
597 res00 += dst_offset;
598 res01 += dst_offset;
599 res10 += dst_offset;
600 res11 += dst_offset;
601
602 // Clamp the result
603 res00 = MAX(res00, activation_min);
604 res00 = MIN(res00, activation_max);
605 res01 = MAX(res01, activation_min);
606 res01 = MIN(res01, activation_max);
607 res10 = MAX(res10, activation_min);
608 res10 = MIN(res10, activation_max);
609 res11 = MAX(res11, activation_min);
610 res11 = MIN(res11, activation_max);
611
612 dst_ptr[0] = (int8_t)res00;
613 dst_ptr[2] = (int8_t)res01;
614 dst_ptr += rhs_rows;
615 dst_ptr[0] = (int8_t)res10;
616 dst_ptr[2] = (int8_t)res11;
617 dst_ptr -= rhs_rows;
618
619 res00 = spillover00;
620 res01 = spillover01;
621 res10 = spillover10;
622 res11 = spillover11;
623
624 rhs_cols_idx = 0;
625
626 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
627 {
628 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
629 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
630 packed_rhs_ptr += 2;
631
632 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
633 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
634 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
635
636 // 4 x MAC res00, res01
637 res00 = SMLAD(rhs_low0, lhs_low, res00);
638 res00 = SMLAD(rhs_high0, lhs_high, res00);
639 res01 = SMLAD(rhs_low1, lhs_low, res01);
640 res01 = SMLAD(rhs_high1, lhs_high, res01);
641
642 // 4 x MAC res10, res11
643 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
644 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
645 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
646
647 res10 = SMLAD(rhs_low0, lhs_low, res10);
648 res11 = SMLAD(rhs_low1, lhs_low, res11);
649 res10 = SMLAD(rhs_high0, lhs_high, res10);
650 res11 = SMLAD(rhs_high1, lhs_high, res11);
651
652 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
653 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
654 packed_rhs_ptr += 2;
655
656 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
657 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
658 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
659
660 // 4 x MAC res00, res01
661 res00 = SMLAD(rhs_low0, lhs_low, res00);
662 res00 = SMLAD(rhs_high0, lhs_high, res00);
663 res01 = SMLAD(rhs_low1, lhs_low, res01);
664 res01 = SMLAD(rhs_high1, lhs_high, res01);
665
666 // 4 x MAC res10, res11
667 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
668 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
669 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
670
671 res10 = SMLAD(rhs_low0, lhs_low, res10);
672 res11 = SMLAD(rhs_low1, lhs_low, res11);
673 res10 = SMLAD(rhs_high0, lhs_high, res10);
674 res11 = SMLAD(rhs_high1, lhs_high, res11);
675
676 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
677 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
678 packed_rhs_ptr += 2;
679
680 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
681 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
682 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
683
684 // 4 x MAC res00, res01
685 res00 = SMLAD(rhs_low0, lhs_low, res00);
686 res00 = SMLAD(rhs_high0, lhs_high, res00);
687 res01 = SMLAD(rhs_low1, lhs_low, res01);
688 res01 = SMLAD(rhs_high1, lhs_high, res01);
689
690 // 4 x MAC res10, res11
691 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
692 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
693 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
694
695 res10 = SMLAD(rhs_low0, lhs_low, res10);
696 res11 = SMLAD(rhs_low1, lhs_low, res11);
697 res10 = SMLAD(rhs_high0, lhs_high, res10);
698 res11 = SMLAD(rhs_high1, lhs_high, res11);
699
700 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
701 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
702 packed_rhs_ptr += 2;
703
704 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
705 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
706 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
707
708 // 4 x MAC res00, res01
709 res00 = SMLAD(rhs_low0, lhs_low, res00);
710 res00 = SMLAD(rhs_high0, lhs_high, res00);
711 res01 = SMLAD(rhs_low1, lhs_low, res01);
712 res01 = SMLAD(rhs_high1, lhs_high, res01);
713
714 // 4 x MAC res10, res11
715 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
716 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
717 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
718
719 res10 = SMLAD(rhs_low0, lhs_low, res10);
720 res11 = SMLAD(rhs_low1, lhs_low, res11);
721 res10 = SMLAD(rhs_high0, lhs_high, res10);
722 res11 = SMLAD(rhs_high1, lhs_high, res11);
723 }
724 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
725 {
726 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
727 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
728 packed_rhs_ptr += 2;
729
730 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
731 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
732 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
733
734 // 4 x MAC res00, res01
735 res00 = SMLAD(rhs_low0, lhs_low, res00);
736 res00 = SMLAD(rhs_high0, lhs_high, res00);
737 res01 = SMLAD(rhs_low1, lhs_low, res01);
738 res01 = SMLAD(rhs_high1, lhs_high, res01);
739
740 // 4 x MAC res10, res11
741 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
742 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
743 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
744 res10 = SMLAD(rhs_low0, lhs_low, res10);
745 res11 = SMLAD(rhs_low1, lhs_low, res11);
746 res10 = SMLAD(rhs_high0, lhs_high, res10);
747 res11 = SMLAD(rhs_high1, lhs_high, res11);
748 }
749 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
750 {
751 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
752 rhs_high0 = packed_rhs_ptr[0] >> 4;
753
754 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
755 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
756
757 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
758 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
759
760 res00 += lhs_low * rhs_low0;
761 res00 += lhs_high * rhs_high0;
762 res01 += lhs_low * rhs_low1;
763 res01 += lhs_high * rhs_high1;
764
765 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
766 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
767 res10 += lhs_low * rhs_low0;
768 res10 += lhs_high * rhs_high0;
769 res11 += lhs_low * rhs_low1;
770 res11 += lhs_high * rhs_high1;
771
772 ++packed_rhs_ptr;
773 lhs_ptr += 2;
774 }
775
776 // Quantize down
777 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
778 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
779 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
780 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
781
782 // Add offset
783 res00 += dst_offset;
784 res01 += dst_offset;
785 res10 += dst_offset;
786 res11 += dst_offset;
787
788 // Clamp the result
789 res00 = MAX(res00, activation_min);
790 res00 = MIN(res00, activation_max);
791 res01 = MAX(res01, activation_min);
792 res01 = MIN(res01, activation_max);
793 res10 = MAX(res10, activation_min);
794 res10 = MIN(res10, activation_max);
795 res11 = MAX(res11, activation_min);
796 res11 = MIN(res11, activation_max);
797
798 dst_ptr[1] = (int8_t)res00;
799 dst_ptr[3] = (int8_t)res01;
800 dst_ptr += rhs_rows;
801 dst_ptr[1] = (int8_t)res10;
802 dst_ptr[3] = (int8_t)res11;
803 dst_ptr += rhs_rows;
804
805 lhs_ptr -= rhs_cols;
806 lhs_ptr += 2 * lhs_cols_offset;
807
808 lhs_rows_idx--;
809 }
810
811 // Left-over rows
812 if (lhs_rows % 2)
813 {
814 const int8_t *packed_rhs_ptr = &packed_rhs[0];
815
816 int32_t res00 = 0;
817 int32_t res01 = 0;
818
819 int32_t spillover00 = 0;
820 int32_t spillover01 = 0;
821
822 if (bias)
823 {
824 res00 = bias[rhs_rows_idx];
825 spillover00 = bias[rhs_rows_idx + 1];
826 res01 = bias[rhs_rows_idx + 2];
827 spillover01 = bias[rhs_rows_idx + 3];
828 }
829
830 int32_t rhs_cols_idx = 0;
831
832 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1;
833
834 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
835 {
836 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
837 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
838 packed_rhs_ptr += 2;
839
840 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
841 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
842 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
843
844 // 4 x MAC res00, res01
845 res00 = SMLAD(rhs_low0, lhs_low, res00);
846 res00 = SMLAD(rhs_high0, lhs_high, res00);
847 res01 = SMLAD(rhs_low1, lhs_low, res01);
848 res01 = SMLAD(rhs_high1, lhs_high, res01);
849
850 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
851 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
852 packed_rhs_ptr += 2;
853
854 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
855 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
856 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
857
858 // 4 x MAC res00, res01
859 res00 = SMLAD(rhs_low0, lhs_low, res00);
860 res00 = SMLAD(rhs_high0, lhs_high, res00);
861 res01 = SMLAD(rhs_low1, lhs_low, res01);
862 res01 = SMLAD(rhs_high1, lhs_high, res01);
863
864 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
865 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
866 packed_rhs_ptr += 2;
867
868 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
869 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
870 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
871
872 // 4 x MAC res00, res01
873 res00 = SMLAD(rhs_low0, lhs_low, res00);
874 res00 = SMLAD(rhs_high0, lhs_high, res00);
875 res01 = SMLAD(rhs_low1, lhs_low, res01);
876 res01 = SMLAD(rhs_high1, lhs_high, res01);
877
878 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
879 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
880 packed_rhs_ptr += 2;
881
882 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
883 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
884 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
885
886 // 4 x MAC res00, res01
887 res00 = SMLAD(rhs_low0, lhs_low, res00);
888 res00 = SMLAD(rhs_high0, lhs_high, res00);
889 res01 = SMLAD(rhs_low1, lhs_low, res01);
890 res01 = SMLAD(rhs_high1, lhs_high, res01);
891 }
892 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
893 {
894 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
895 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
896 packed_rhs_ptr += 2;
897
898 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
899 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
900 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
901
902 // 4 x MAC res00, res01
903 res00 = SMLAD(rhs_low0, lhs_low, res00);
904 res00 = SMLAD(rhs_high0, lhs_high, res00);
905 res01 = SMLAD(rhs_low1, lhs_low, res01);
906 res01 = SMLAD(rhs_high1, lhs_high, res01);
907 }
908 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
909 {
910 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
911 rhs_high0 = packed_rhs_ptr[0] >> 4;
912
913 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
914 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
915
916 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
917 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
918
919 res00 += lhs_low * rhs_low0;
920 res00 += lhs_high * rhs_high0;
921 res01 += lhs_low * rhs_low1;
922 res01 += lhs_high * rhs_high1;
923
924 ++packed_rhs_ptr;
925 lhs_ptr += 2;
926 }
927 if (rhs_cols % 2)
928 {
929 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
930 rhs_high0 = packed_rhs_ptr[0] >> 4;
931 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
932 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
933
934 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
935 lhs_ptr -= rhs_cols - 1;
936 lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
937
938 res00 += lhs_low * rhs_low0;
939 res01 += lhs_low * rhs_low1;
940 spillover00 += lhs_high * rhs_high0;
941 spillover01 += lhs_high * rhs_high1;
942
943 ++packed_rhs_ptr;
944 ++lhs_ptr;
945 }
946 else
947 {
948 lhs_ptr -= rhs_cols;
949 }
950
951 // Quantize down
952 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
953 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
954
955 // Add offset
956 res00 += dst_offset;
957 res01 += dst_offset;
958
959 // Clamp the result
960 res00 = MAX(res00, activation_min);
961 res00 = MIN(res00, activation_max);
962 res01 = MAX(res01, activation_min);
963 res01 = MIN(res01, activation_max);
964
965 dst_ptr[0] = (int8_t)res00;
966 dst_ptr[2] = (int8_t)res01;
967
968 res00 = spillover00;
969 res01 = spillover01;
970
971 rhs_cols_idx = 0;
972
973 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
974 {
975 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
976 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
977 packed_rhs_ptr += 2;
978
979 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
980 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
981 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
982
983 // 4 x MAC res00, res01
984 res00 = SMLAD(rhs_low0, lhs_low, res00);
985 res00 = SMLAD(rhs_high0, lhs_high, res00);
986 res01 = SMLAD(rhs_low1, lhs_low, res01);
987 res01 = SMLAD(rhs_high1, lhs_high, res01);
988
989 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
990 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
991 packed_rhs_ptr += 2;
992
993 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
994 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
995 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
996
997 // 4 x MAC res00, res01
998 res00 = SMLAD(rhs_low0, lhs_low, res00);
999 res00 = SMLAD(rhs_high0, lhs_high, res00);
1000 res01 = SMLAD(rhs_low1, lhs_low, res01);
1001 res01 = SMLAD(rhs_high1, lhs_high, res01);
1002
1003 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1004 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1005 packed_rhs_ptr += 2;
1006
1007 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1008 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1009 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1010
1011 // 4 x MAC res00, res01
1012 res00 = SMLAD(rhs_low0, lhs_low, res00);
1013 res00 = SMLAD(rhs_high0, lhs_high, res00);
1014 res01 = SMLAD(rhs_low1, lhs_low, res01);
1015 res01 = SMLAD(rhs_high1, lhs_high, res01);
1016
1017 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1018 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1019 packed_rhs_ptr += 2;
1020
1021 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1022 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1023 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1024
1025 // 4 x MAC res00, res01
1026 res00 = SMLAD(rhs_low0, lhs_low, res00);
1027 res00 = SMLAD(rhs_high0, lhs_high, res00);
1028 res01 = SMLAD(rhs_low1, lhs_low, res01);
1029 res01 = SMLAD(rhs_high1, lhs_high, res01);
1030 }
1031 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1032 {
1033 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1034 read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1);
1035 packed_rhs_ptr += 2;
1036
1037 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1038 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1039 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1040
1041 // 4 x MAC res00, res01
1042 res00 = SMLAD(rhs_low0, lhs_low, res00);
1043 res00 = SMLAD(rhs_high0, lhs_high, res00);
1044 res01 = SMLAD(rhs_low1, lhs_low, res01);
1045 res01 = SMLAD(rhs_high1, lhs_high, res01);
1046 }
1047 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1048 {
1049 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1050 rhs_high0 = packed_rhs_ptr[0] >> 4;
1051
1052 rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1053 rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1054
1055 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1056 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1057
1058 res00 += lhs_low * rhs_low0;
1059 res00 += lhs_high * rhs_high0;
1060 res01 += lhs_low * rhs_low1;
1061 res01 += lhs_high * rhs_high1;
1062
1063 ++packed_rhs_ptr;
1064 lhs_ptr += 2;
1065 }
1066
1067 // Quantize down
1068 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1069 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1070
1071 // Add offset
1072 res00 += dst_offset;
1073 res01 += dst_offset;
1074
1075 // Clamp the result
1076 res00 = MAX(res00, activation_min);
1077 res00 = MIN(res00, activation_max);
1078 res01 = MAX(res01, activation_min);
1079 res01 = MIN(res01, activation_max);
1080
1081 dst_ptr[1] = (int8_t)res00;
1082 dst_ptr[3] = (int8_t)res01;
1083 }
1084
1085 packed_rhs += 2 * rhs_cols;
1086 dst += 4;
1087 }
1088
1089 int8_t rhs_spilled_col = 0;
1090 const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1091 // Left over rhs rows will be in the range 0 -> 3
1092 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1093 {
1094 const int8_t *lhs_ptr = &lhs[0];
1095 int8_t *dst_ptr = &dst[0];
1096
1097 int32_t lhs_rows_idx = lhs_rows >> 1;
1098 while (lhs_rows_idx)
1099 {
1100 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1101
1102 int32_t res00 = 0;
1103 int32_t res10 = 0;
1104
1105 if (bias)
1106 {
1107 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1108 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1109 }
1110
1111 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1112 // differently by dealing with the leftover column from rhs_row_idx[0]
1113 if (rhs_cols % 2 && rhs_rows_idx == 1)
1114 {
1115 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1116 res00 += lhs_low * rhs_spilled_col;
1117
1118 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1119 res10 += lhs_low * rhs_spilled_col;
1120
1121 ++lhs_ptr;
1122 }
1123
1124 int32_t rhs_cols_idx = 0;
1125
1126 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1127
1128 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1129 {
1130 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1131 packed_rhs_ptr += 2;
1132
1133 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1134 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1135 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1136
1137 // 2 x MAC res00
1138 res00 = SMLAD(rhs_low0, lhs_low, res00);
1139 res00 = SMLAD(rhs_high0, lhs_high, res00);
1140
1141 // 2 x MAC res10
1142 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1143 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1144 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1145
1146 res10 = SMLAD(rhs_low0, lhs_low, res10);
1147 res10 = SMLAD(rhs_high0, lhs_high, res10);
1148
1149 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1150 packed_rhs_ptr += 2;
1151
1152 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1153 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1154 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1155
1156 // 2 x MAC res00
1157 res00 = SMLAD(rhs_low0, lhs_low, res00);
1158 res00 = SMLAD(rhs_high0, lhs_high, res00);
1159
1160 // 2 x MAC res10
1161 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1162 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1163 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1164
1165 res10 = SMLAD(rhs_low0, lhs_low, res10);
1166 res10 = SMLAD(rhs_high0, lhs_high, res10);
1167
1168 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1169 packed_rhs_ptr += 2;
1170
1171 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1172 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1173 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1174
1175 // 2 x MAC res00
1176 res00 = SMLAD(rhs_low0, lhs_low, res00);
1177 res00 = SMLAD(rhs_high0, lhs_high, res00);
1178
1179 // 2 x MAC res10
1180 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1181 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1182 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1183
1184 res10 = SMLAD(rhs_low0, lhs_low, res10);
1185 res10 = SMLAD(rhs_high0, lhs_high, res10);
1186
1187 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1188 packed_rhs_ptr += 2;
1189
1190 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1191 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1192 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1193
1194 // 2 x MAC res00
1195 res00 = SMLAD(rhs_low0, lhs_low, res00);
1196 res00 = SMLAD(rhs_high0, lhs_high, res00);
1197
1198 // 2 x MAC res10
1199 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1200 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1201 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1202
1203 res10 = SMLAD(rhs_low0, lhs_low, res10);
1204 res10 = SMLAD(rhs_high0, lhs_high, res10);
1205 }
1206
1207 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1208 {
1209 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1210 packed_rhs_ptr += 2;
1211
1212 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1213 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1214 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1215
1216 // 2 x MAC res00
1217 res00 = SMLAD(rhs_low0, lhs_low, res00);
1218 res00 = SMLAD(rhs_high0, lhs_high, res00);
1219
1220 // 2 x MAC res10
1221 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]);
1222 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1223 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1224 res10 = SMLAD(rhs_low0, lhs_low, res10);
1225 res10 = SMLAD(rhs_high0, lhs_high, res10);
1226 }
1227
1228 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1229 {
1230 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1231 rhs_high0 = packed_rhs_ptr[0] >> 4;
1232
1233 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1234 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1235
1236 res00 += lhs_low * rhs_low0;
1237 res00 += lhs_high * rhs_high0;
1238
1239 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1240 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1241 res10 += lhs_low * rhs_low0;
1242 res10 += lhs_high * rhs_high0;
1243
1244 ++packed_rhs_ptr;
1245 lhs_ptr += 2;
1246 }
1247
1248 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1249 {
1250 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1251
1252 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1253 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1254
1255 res00 += lhs_low * rhs_low0;
1256 res10 += lhs_high * rhs_low0;
1257
1258 ++lhs_ptr;
1259 }
1260
1261 lhs_ptr -= rhs_cols;
1262 lhs_ptr += 2 * lhs_cols_offset;
1263
1264 // Quantize down
1265 res00 = arm_nn_requantize(
1266 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1267 res10 = arm_nn_requantize(
1268 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1269
1270 // Add offset
1271 res00 += dst_offset;
1272 res10 += dst_offset;
1273
1274 // Clamp the result
1275 res00 = MAX(res00, activation_min);
1276 res00 = MIN(res00, activation_max);
1277 res10 = MAX(res10, activation_min);
1278 res10 = MIN(res10, activation_max);
1279
1280 dst_ptr[0] = (int8_t)res00;
1281 dst_ptr += rhs_rows;
1282 dst_ptr[0] = (int8_t)res10;
1283 dst_ptr += rhs_rows;
1284
1285 lhs_rows_idx--;
1286 }
1287 if (lhs_rows % 2)
1288 {
1289 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1290
1291 int32_t res00 = 0;
1292
1293 if (bias)
1294 {
1295 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1296 }
1297
1298 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1299 // differently by dealing with the leftover column from rhs_row_idx[0]
1300 if (rhs_cols % 2 && rhs_rows_idx == 1)
1301 {
1302 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1303 res00 += lhs_low * rhs_spilled_col;
1304
1305 ++lhs_ptr;
1306 }
1307
1308 int32_t rhs_cols_idx = 0;
1309
1310 int32_t lhs_low, rhs_low0, rhs_high0, lhs_high;
1311
1312 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
1313 {
1314 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1315 packed_rhs_ptr += 2;
1316
1317 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1318 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1319 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1320
1321 // 2 x MAC res00
1322 res00 = SMLAD(rhs_low0, lhs_low, res00);
1323 res00 = SMLAD(rhs_high0, lhs_high, res00);
1324
1325 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1326 packed_rhs_ptr += 2;
1327
1328 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1329 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1330 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1331
1332 // 2 x MAC res00
1333 res00 = SMLAD(rhs_low0, lhs_low, res00);
1334 res00 = SMLAD(rhs_high0, lhs_high, res00);
1335
1336 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1337 packed_rhs_ptr += 2;
1338
1339 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1340 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1341 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1342
1343 // 2 x MAC res00
1344 res00 = SMLAD(rhs_low0, lhs_low, res00);
1345 res00 = SMLAD(rhs_high0, lhs_high, res00);
1346
1347 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1348 packed_rhs_ptr += 2;
1349
1350 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1351 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1352 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1353
1354 // 2 x MAC res00
1355 res00 = SMLAD(rhs_low0, lhs_low, res00);
1356 res00 = SMLAD(rhs_high0, lhs_high, res00);
1357 }
1358
1359 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
1360 {
1361 read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0);
1362 packed_rhs_ptr += 2;
1363
1364 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
1365 lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high);
1366 lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8);
1367
1368 // 2 x MAC res00
1369 res00 = SMLAD(rhs_low0, lhs_low, res00);
1370 res00 = SMLAD(rhs_high0, lhs_high, res00);
1371 }
1372
1373 for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2)
1374 {
1375 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1376 rhs_high0 = packed_rhs_ptr[0] >> 4;
1377
1378 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1379 lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1380
1381 res00 += lhs_low * rhs_low0;
1382 res00 += lhs_high * rhs_high0;
1383
1384 ++packed_rhs_ptr;
1385 lhs_ptr += 2;
1386 }
1387
1388 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1389 {
1390 rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1391
1392 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1393 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1394
1395 res00 += lhs_low * rhs_low0;
1396
1397 ++lhs_ptr;
1398 }
1399
1400 // Quantize down
1401 res00 = arm_nn_requantize(
1402 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1403
1404 // Add offset
1405 res00 += dst_offset;
1406
1407 // Clamp the result
1408 res00 = MAX(res00, activation_min);
1409 res00 = MIN(res00, activation_max);
1410
1411 dst_ptr[0] = (int8_t)res00;
1412 }
1413 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1414 {
1415 rhs_spilled_col = packed_rhs[rhs_cols_int4] >> 4;
1416 packed_rhs += rhs_cols_int4 + 1;
1417 }
1418 else
1419 {
1420 rhs_spilled_col = 0;
1421 packed_rhs += rhs_cols_int4;
1422 }
1423
1424 ++dst;
1425 }
1426 #else
1427
1428 const int32_t rhs_cols_int4 = rhs_cols >> 1;
1429
1430 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4)
1431 {
1432 const int8_t *lhs_ptr = &lhs[0];
1433 int8_t *dst_ptr = &dst[0];
1434
1435 for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1436 {
1437 // To avoid the issue of packed values leaking into the next rhs row
1438 // we instead evaluate the rhs rows in pairs like so:
1439 // rhs[0] and rhs[2], rhs[1] and rhs[3]
1440
1441 // Start processing rhs_row[0] and rhs_row[2]
1442 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1443
1444 int32_t res00 = 0;
1445 int32_t res01 = 0;
1446 int32_t res10 = 0;
1447 int32_t res11 = 0;
1448
1449 int32_t spillover00 = 0;
1450 int32_t spillover01 = 0;
1451 int32_t spillover10 = 0;
1452 int32_t spillover11 = 0;
1453
1454 if (bias)
1455 {
1456 res00 = bias[rhs_rows_idx];
1457 res01 = bias[rhs_rows_idx + 2];
1458 res10 = bias[rhs_rows_idx];
1459 res11 = bias[rhs_rows_idx + 2];
1460 }
1461
1462 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1463 {
1464 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1465 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1466 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1467 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1468
1469 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1470 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1471
1472 res00 += lhs_low * rhs_low0;
1473 res00 += lhs_high * rhs_high0;
1474 res01 += lhs_low * rhs_low1;
1475 res01 += lhs_high * rhs_high1;
1476
1477 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1478 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1479
1480 res10 += lhs_low * rhs_low0;
1481 res10 += lhs_high * rhs_high0;
1482 res11 += lhs_low * rhs_low1;
1483 res11 += lhs_high * rhs_high1;
1484
1485 ++packed_rhs_ptr;
1486 lhs_ptr += 2;
1487 }
1488 if (rhs_cols % 2)
1489 {
1490 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1491 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1492 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1493 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1494
1495 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1496 int32_t lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1497
1498 res00 += lhs_low * rhs_low0;
1499 res01 += lhs_low * rhs_low1;
1500
1501 res10 += lhs_high * rhs_low0;
1502 res11 += lhs_high * rhs_low1;
1503
1504 lhs_ptr -= rhs_cols - 1;
1505
1506 lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1507 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1508
1509 spillover00 += lhs_low * rhs_high0;
1510 spillover01 += lhs_low * rhs_high1;
1511
1512 spillover10 += lhs_high * rhs_high0;
1513 spillover11 += lhs_high * rhs_high1;
1514
1515 ++packed_rhs_ptr;
1516 ++lhs_ptr;
1517 }
1518 else
1519 {
1520 lhs_ptr -= rhs_cols;
1521 }
1522
1523 // Quantize down
1524 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1525 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1526 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1527 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1528
1529 // Add offset
1530 res00 += dst_offset;
1531 res01 += dst_offset;
1532 res10 += dst_offset;
1533 res11 += dst_offset;
1534
1535 // Clamp the result
1536 res00 = MAX(res00, activation_min);
1537 res00 = MIN(res00, activation_max);
1538 res01 = MAX(res01, activation_min);
1539 res01 = MIN(res01, activation_max);
1540 res10 = MAX(res10, activation_min);
1541 res10 = MIN(res10, activation_max);
1542 res11 = MAX(res11, activation_min);
1543 res11 = MIN(res11, activation_max);
1544
1545 dst_ptr[0] = (int8_t)res00;
1546 dst_ptr[2] = (int8_t)res01;
1547 dst_ptr += rhs_rows;
1548 dst_ptr[0] = (int8_t)res10;
1549 dst_ptr[2] = (int8_t)res11;
1550 dst_ptr -= rhs_rows;
1551
1552 // Start processing rhs_row[1] and rhs_row[3]
1553 res00 = spillover00;
1554 res01 = spillover01;
1555 res10 = spillover10;
1556 res11 = spillover11;
1557 if (bias)
1558 {
1559 res00 += bias[rhs_rows_idx + 1];
1560 res01 += bias[rhs_rows_idx + 3];
1561 res10 += bias[rhs_rows_idx + 1];
1562 res11 += bias[rhs_rows_idx + 3];
1563 }
1564
1565 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1566 {
1567 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1568 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1569 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1570 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1571
1572 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1573 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1574
1575 res00 += lhs_low * rhs_low0;
1576 res00 += lhs_high * rhs_high0;
1577 res01 += lhs_low * rhs_low1;
1578 res01 += lhs_high * rhs_high1;
1579
1580 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1581 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1582
1583 res10 += lhs_low * rhs_low0;
1584 res10 += lhs_high * rhs_high0;
1585 res11 += lhs_low * rhs_low1;
1586 res11 += lhs_high * rhs_high1;
1587
1588 ++packed_rhs_ptr;
1589 lhs_ptr += 2;
1590 }
1591
1592 // Quantize down
1593 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1594 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1595 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1596 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1597
1598 // Add offset
1599 res00 += dst_offset;
1600 res01 += dst_offset;
1601 res10 += dst_offset;
1602 res11 += dst_offset;
1603
1604 // Clamp the result
1605 res00 = MAX(res00, activation_min);
1606 res00 = MIN(res00, activation_max);
1607 res01 = MAX(res01, activation_min);
1608 res01 = MIN(res01, activation_max);
1609 res10 = MAX(res10, activation_min);
1610 res10 = MIN(res10, activation_max);
1611 res11 = MAX(res11, activation_min);
1612 res11 = MIN(res11, activation_max);
1613
1614 dst_ptr[1] = (int8_t)res00;
1615 dst_ptr[3] = (int8_t)res01;
1616 dst_ptr += rhs_rows;
1617 dst_ptr[1] = (int8_t)res10;
1618 dst_ptr[3] = (int8_t)res11;
1619 dst_ptr += rhs_rows;
1620
1621 lhs_ptr -= rhs_cols;
1622 lhs_ptr += 2 * lhs_cols_offset;
1623 }
1624
1625 // Left-over row
1626 if (lhs_rows % 2)
1627 {
1628 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1629
1630 int32_t res00 = 0;
1631 int32_t res01 = 0;
1632
1633 int32_t spillover00 = 0;
1634 int32_t spillover01 = 0;
1635
1636 if (bias)
1637 {
1638 res00 += bias[rhs_rows_idx];
1639 res01 += bias[rhs_rows_idx + 2];
1640 }
1641
1642 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1643 {
1644 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1645 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1646
1647 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1648 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1649
1650 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1651 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1652
1653 res00 += lhs_low * rhs_low0;
1654 res00 += lhs_high * rhs_high0;
1655
1656 res01 += lhs_low * rhs_low1;
1657 res01 += lhs_high * rhs_high1;
1658
1659 ++packed_rhs_ptr;
1660 lhs_ptr += 2;
1661 }
1662 if (rhs_cols % 2)
1663 {
1664 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1665 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1666 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1667 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1668
1669 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1670 lhs_ptr -= rhs_cols - 1;
1671 int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
1672
1673 res00 += lhs_low * rhs_low0;
1674 res01 += lhs_low * rhs_low1;
1675 spillover00 = lhs_high * rhs_high0;
1676 spillover01 = lhs_high * rhs_high1;
1677
1678 ++packed_rhs_ptr;
1679 ++lhs_ptr;
1680 }
1681 else
1682 {
1683 lhs_ptr -= rhs_cols;
1684 }
1685
1686 // Quantize down
1687 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
1688 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]);
1689
1690 // Add offset
1691 res00 += dst_offset;
1692 res01 += dst_offset;
1693
1694 // Clamp the result
1695 res00 = MAX(res00, activation_min);
1696 res00 = MIN(res00, activation_max);
1697 res01 = MAX(res01, activation_min);
1698 res01 = MIN(res01, activation_max);
1699
1700 dst_ptr[0] = (int8_t)res00;
1701 dst_ptr[2] = (int8_t)res01;
1702
1703 res00 = spillover00;
1704 res01 = spillover01;
1705
1706 if (bias)
1707 {
1708 res00 += bias[rhs_rows_idx + 1];
1709 res01 += bias[rhs_rows_idx + 3];
1710 }
1711
1712 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1713 {
1714 int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1715 int8_t rhs_high0 = packed_rhs_ptr[0] >> 4;
1716
1717 int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4;
1718 int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4;
1719
1720 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1721 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1722
1723 res00 += lhs_low * rhs_low0;
1724 res00 += lhs_high * rhs_high0;
1725
1726 res01 += lhs_low * rhs_low1;
1727 res01 += lhs_high * rhs_high1;
1728
1729 ++packed_rhs_ptr;
1730 lhs_ptr += 2;
1731 }
1732
1733 // Quantize down
1734 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
1735 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]);
1736
1737 // Add offset
1738 res00 += dst_offset;
1739 res01 += dst_offset;
1740
1741 // Clamp the result
1742 res00 = MAX(res00, activation_min);
1743 res00 = MIN(res00, activation_max);
1744 res01 = MAX(res01, activation_min);
1745 res01 = MIN(res01, activation_max);
1746
1747 dst_ptr[1] = (int8_t)res00;
1748 dst_ptr[3] = (int8_t)res01;
1749 }
1750
1751 packed_rhs += 2 * rhs_cols;
1752 dst += 4;
1753 }
1754
1755 int32_t spillover00 = 0;
1756 const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4);
1757 // Left over rhs rows will be in the range 0 -> 3
1758 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx)
1759 {
1760 const int8_t *lhs_ptr = &lhs[0];
1761 int8_t *dst_ptr = &dst[0];
1762
1763 for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx)
1764 {
1765 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1766 int32_t res00 = 0;
1767 int32_t res10 = 0;
1768 if (bias)
1769 {
1770 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1771 res10 = bias[rhs_rows_finished + rhs_rows_idx];
1772 }
1773 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1774 // differently by dealing with the leftover column from rhs_row_idx[0]
1775 if (rhs_cols % 2 && rhs_rows_idx == 1)
1776 {
1777 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1778 res00 += lhs_low * spillover00;
1779
1780 lhs_low = lhs_ptr[lhs_cols_offset] + lhs_offset;
1781 res10 += lhs_low * spillover00;
1782
1783 ++lhs_ptr;
1784 }
1785 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1786 {
1787 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1788 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
1789
1790 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1791 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1792
1793 res00 += lhs_low * rhs_low;
1794 res00 += lhs_high * rhs_high;
1795
1796 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1797 lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset;
1798
1799 res10 += lhs_low * rhs_low;
1800 res10 += lhs_high * rhs_high;
1801
1802 ++packed_rhs_ptr;
1803 lhs_ptr += 2;
1804 }
1805 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1806 {
1807 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1808
1809 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1810
1811 res00 += lhs_low * rhs_low;
1812
1813 lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset;
1814
1815 res10 += lhs_low * rhs_low;
1816
1817 ++lhs_ptr;
1818 }
1819
1820 lhs_ptr -= rhs_cols;
1821 lhs_ptr += 2 * lhs_cols_offset;
1822
1823 // Quantize down
1824 res00 = arm_nn_requantize(
1825 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1826 res10 = arm_nn_requantize(
1827 res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1828
1829 // Add offset
1830 res00 += dst_offset;
1831 res10 += dst_offset;
1832
1833 // Clamp the result
1834 res00 = MAX(res00, activation_min);
1835 res00 = MIN(res00, activation_max);
1836 res10 = MAX(res10, activation_min);
1837 res10 = MIN(res10, activation_max);
1838
1839 dst_ptr[0] = (int8_t)res00;
1840 dst_ptr += rhs_rows;
1841 dst_ptr[0] = (int8_t)res10;
1842 dst_ptr += rhs_rows;
1843 }
1844 if (lhs_rows % 2)
1845 {
1846 const int8_t *packed_rhs_ptr = &packed_rhs[0];
1847 int32_t res00 = 0;
1848 if (bias)
1849 {
1850 res00 = bias[rhs_rows_finished + rhs_rows_idx];
1851 }
1852 // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1]
1853 // differently by dealing with the leftover column from rhs_row_idx[0]
1854 if (rhs_cols % 2 && rhs_rows_idx == 1)
1855 {
1856 int32_t lhs_low = lhs_ptr[0] + lhs_offset;
1857 res00 += lhs_low * spillover00;
1858
1859 ++lhs_ptr;
1860 }
1861 for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx)
1862 {
1863 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1864 int8_t rhs_high = packed_rhs_ptr[0] >> 4;
1865
1866 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1867 int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
1868
1869 res00 += lhs_low * rhs_low;
1870 res00 += lhs_high * rhs_high;
1871
1872 ++packed_rhs_ptr;
1873 lhs_ptr += 2;
1874 }
1875 if (rhs_cols % 2 && (rhs_rows_idx != 1))
1876 {
1877 int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4;
1878
1879 int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
1880 res00 += lhs_low * rhs_low;
1881
1882 ++lhs_ptr;
1883 }
1884
1885 // Quantize down
1886 res00 = arm_nn_requantize(
1887 res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]);
1888
1889 // Add offset
1890 res00 += dst_offset;
1891
1892 // Clamp the result
1893 res00 = MAX(res00, activation_min);
1894 res00 = MIN(res00, activation_max);
1895
1896 dst_ptr[0] = (int8_t)res00;
1897 }
1898 if (rhs_cols % 2 && !(rhs_rows_idx % 2))
1899 {
1900 spillover00 = packed_rhs[rhs_cols_int4] >> 4;
1901 packed_rhs += rhs_cols_int4 + (rhs_cols & 0x1);
1902 }
1903 else
1904 {
1905 spillover00 = 0;
1906 packed_rhs += rhs_cols_int4;
1907 }
1908
1909 ++dst;
1910 }
1911
1912 #endif
1913 return ARM_CMSIS_NN_SUCCESS;
1914 }
1915
1916 /**
1917 * @} end of Doxygen group
1918 */