1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s4
22  * Description:  s4 vector by matrix (transposed) multiplication
23  *
24  * $Date:        26 April 2024
25  * $Revision:    V.2.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 /**
33  */
34 
35 /**
36  * @defgroup supportFC Fully Connected
37  *
38  * Support functions for Fully Connected
39  *
40  */
41 
42 /**
43  * @addtogroup supportFC
44  * @{
45  */
46 
47 /*
48  * s4 vector(lhs) by matrix (transposed) multiplication
49  *
50  * Refer header file for details.
51  *
52  */
arm_nn_vec_mat_mult_t_s4(const int8_t * lhs,const int8_t * packed_rhs,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)53 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
54                                              const int8_t *packed_rhs,
55                                              const int32_t *bias,
56                                              int8_t *dst,
57                                              const int32_t lhs_offset,
58                                              const int32_t dst_offset,
59                                              const int32_t dst_multiplier,
60                                              const int32_t dst_shift,
61                                              const int32_t rhs_cols,
62                                              const int32_t rhs_rows,
63                                              const int32_t activation_min,
64                                              const int32_t activation_max)
65 {
66     const int32_t row_loop_cnt = rhs_rows / 4;
67     const int rhs_offset = rhs_cols * row_loop_cnt;
68     const int8_t *rhs_ptr = &packed_rhs[0];
69 
70 #if defined(ARM_MATH_MVEI)
71     const int rhs_cols_offset = rhs_cols % 16;
72 #else
73     const int rhs_cols_offset = rhs_cols;
74 #endif
75 
76 #if defined(ARM_MATH_DSP)
77     const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
78     const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
79 #endif
80 
81     int32_t spillover0, spillover1;
82 
83 #if defined(ARM_MATH_MVEI)
84     const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101
85     const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
86     const int32_t col_loop_cnt = rhs_cols >> 5;
87     const int I6_elements_spill = rhs_cols & 0x10;
88 
89     for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
90     {
91         const uint32x4_t scatter_offset = {0, 1, 2 * row_loop_cnt, 2 * row_loop_cnt + 1};
92         const int8_t *lhs_ptr = &lhs[0];
93 
94         mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
95 
96         int32_t acc0 = 0;
97         int32_t acc2 = 0;
98 
99         int32_t rhs_sum_0 = 0;
100         int32_t rhs_sum_2 = 0;
101 
102         if (bias)
103         {
104             acc0 += *bias;
105             acc2 += bias[2 * row_loop_cnt];
106             ++bias;
107         }
108 
109         for (int i = 0; i < col_loop_cnt; i++)
110         {
111             const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
112             const int8x16_t ker_0 = vldrbq_s8(rhs_ptr);
113 
114             int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
115             ker_low_0 = vshrq_n_s8(ker_low_0, 4);
116             int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
117 
118             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_low_0);
119             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_high_0);
120             acc0 = vmladavaq_s8(acc0, ker_low_0, inputx2.val[0]);
121             acc0 = vmladavaq_s8(acc0, ker_high_0, inputx2.val[1]);
122 
123             const int8x16_t ker_1 = vldrbq_s8(&rhs_ptr[rhs_offset]);
124             int8x16_t ker_low_1 = vrshlq_n_s8(ker_1, 4);
125             ker_low_1 = vshrq_n_s8(ker_low_1, 4);
126             int8x16_t ker_high_1 = vshrq_n_s8(ker_1, 4);
127 
128             rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_low_1);
129             rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_high_1);
130             acc2 = vmladavaq_s8(acc2, ker_low_1, inputx2.val[0]);
131             acc2 = vmladavaq_s8(acc2, ker_high_1, inputx2.val[1]);
132 
133             lhs_ptr += 32;
134             rhs_ptr += 16;
135         }
136 
137         if (I6_elements_spill)
138         {
139             const int8x16_t input = vldrbq_s8(lhs_ptr);
140 
141             int8x16_t ker_0 = vldrbq_gather_offset_s8(rhs_ptr, gather_offset);
142             ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
143             ker_0 = vshrq_n_s8(ker_0, 4);
144 
145             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
146             acc0 = vmladavaq_s8(acc0, ker_0, input);
147 
148             int8x16_t ker_1 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
149             ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
150             ker_1 = vshrq_n_s8(ker_1, 4);
151 
152             rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_1);
153             acc2 = vmladavaq_s8(acc2, ker_1, input);
154 
155             lhs_ptr += 16;
156             rhs_ptr += 8;
157         }
158 
159         if (rmdr_mask)
160         {
161             const int8x16_t input = vldrbq_z_s8(lhs_ptr, rmdr_mask);
162 
163             int8x16_t ker_0 = vldrbq_gather_offset_z_s8(rhs_ptr, gather_offset, rmdr_mask);
164             ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
165             ker_0 = vshrq_n_s8(ker_0, 4);
166 
167             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
168             acc0 = vmladavaq_s8(acc0, ker_0, input);
169 
170             int8x16_t ker_1 = vldrbq_gather_offset_z_s8(&rhs_ptr[rhs_offset], gather_offset, rmdr_mask);
171             ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
172             ker_1 = vshrq_n_s8(ker_1, 4);
173 
174             rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_1);
175             acc2 = vmladavaq_s8(acc2, ker_1, input);
176 
177             rhs_ptr += rhs_cols_offset >> 1;
178         }
179 
180         if (rhs_cols & 1)
181         {
182             const int32_t rhs_high0 = rhs_ptr[0] >> 4;
183             const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
184 
185             lhs_ptr = &lhs[0];
186             const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
187 
188             spillover0 = lhs_high * rhs_high0;
189             spillover1 = lhs_high * rhs_high1;
190 
191             rmdr_mask >>= 1;
192 
193             ++lhs_ptr;
194             ++rhs_ptr;
195         }
196         else
197         {
198             spillover0 = 0;
199             spillover1 = 0;
200             lhs_ptr = &lhs[0];
201         }
202 
203         int32_t acc1 = spillover0;
204         int32_t acc3 = spillover1;
205 
206         if (bias)
207         {
208             acc1 += *bias;
209             acc3 += bias[2 * row_loop_cnt];
210             ++bias;
211         }
212 
213         int32_t rhs_sum_1 = 0;
214         int32_t rhs_sum_3 = 0;
215 
216         for (int i = 0; i < col_loop_cnt; i++)
217         {
218             const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
219             int8x16_t ker_0 = vldrbq_s8(rhs_ptr);
220 
221             int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
222             ker_low_0 = vshrq_n_s8(ker_low_0, 4);
223             int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
224 
225             rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_low_0);
226             rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_high_0);
227             acc1 = vmladavaq_s8(acc1, ker_low_0, inputx2.val[0]);
228             acc1 = vmladavaq_s8(acc1, ker_high_0, inputx2.val[1]);
229 
230             int8x16_t ker_1 = vldrbq_s8(&rhs_ptr[rhs_offset]);
231             int8x16_t ker_low_1 = vrshlq_n_s8(ker_1, 4);
232             ker_low_1 = vshrq_n_s8(ker_low_1, 4);
233             int8x16_t ker_high_1 = vshrq_n_s8(ker_1, 4);
234 
235             rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_low_1);
236             rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_high_1);
237             acc3 = vmladavaq_s8(acc3, ker_low_1, inputx2.val[0]);
238             acc3 = vmladavaq_s8(acc3, ker_high_1, inputx2.val[1]);
239 
240             lhs_ptr += 32;
241             rhs_ptr += 16;
242         }
243 
244         if (I6_elements_spill)
245         {
246             const int8x16_t input = vldrbq_s8(lhs_ptr);
247 
248             int8x16_t ker_0 = vldrbq_gather_offset_s8(rhs_ptr, gather_offset);
249             ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
250             ker_0 = vshrq_n_s8(ker_0, 4);
251 
252             rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_0);
253             acc1 = vmladavaq_s8(acc1, ker_0, input);
254 
255             int8x16_t ker_1 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
256             ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
257             ker_1 = vshrq_n_s8(ker_1, 4);
258 
259             rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_1);
260             acc3 = vmladavaq_s8(acc3, ker_1, input);
261 
262             lhs_ptr += 16;
263             rhs_ptr += 8;
264         }
265 
266         if (rmdr_mask)
267         {
268             const int8x16_t input = vldrbq_z_s8(lhs_ptr, rmdr_mask);
269 
270             int8x16_t ker_0 = vldrbq_gather_offset_z_s8(rhs_ptr, gather_offset, rmdr_mask);
271             ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
272             ker_0 = vshrq_n_s8(ker_0, 4);
273 
274             rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_0);
275             acc1 = vmladavaq_s8(acc1, ker_0, input);
276 
277             int8x16_t ker_1 = vldrbq_gather_offset_z_s8(&rhs_ptr[rhs_offset], gather_offset, rmdr_mask);
278             ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
279             ker_1 = vshrq_n_s8(ker_1, 4);
280 
281             rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_1);
282             acc3 = vmladavaq_s8(acc3, ker_1, input);
283 
284             rhs_ptr += rhs_cols_offset >> 1;
285         }
286 
287         int32x4_t acc = {acc0, acc1, acc2, acc3};
288 
289         const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, rhs_sum_3};
290         acc += vdupq_n_s32(lhs_offset) * rhs_sum;
291 
292         acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
293         acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
294         acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
295         acc = vminq_s32(acc, vdupq_n_s32(activation_max));
296 
297         vstrbq_scatter_offset_s32(dst, scatter_offset, acc);
298 
299         dst += 2;
300     }
301 
302 #elif defined(ARM_MATH_DSP)
303 
304     for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
305     {
306         const int8_t *lhs_ptr = &lhs[0];
307         int32_t res0 = 0;
308         int32_t res1 = 0;
309 
310         if (bias)
311         {
312             res0 += *bias;
313             res1 += bias[2 * row_loop_cnt];
314             ++bias;
315         }
316 
317         for (int rhs_cols_idx = 0; rhs_cols_idx < (rhs_cols / 4); ++rhs_cols_idx)
318         {
319             int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
320 
321             read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
322             read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
323             rhs_ptr += 2;
324 
325             lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
326             lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
327             lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
328 
329             res0 = SMLAD(lhs_low, rhs_low0, res0);
330             res0 = SMLAD(lhs_high, rhs_high0, res0);
331             res1 = SMLAD(lhs_low, rhs_low1, res1);
332             res1 = SMLAD(lhs_high, rhs_high1, res1);
333         }
334 
335         if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
336         {
337             const int32_t rhs_value0 = rhs_ptr[0];
338             const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
339             const int32_t higher0 = rhs_value0 >> 4;
340 
341             const int32_t rhs_value1 = rhs_ptr[rhs_offset];
342             const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
343             const int32_t higher1 = rhs_value1 >> 4;
344 
345             const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
346             const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
347 
348             res0 += lhs_value_0 * lower0;
349             res0 += lhs_value_1 * higher0;
350             res1 += lhs_value_0 * lower1;
351             res1 += lhs_value_1 * higher1;
352 
353             ++rhs_ptr;
354             lhs_ptr += 2;
355         }
356 
357         if (rhs_cols % 2 == 1)
358         {
359             const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
360             const int32_t rhs_high0 = rhs_ptr[0] >> 4;
361             const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
362             const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
363 
364             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
365             lhs_ptr = &lhs[0];
366             const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
367             ++lhs_ptr;
368 
369             res0 += lhs_low * rhs_low0;
370             spillover0 = lhs_high * rhs_high0;
371             res1 += lhs_low * rhs_low1;
372             spillover1 = lhs_high * rhs_high1;
373 
374             ++rhs_ptr;
375         }
376         else
377         {
378             spillover0 = 0;
379             spillover1 = 0;
380             lhs_ptr = &lhs[0];
381         }
382 
383         // Quantize down
384         res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
385         res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
386 
387         // Add offset
388         res0 += dst_offset;
389         res1 += dst_offset;
390 
391         // Clamp the result
392         res0 = MAX(res0, activation_min);
393         res0 = MIN(res0, activation_max);
394         res1 = MAX(res1, activation_min);
395         res1 = MIN(res1, activation_max);
396 
397         *dst = (int8_t)res0;
398         *(dst + 2 * row_loop_cnt) = (int8_t)res1;
399         dst++;
400 
401         res0 = spillover0;
402         res1 = spillover1;
403 
404         if (bias)
405         {
406             res0 += *bias;
407             res1 += bias[2 * row_loop_cnt];
408             ++bias;
409         }
410 
411         for (int rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 4; ++rhs_cols_idx)
412         {
413             int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
414 
415             read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
416             read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
417             rhs_ptr += 2;
418 
419             lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
420             lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
421             lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
422 
423             res0 = SMLAD(lhs_low, rhs_low0, res0);
424             res0 = SMLAD(lhs_high, rhs_high0, res0);
425             res1 = SMLAD(lhs_low, rhs_low1, res1);
426             res1 = SMLAD(lhs_high, rhs_high1, res1);
427         }
428 
429         if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
430         {
431             const int32_t rhs_value0 = rhs_ptr[0];
432             const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
433             const int32_t higher0 = rhs_value0 >> 4;
434 
435             const int32_t rhs_value1 = rhs_ptr[rhs_offset];
436             const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
437             const int32_t higher1 = rhs_value1 >> 4;
438 
439             const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
440             const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
441 
442             res0 += lhs_value_0 * lower0;
443             res0 += lhs_value_1 * higher0;
444             res1 += lhs_value_0 * lower1;
445             res1 += lhs_value_1 * higher1;
446 
447             ++rhs_ptr;
448             lhs_ptr += 2;
449         }
450 
451         // Quantize down
452         res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
453         res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
454 
455         // Add offset
456         res0 += dst_offset;
457         res1 += dst_offset;
458 
459         // Clamp the result
460         res0 = MAX(res0, activation_min);
461         res0 = MIN(res0, activation_max);
462         res1 = MAX(res1, activation_min);
463         res1 = MIN(res1, activation_max);
464 
465         *dst = (int8_t)res0;
466 
467         *(dst + 2 * row_loop_cnt) = (int8_t)res1;
468         dst++;
469     }
470 
471 #else
472 
473     for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
474     {
475         const int8_t *lhs_ptr = &lhs[0];
476 
477         int32_t res0 = 0;
478         int32_t res1 = 0;
479 
480         if (bias)
481         {
482             res0 += *bias;
483             res1 += bias[2 * row_loop_cnt];
484             ++bias;
485         }
486 
487         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
488         {
489             const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
490             const int32_t rhs_high0 = rhs_ptr[0] >> 4;
491             const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
492             const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
493 
494             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
495             const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
496 
497             res0 += lhs_low * rhs_low0;
498             res0 += lhs_high * rhs_high0;
499             res1 += lhs_low * rhs_low1;
500             res1 += lhs_high * rhs_high1;
501 
502             ++rhs_ptr;
503             lhs_ptr += 2;
504         }
505 
506         if (rhs_cols % 2 == 1)
507         {
508             const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
509             const int32_t rhs_high0 = rhs_ptr[0] >> 4;
510             const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
511             const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
512 
513             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
514             lhs_ptr = &lhs[0];
515             const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
516             ++lhs_ptr;
517 
518             res0 += lhs_low * rhs_low0;
519             spillover0 = lhs_high * rhs_high0;
520             res1 += lhs_low * rhs_low1;
521             spillover1 = lhs_high * rhs_high1;
522 
523             ++rhs_ptr;
524         }
525         else
526         {
527             spillover0 = 0;
528             spillover1 = 0;
529             lhs_ptr = &lhs[0];
530         }
531 
532         // Quantize down
533         res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
534         res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
535 
536         // Add offset
537         res0 += dst_offset;
538         res1 += dst_offset;
539 
540         // Clamp the result
541         res0 = MAX(res0, activation_min);
542         res0 = MIN(res0, activation_max);
543         res1 = MAX(res1, activation_min);
544         res1 = MIN(res1, activation_max);
545 
546         *dst = (int8_t)res0;
547 
548         *(dst + 2 * row_loop_cnt) = (int8_t)res1;
549         dst++;
550 
551         res0 = spillover0;
552         res1 = spillover1;
553         if (bias)
554         {
555             res0 += *bias;
556             res1 += bias[2 * row_loop_cnt];
557             ++bias;
558         }
559 
560         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
561         {
562             const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
563             const int32_t rhs_high0 = rhs_ptr[0] >> 4;
564             const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
565             const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
566 
567             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
568             const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
569 
570             res0 += lhs_low * rhs_low0;
571             res0 += lhs_high * rhs_high0;
572             res1 += lhs_low * rhs_low1;
573             res1 += lhs_high * rhs_high1;
574 
575             ++rhs_ptr;
576             lhs_ptr += 2;
577         }
578 
579         // Quantize down
580         res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
581         res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
582 
583         // Add offset
584         res0 += dst_offset;
585         res1 += dst_offset;
586 
587         // Clamp the result
588         res0 = MAX(res0, activation_min);
589         res0 = MIN(res0, activation_max);
590         res1 = MAX(res1, activation_min);
591         res1 = MIN(res1, activation_max);
592 
593         *dst = (int8_t)res0;
594         *(dst + 2 * row_loop_cnt) = (int8_t)res1;
595         dst++;
596     }
597 
598 #endif
599 
600     const int8_t *lhs_ptr = &lhs[0];
601     spillover0 = 0;
602 
603     for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows % 4; ++i_row_loop_cnt)
604     {
605         int32_t res0 = spillover0;
606         if (bias)
607         {
608             res0 += bias[2 * row_loop_cnt];
609             ++bias;
610         }
611 
612 #if defined(ARM_MATH_MVEI)
613         int32_t rhs_sum_0 = 0;
614 
615         for (int i = 0; i < col_loop_cnt; i++)
616         {
617             const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
618             const int8x16_t ker_0 = vldrbq_s8(&rhs_ptr[rhs_offset]);
619 
620             int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
621             ker_low_0 = vshrq_n_s8(ker_low_0, 4);
622             int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
623 
624             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_low_0);
625             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_high_0);
626             res0 = vmladavaq_s8(res0, ker_low_0, inputx2.val[0]);
627             res0 = vmladavaq_s8(res0, ker_high_0, inputx2.val[1]);
628 
629             lhs_ptr += 32;
630             rhs_ptr += 16;
631         }
632 
633         if (I6_elements_spill)
634         {
635             const int8x16_t input = vldrbq_s8(lhs_ptr);
636 
637             int8x16_t ker_0 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
638             ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
639             ker_0 = vshrq_n_s8(ker_0, 4);
640 
641             rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
642             res0 = vmladavaq_s8(res0, ker_0, input);
643 
644             lhs_ptr += 16;
645             rhs_ptr += 8;
646         }
647 
648         res0 += lhs_offset * rhs_sum_0;
649 #endif
650 
651 #if defined(ARM_MATH_DSP)
652         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols_offset / 4; ++rhs_cols_idx)
653         {
654             int32_t lhs_high, rhs_high0, rhs_low0, lhs_low;
655 
656             read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_high0, &rhs_low0);
657             rhs_ptr += 2;
658 
659             lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
660             lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
661             lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
662 
663             res0 = SMLAD(lhs_low, rhs_high0, res0);
664             res0 = SMLAD(lhs_high, rhs_low0, res0);
665         }
666 
667         if ((rhs_cols % 4) == 2 || (rhs_cols % 4 == 3))
668         {
669             const int32_t rhs_value0 = rhs_ptr[rhs_offset];
670             const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
671             const int32_t higher0 = rhs_value0 >> 4;
672 
673             const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
674             const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
675 
676             res0 += lhs_value_0 * lower0;
677             res0 += lhs_value_1 * higher0;
678 
679             ++rhs_ptr;
680             lhs_ptr += 2;
681         }
682 #else
683         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols_offset / 2; ++rhs_cols_idx)
684         {
685             const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
686             const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
687 
688             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
689             const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
690 
691             res0 += lhs_low * rhs_low0;
692             res0 += lhs_high * rhs_high0;
693 
694             ++rhs_ptr;
695             lhs_ptr += 2;
696         }
697 #endif
698 
699         if ((rhs_cols % 2 == 1) && (i_row_loop_cnt % 2 == 0))
700         {
701             const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
702             const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
703 
704             const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
705             lhs_ptr = &lhs[0];
706             const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
707             ++lhs_ptr;
708 
709             res0 += lhs_low * rhs_low0;
710             spillover0 = lhs_high * rhs_high0;
711             ++rhs_ptr;
712         }
713         else
714         {
715             spillover0 = 0;
716             lhs_ptr = &lhs[0];
717         }
718 
719         // Quantize down
720         res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
721 
722         // Add offset
723         res0 += dst_offset;
724 
725         // Clamp the result
726         res0 = MAX(res0, activation_min);
727         res0 = MIN(res0, activation_max);
728 
729         *(dst + 2 * row_loop_cnt) = (int8_t)res0;
730         dst++;
731     }
732 
733     return ARM_CMSIS_NN_SUCCESS;
734 }
735 
736 /**
737  * @} end of Doxygen group
738  */
739