1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s4
22 * Description: s4 vector by matrix (transposed) multiplication
23 *
24 * $Date: 26 April 2024
25 * $Revision: V.2.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32 /**
33 */
34
35 /**
36 * @defgroup supportFC Fully Connected
37 *
38 * Support functions for Fully Connected
39 *
40 */
41
42 /**
43 * @addtogroup supportFC
44 * @{
45 */
46
47 /*
48 * s4 vector(lhs) by matrix (transposed) multiplication
49 *
50 * Refer header file for details.
51 *
52 */
arm_nn_vec_mat_mult_t_s4(const int8_t * lhs,const int8_t * packed_rhs,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)53 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
54 const int8_t *packed_rhs,
55 const int32_t *bias,
56 int8_t *dst,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t dst_multiplier,
60 const int32_t dst_shift,
61 const int32_t rhs_cols,
62 const int32_t rhs_rows,
63 const int32_t activation_min,
64 const int32_t activation_max)
65 {
66 const int32_t row_loop_cnt = rhs_rows / 4;
67 const int rhs_offset = rhs_cols * row_loop_cnt;
68 const int8_t *rhs_ptr = &packed_rhs[0];
69
70 #if defined(ARM_MATH_MVEI)
71 const int rhs_cols_offset = rhs_cols % 16;
72 #else
73 const int rhs_cols_offset = rhs_cols;
74 #endif
75
76 #if defined(ARM_MATH_DSP)
77 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
78 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
79 #endif
80
81 int32_t spillover0, spillover1;
82
83 #if defined(ARM_MATH_MVEI)
84 const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101
85 const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
86 const int32_t col_loop_cnt = rhs_cols >> 5;
87 const int I6_elements_spill = rhs_cols & 0x10;
88
89 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
90 {
91 const uint32x4_t scatter_offset = {0, 1, 2 * row_loop_cnt, 2 * row_loop_cnt + 1};
92 const int8_t *lhs_ptr = &lhs[0];
93
94 mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset);
95
96 int32_t acc0 = 0;
97 int32_t acc2 = 0;
98
99 int32_t rhs_sum_0 = 0;
100 int32_t rhs_sum_2 = 0;
101
102 if (bias)
103 {
104 acc0 += *bias;
105 acc2 += bias[2 * row_loop_cnt];
106 ++bias;
107 }
108
109 for (int i = 0; i < col_loop_cnt; i++)
110 {
111 const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
112 const int8x16_t ker_0 = vldrbq_s8(rhs_ptr);
113
114 int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
115 ker_low_0 = vshrq_n_s8(ker_low_0, 4);
116 int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
117
118 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_low_0);
119 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_high_0);
120 acc0 = vmladavaq_s8(acc0, ker_low_0, inputx2.val[0]);
121 acc0 = vmladavaq_s8(acc0, ker_high_0, inputx2.val[1]);
122
123 const int8x16_t ker_1 = vldrbq_s8(&rhs_ptr[rhs_offset]);
124 int8x16_t ker_low_1 = vrshlq_n_s8(ker_1, 4);
125 ker_low_1 = vshrq_n_s8(ker_low_1, 4);
126 int8x16_t ker_high_1 = vshrq_n_s8(ker_1, 4);
127
128 rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_low_1);
129 rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_high_1);
130 acc2 = vmladavaq_s8(acc2, ker_low_1, inputx2.val[0]);
131 acc2 = vmladavaq_s8(acc2, ker_high_1, inputx2.val[1]);
132
133 lhs_ptr += 32;
134 rhs_ptr += 16;
135 }
136
137 if (I6_elements_spill)
138 {
139 const int8x16_t input = vldrbq_s8(lhs_ptr);
140
141 int8x16_t ker_0 = vldrbq_gather_offset_s8(rhs_ptr, gather_offset);
142 ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
143 ker_0 = vshrq_n_s8(ker_0, 4);
144
145 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
146 acc0 = vmladavaq_s8(acc0, ker_0, input);
147
148 int8x16_t ker_1 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
149 ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
150 ker_1 = vshrq_n_s8(ker_1, 4);
151
152 rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_1);
153 acc2 = vmladavaq_s8(acc2, ker_1, input);
154
155 lhs_ptr += 16;
156 rhs_ptr += 8;
157 }
158
159 if (rmdr_mask)
160 {
161 const int8x16_t input = vldrbq_z_s8(lhs_ptr, rmdr_mask);
162
163 int8x16_t ker_0 = vldrbq_gather_offset_z_s8(rhs_ptr, gather_offset, rmdr_mask);
164 ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
165 ker_0 = vshrq_n_s8(ker_0, 4);
166
167 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
168 acc0 = vmladavaq_s8(acc0, ker_0, input);
169
170 int8x16_t ker_1 = vldrbq_gather_offset_z_s8(&rhs_ptr[rhs_offset], gather_offset, rmdr_mask);
171 ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
172 ker_1 = vshrq_n_s8(ker_1, 4);
173
174 rhs_sum_2 = vaddvaq_s8(rhs_sum_2, ker_1);
175 acc2 = vmladavaq_s8(acc2, ker_1, input);
176
177 rhs_ptr += rhs_cols_offset >> 1;
178 }
179
180 if (rhs_cols & 1)
181 {
182 const int32_t rhs_high0 = rhs_ptr[0] >> 4;
183 const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
184
185 lhs_ptr = &lhs[0];
186 const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
187
188 spillover0 = lhs_high * rhs_high0;
189 spillover1 = lhs_high * rhs_high1;
190
191 rmdr_mask >>= 1;
192
193 ++lhs_ptr;
194 ++rhs_ptr;
195 }
196 else
197 {
198 spillover0 = 0;
199 spillover1 = 0;
200 lhs_ptr = &lhs[0];
201 }
202
203 int32_t acc1 = spillover0;
204 int32_t acc3 = spillover1;
205
206 if (bias)
207 {
208 acc1 += *bias;
209 acc3 += bias[2 * row_loop_cnt];
210 ++bias;
211 }
212
213 int32_t rhs_sum_1 = 0;
214 int32_t rhs_sum_3 = 0;
215
216 for (int i = 0; i < col_loop_cnt; i++)
217 {
218 const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
219 int8x16_t ker_0 = vldrbq_s8(rhs_ptr);
220
221 int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
222 ker_low_0 = vshrq_n_s8(ker_low_0, 4);
223 int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
224
225 rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_low_0);
226 rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_high_0);
227 acc1 = vmladavaq_s8(acc1, ker_low_0, inputx2.val[0]);
228 acc1 = vmladavaq_s8(acc1, ker_high_0, inputx2.val[1]);
229
230 int8x16_t ker_1 = vldrbq_s8(&rhs_ptr[rhs_offset]);
231 int8x16_t ker_low_1 = vrshlq_n_s8(ker_1, 4);
232 ker_low_1 = vshrq_n_s8(ker_low_1, 4);
233 int8x16_t ker_high_1 = vshrq_n_s8(ker_1, 4);
234
235 rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_low_1);
236 rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_high_1);
237 acc3 = vmladavaq_s8(acc3, ker_low_1, inputx2.val[0]);
238 acc3 = vmladavaq_s8(acc3, ker_high_1, inputx2.val[1]);
239
240 lhs_ptr += 32;
241 rhs_ptr += 16;
242 }
243
244 if (I6_elements_spill)
245 {
246 const int8x16_t input = vldrbq_s8(lhs_ptr);
247
248 int8x16_t ker_0 = vldrbq_gather_offset_s8(rhs_ptr, gather_offset);
249 ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
250 ker_0 = vshrq_n_s8(ker_0, 4);
251
252 rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_0);
253 acc1 = vmladavaq_s8(acc1, ker_0, input);
254
255 int8x16_t ker_1 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
256 ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
257 ker_1 = vshrq_n_s8(ker_1, 4);
258
259 rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_1);
260 acc3 = vmladavaq_s8(acc3, ker_1, input);
261
262 lhs_ptr += 16;
263 rhs_ptr += 8;
264 }
265
266 if (rmdr_mask)
267 {
268 const int8x16_t input = vldrbq_z_s8(lhs_ptr, rmdr_mask);
269
270 int8x16_t ker_0 = vldrbq_gather_offset_z_s8(rhs_ptr, gather_offset, rmdr_mask);
271 ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
272 ker_0 = vshrq_n_s8(ker_0, 4);
273
274 rhs_sum_1 = vaddvaq_s8(rhs_sum_1, ker_0);
275 acc1 = vmladavaq_s8(acc1, ker_0, input);
276
277 int8x16_t ker_1 = vldrbq_gather_offset_z_s8(&rhs_ptr[rhs_offset], gather_offset, rmdr_mask);
278 ker_1 = vrshlq_m_n_s8(ker_1, 4, lower_nibble_mask);
279 ker_1 = vshrq_n_s8(ker_1, 4);
280
281 rhs_sum_3 = vaddvaq_s8(rhs_sum_3, ker_1);
282 acc3 = vmladavaq_s8(acc3, ker_1, input);
283
284 rhs_ptr += rhs_cols_offset >> 1;
285 }
286
287 int32x4_t acc = {acc0, acc1, acc2, acc3};
288
289 const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, rhs_sum_3};
290 acc += vdupq_n_s32(lhs_offset) * rhs_sum;
291
292 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
293 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
294 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
295 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
296
297 vstrbq_scatter_offset_s32(dst, scatter_offset, acc);
298
299 dst += 2;
300 }
301
302 #elif defined(ARM_MATH_DSP)
303
304 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
305 {
306 const int8_t *lhs_ptr = &lhs[0];
307 int32_t res0 = 0;
308 int32_t res1 = 0;
309
310 if (bias)
311 {
312 res0 += *bias;
313 res1 += bias[2 * row_loop_cnt];
314 ++bias;
315 }
316
317 for (int rhs_cols_idx = 0; rhs_cols_idx < (rhs_cols / 4); ++rhs_cols_idx)
318 {
319 int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
320
321 read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
322 read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
323 rhs_ptr += 2;
324
325 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
326 lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
327 lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
328
329 res0 = SMLAD(lhs_low, rhs_low0, res0);
330 res0 = SMLAD(lhs_high, rhs_high0, res0);
331 res1 = SMLAD(lhs_low, rhs_low1, res1);
332 res1 = SMLAD(lhs_high, rhs_high1, res1);
333 }
334
335 if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
336 {
337 const int32_t rhs_value0 = rhs_ptr[0];
338 const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
339 const int32_t higher0 = rhs_value0 >> 4;
340
341 const int32_t rhs_value1 = rhs_ptr[rhs_offset];
342 const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
343 const int32_t higher1 = rhs_value1 >> 4;
344
345 const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
346 const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
347
348 res0 += lhs_value_0 * lower0;
349 res0 += lhs_value_1 * higher0;
350 res1 += lhs_value_0 * lower1;
351 res1 += lhs_value_1 * higher1;
352
353 ++rhs_ptr;
354 lhs_ptr += 2;
355 }
356
357 if (rhs_cols % 2 == 1)
358 {
359 const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
360 const int32_t rhs_high0 = rhs_ptr[0] >> 4;
361 const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
362 const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
363
364 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
365 lhs_ptr = &lhs[0];
366 const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
367 ++lhs_ptr;
368
369 res0 += lhs_low * rhs_low0;
370 spillover0 = lhs_high * rhs_high0;
371 res1 += lhs_low * rhs_low1;
372 spillover1 = lhs_high * rhs_high1;
373
374 ++rhs_ptr;
375 }
376 else
377 {
378 spillover0 = 0;
379 spillover1 = 0;
380 lhs_ptr = &lhs[0];
381 }
382
383 // Quantize down
384 res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
385 res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
386
387 // Add offset
388 res0 += dst_offset;
389 res1 += dst_offset;
390
391 // Clamp the result
392 res0 = MAX(res0, activation_min);
393 res0 = MIN(res0, activation_max);
394 res1 = MAX(res1, activation_min);
395 res1 = MIN(res1, activation_max);
396
397 *dst = (int8_t)res0;
398 *(dst + 2 * row_loop_cnt) = (int8_t)res1;
399 dst++;
400
401 res0 = spillover0;
402 res1 = spillover1;
403
404 if (bias)
405 {
406 res0 += *bias;
407 res1 += bias[2 * row_loop_cnt];
408 ++bias;
409 }
410
411 for (int rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 4; ++rhs_cols_idx)
412 {
413 int32_t lhs_high, rhs_high0, rhs_low0, lhs_low, rhs_high1, rhs_low1;
414
415 read_and_pad_s4(rhs_ptr, &rhs_low0, &rhs_high0);
416 read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_low1, &rhs_high1);
417 rhs_ptr += 2;
418
419 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
420 lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
421 lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
422
423 res0 = SMLAD(lhs_low, rhs_low0, res0);
424 res0 = SMLAD(lhs_high, rhs_high0, res0);
425 res1 = SMLAD(lhs_low, rhs_low1, res1);
426 res1 = SMLAD(lhs_high, rhs_high1, res1);
427 }
428
429 if (((rhs_cols % 4) == 2) || ((rhs_cols % 4) == 3))
430 {
431 const int32_t rhs_value0 = rhs_ptr[0];
432 const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
433 const int32_t higher0 = rhs_value0 >> 4;
434
435 const int32_t rhs_value1 = rhs_ptr[rhs_offset];
436 const int32_t lower1 = (int8_t)(rhs_value1 << 4) >> 4;
437 const int32_t higher1 = rhs_value1 >> 4;
438
439 const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
440 const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
441
442 res0 += lhs_value_0 * lower0;
443 res0 += lhs_value_1 * higher0;
444 res1 += lhs_value_0 * lower1;
445 res1 += lhs_value_1 * higher1;
446
447 ++rhs_ptr;
448 lhs_ptr += 2;
449 }
450
451 // Quantize down
452 res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
453 res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
454
455 // Add offset
456 res0 += dst_offset;
457 res1 += dst_offset;
458
459 // Clamp the result
460 res0 = MAX(res0, activation_min);
461 res0 = MIN(res0, activation_max);
462 res1 = MAX(res1, activation_min);
463 res1 = MIN(res1, activation_max);
464
465 *dst = (int8_t)res0;
466
467 *(dst + 2 * row_loop_cnt) = (int8_t)res1;
468 dst++;
469 }
470
471 #else
472
473 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; ++i_row_loop_cnt)
474 {
475 const int8_t *lhs_ptr = &lhs[0];
476
477 int32_t res0 = 0;
478 int32_t res1 = 0;
479
480 if (bias)
481 {
482 res0 += *bias;
483 res1 += bias[2 * row_loop_cnt];
484 ++bias;
485 }
486
487 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
488 {
489 const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
490 const int32_t rhs_high0 = rhs_ptr[0] >> 4;
491 const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
492 const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
493
494 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
495 const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
496
497 res0 += lhs_low * rhs_low0;
498 res0 += lhs_high * rhs_high0;
499 res1 += lhs_low * rhs_low1;
500 res1 += lhs_high * rhs_high1;
501
502 ++rhs_ptr;
503 lhs_ptr += 2;
504 }
505
506 if (rhs_cols % 2 == 1)
507 {
508 const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
509 const int32_t rhs_high0 = rhs_ptr[0] >> 4;
510 const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
511 const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
512
513 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
514 lhs_ptr = &lhs[0];
515 const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
516 ++lhs_ptr;
517
518 res0 += lhs_low * rhs_low0;
519 spillover0 = lhs_high * rhs_high0;
520 res1 += lhs_low * rhs_low1;
521 spillover1 = lhs_high * rhs_high1;
522
523 ++rhs_ptr;
524 }
525 else
526 {
527 spillover0 = 0;
528 spillover1 = 0;
529 lhs_ptr = &lhs[0];
530 }
531
532 // Quantize down
533 res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
534 res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
535
536 // Add offset
537 res0 += dst_offset;
538 res1 += dst_offset;
539
540 // Clamp the result
541 res0 = MAX(res0, activation_min);
542 res0 = MIN(res0, activation_max);
543 res1 = MAX(res1, activation_min);
544 res1 = MIN(res1, activation_max);
545
546 *dst = (int8_t)res0;
547
548 *(dst + 2 * row_loop_cnt) = (int8_t)res1;
549 dst++;
550
551 res0 = spillover0;
552 res1 = spillover1;
553 if (bias)
554 {
555 res0 += *bias;
556 res1 += bias[2 * row_loop_cnt];
557 ++bias;
558 }
559
560 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols / 2; ++rhs_cols_idx)
561 {
562 const int32_t rhs_low0 = (int8_t)(rhs_ptr[0] << 4) >> 4;
563 const int32_t rhs_high0 = rhs_ptr[0] >> 4;
564 const int32_t rhs_low1 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
565 const int32_t rhs_high1 = rhs_ptr[rhs_offset] >> 4;
566
567 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
568 const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
569
570 res0 += lhs_low * rhs_low0;
571 res0 += lhs_high * rhs_high0;
572 res1 += lhs_low * rhs_low1;
573 res1 += lhs_high * rhs_high1;
574
575 ++rhs_ptr;
576 lhs_ptr += 2;
577 }
578
579 // Quantize down
580 res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
581 res1 = arm_nn_requantize(res1, dst_multiplier, dst_shift);
582
583 // Add offset
584 res0 += dst_offset;
585 res1 += dst_offset;
586
587 // Clamp the result
588 res0 = MAX(res0, activation_min);
589 res0 = MIN(res0, activation_max);
590 res1 = MAX(res1, activation_min);
591 res1 = MIN(res1, activation_max);
592
593 *dst = (int8_t)res0;
594 *(dst + 2 * row_loop_cnt) = (int8_t)res1;
595 dst++;
596 }
597
598 #endif
599
600 const int8_t *lhs_ptr = &lhs[0];
601 spillover0 = 0;
602
603 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows % 4; ++i_row_loop_cnt)
604 {
605 int32_t res0 = spillover0;
606 if (bias)
607 {
608 res0 += bias[2 * row_loop_cnt];
609 ++bias;
610 }
611
612 #if defined(ARM_MATH_MVEI)
613 int32_t rhs_sum_0 = 0;
614
615 for (int i = 0; i < col_loop_cnt; i++)
616 {
617 const int8x16x2_t inputx2 = vld2q_s8(lhs_ptr);
618 const int8x16_t ker_0 = vldrbq_s8(&rhs_ptr[rhs_offset]);
619
620 int8x16_t ker_low_0 = vrshlq_n_s8(ker_0, 4);
621 ker_low_0 = vshrq_n_s8(ker_low_0, 4);
622 int8x16_t ker_high_0 = vshrq_n_s8(ker_0, 4);
623
624 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_low_0);
625 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_high_0);
626 res0 = vmladavaq_s8(res0, ker_low_0, inputx2.val[0]);
627 res0 = vmladavaq_s8(res0, ker_high_0, inputx2.val[1]);
628
629 lhs_ptr += 32;
630 rhs_ptr += 16;
631 }
632
633 if (I6_elements_spill)
634 {
635 const int8x16_t input = vldrbq_s8(lhs_ptr);
636
637 int8x16_t ker_0 = vldrbq_gather_offset_s8(&rhs_ptr[rhs_offset], gather_offset);
638 ker_0 = vrshlq_m_n_s8(ker_0, 4, lower_nibble_mask);
639 ker_0 = vshrq_n_s8(ker_0, 4);
640
641 rhs_sum_0 = vaddvaq_s8(rhs_sum_0, ker_0);
642 res0 = vmladavaq_s8(res0, ker_0, input);
643
644 lhs_ptr += 16;
645 rhs_ptr += 8;
646 }
647
648 res0 += lhs_offset * rhs_sum_0;
649 #endif
650
651 #if defined(ARM_MATH_DSP)
652 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols_offset / 4; ++rhs_cols_idx)
653 {
654 int32_t lhs_high, rhs_high0, rhs_low0, lhs_low;
655
656 read_and_pad_s4((const int8_t *)&rhs_ptr[rhs_offset], &rhs_high0, &rhs_low0);
657 rhs_ptr += 2;
658
659 lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
660 lhs_low = SXTAB16(lhs_offset_s16x2, lhs_high);
661 lhs_high = SXTAB16_RORn(lhs_offset_s16x2, lhs_high, 8);
662
663 res0 = SMLAD(lhs_low, rhs_high0, res0);
664 res0 = SMLAD(lhs_high, rhs_low0, res0);
665 }
666
667 if ((rhs_cols % 4) == 2 || (rhs_cols % 4 == 3))
668 {
669 const int32_t rhs_value0 = rhs_ptr[rhs_offset];
670 const int32_t lower0 = (int8_t)(rhs_value0 << 4) >> 4;
671 const int32_t higher0 = rhs_value0 >> 4;
672
673 const int32_t lhs_value_0 = lhs_ptr[0] + lhs_offset;
674 const int32_t lhs_value_1 = lhs_ptr[1] + lhs_offset;
675
676 res0 += lhs_value_0 * lower0;
677 res0 += lhs_value_1 * higher0;
678
679 ++rhs_ptr;
680 lhs_ptr += 2;
681 }
682 #else
683 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols_offset / 2; ++rhs_cols_idx)
684 {
685 const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
686 const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
687
688 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
689 const int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset;
690
691 res0 += lhs_low * rhs_low0;
692 res0 += lhs_high * rhs_high0;
693
694 ++rhs_ptr;
695 lhs_ptr += 2;
696 }
697 #endif
698
699 if ((rhs_cols % 2 == 1) && (i_row_loop_cnt % 2 == 0))
700 {
701 const int32_t rhs_low0 = (int8_t)(rhs_ptr[rhs_offset] << 4) >> 4;
702 const int32_t rhs_high0 = rhs_ptr[rhs_offset] >> 4;
703
704 const int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset;
705 lhs_ptr = &lhs[0];
706 const int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset;
707 ++lhs_ptr;
708
709 res0 += lhs_low * rhs_low0;
710 spillover0 = lhs_high * rhs_high0;
711 ++rhs_ptr;
712 }
713 else
714 {
715 spillover0 = 0;
716 lhs_ptr = &lhs[0];
717 }
718
719 // Quantize down
720 res0 = arm_nn_requantize(res0, dst_multiplier, dst_shift);
721
722 // Add offset
723 res0 += dst_offset;
724
725 // Clamp the result
726 res0 = MAX(res0, activation_min);
727 res0 = MIN(res0, activation_max);
728
729 *(dst + 2 * row_loop_cnt) = (int8_t)res0;
730 dst++;
731 }
732
733 return ARM_CMSIS_NN_SUCCESS;
734 }
735
736 /**
737 * @} end of Doxygen group
738 */
739