1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s8
22 * Description: s8 vector by matrix (transposed) multiplication
23 *
24 * $Date: 14 Feb 2023
25 * $Revision: V.6.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @defgroup supportFC Fully Connected
39 *
40 * Support functions for Fully Connected
41 *
42 */
43
44 /**
45 * @addtogroup supportFC
46 * @{
47 */
48
49 /*
50 * s8 vector(lhs) by matrix (transposed) multiplication
51 *
52 * Refer header file for details.
53 *
54 */
55 #if defined(ARM_MATH_DSP) && !defined(__ARMCC_VERSION) && !defined(__ICCARM__)
56 #pragma GCC optimize("unroll-loops")
57 #endif
arm_nn_vec_mat_mult_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * kernel_sum,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max,const int32_t address_offset,const int32_t rhs_offset)58 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
59 const int8_t *rhs,
60 const int32_t *kernel_sum,
61 const int32_t *bias,
62 int8_t *dst,
63 const int32_t lhs_offset,
64 const int32_t dst_offset,
65 const int32_t dst_multiplier,
66 const int32_t dst_shift,
67 const int32_t rhs_cols,
68 const int32_t rhs_rows,
69 const int32_t activation_min,
70 const int32_t activation_max,
71 const int32_t address_offset,
72 const int32_t rhs_offset)
73 {
74 if (rhs_offset)
75 {
76 #if defined(ARM_MATH_MVEI)
77 const int32_t row_loop_cnt = rhs_rows / 4;
78 const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
79
80 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
81 {
82 int32_t acc_0 = 0;
83 int32_t acc_1 = 0;
84 int32_t acc_2 = 0;
85 int32_t acc_3 = 0;
86
87 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
88
89 const int8_t *lhs_vec = lhs;
90 const int8_t *rhs_0_ptr = rhs;
91 const int8_t *rhs_1_ptr = rhs + rhs_cols;
92 const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
93 const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
94
95 int32_t lhs_sum = 0;
96
97 if (bias)
98 {
99 acc_0 = *bias++;
100 acc_1 = *bias++;
101 acc_2 = *bias++;
102 acc_3 = *bias++;
103 }
104
105 uint32_t col_cnt = (uint32_t)rhs_cols;
106
107 for (int32_t i = 0; i < col_loop_cnt; i++)
108 {
109 mve_pred16_t p = vctp8q(col_cnt);
110 col_cnt -= 16;
111
112 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
113 lhs_sum = vaddvaq_s8(lhs_sum, input);
114
115 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
116 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
117
118 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
119 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
120
121 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
122 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
123
124 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
125 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
126
127 lhs_vec += 16;
128 rhs_0_ptr += 16;
129 rhs_1_ptr += 16;
130 rhs_2_ptr += 16;
131 rhs_3_ptr += 16;
132 }
133 rhs += 4 * rhs_cols;
134
135 int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
136
137 const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
138 acc += vdupq_n_s32(lhs_offset) * rhs_sum;
139 kernel_sum += 4;
140
141 acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
142 acc += vdupq_n_s32(rhs_offset * lhs_offset) * vdupq_n_s32(rhs_cols);
143
144 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
145 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
146 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
147 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
148
149 vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
150
151 dst += 4 * address_offset;
152 }
153
154 const int loop_cnt = rhs_rows % 4;
155 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
156 {
157 int32_t acc_0 = 0;
158 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
159 const int8_t *lhs_vec = lhs;
160 const int8_t *rhs_ptr = rhs;
161 int32_t lhs_sum = 0;
162 uint32_t col_cnt = (uint32_t)rhs_cols;
163
164 for (int32_t i = 0; i < col_loop_cnt; i++)
165 {
166 mve_pred16_t p = vctp8q(col_cnt);
167 col_cnt -= 16;
168 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
169 lhs_sum = vaddvaq_s8(lhs_sum, input);
170
171 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
172 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
173
174 lhs_vec += 16;
175 rhs_ptr += 16;
176 }
177 rhs += rhs_cols;
178
179 if (bias)
180 {
181 acc_0 += *bias;
182 bias++;
183 }
184 const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
185 acc_0 += rhs_sum * lhs_offset;
186 acc_0 += lhs_sum * rhs_offset;
187 acc_0 += rhs_cols * lhs_offset * rhs_offset;
188
189 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
190 acc_0 += dst_offset;
191
192 // Clamp the result
193 acc_0 = MAX(acc_0, activation_min);
194 *dst = MIN(acc_0, activation_max);
195 dst += address_offset;
196 }
197
198 #elif defined(ARM_MATH_DSP)
199 (void)kernel_sum;
200
201 const int32_t row_loop_cnt = rhs_rows / 2;
202 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
203 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
204
205 const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
206 const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
207
208 for (int32_t i = 0; i < row_loop_cnt; i++)
209 {
210 int32_t acc_0 = 0;
211 int32_t acc_1 = 0;
212 if (bias)
213 {
214 acc_0 = *bias++;
215 acc_1 = *bias++;
216 }
217
218 const int32_t col_loop_cnt = rhs_cols / 4;
219
220 const int8_t *lhs_vec = lhs;
221 const int8_t *rhs_0_ptr = rhs;
222 const int8_t *rhs_1_ptr = rhs + rhs_cols;
223 rhs += 2 * rhs_cols;
224
225 for (int32_t j = col_loop_cnt; j != 0; j--)
226 {
227 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
228 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
229
230 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
231
232 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
233 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
234 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
235
236 acc_0 = SMLAD(ker_1, vec_1, acc_0);
237 acc_0 = SMLAD(ker_0, vec_0, acc_0);
238
239 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
240 ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
241 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
242
243 acc_1 = SMLAD(ker_1, vec_1, acc_1);
244 acc_1 = SMLAD(ker_0, vec_0, acc_1);
245 }
246
247 for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
248 {
249 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
250 lhs_vec++;
251 acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
252 rhs_0_ptr++;
253 acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
254 rhs_1_ptr++;
255 }
256
257 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
258 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
259
260 // Add offset
261 acc_0 += dst_offset;
262 acc_1 += dst_offset;
263 // Clamp the result
264 acc_0 = MAX(acc_0, activation_min);
265 acc_0 = MIN(acc_0, activation_max);
266 acc_1 = MAX(acc_1, activation_min);
267 acc_1 = MIN(acc_1, activation_max);
268 *dst = (int8_t)acc_0;
269 *(dst + address_offset) = (int8_t)acc_1;
270 dst += 2 * address_offset;
271 }
272
273 if (rhs_rows & 0x1)
274 {
275 int32_t acc_0 = 0;
276 if (bias)
277 {
278 acc_0 = *bias++;
279 }
280 const int32_t col_loop_cnt = rhs_cols / 4;
281
282 const int8_t *lhs_vec = lhs;
283 const int8_t *rhs_ptr = rhs;
284
285 for (int32_t i = col_loop_cnt; i != 0; i--)
286 {
287 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
288 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
289 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
290
291 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
292 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
293 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
294
295 acc_0 = SMLAD(ker_1, vec_1, acc_0);
296 acc_0 = SMLAD(ker_0, vec_0, acc_0);
297 }
298
299 for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
300 {
301 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
302 lhs_vec++;
303 acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
304 rhs_ptr++;
305 }
306
307 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
308
309 // Add offset
310 acc_0 += dst_offset;
311 // Clamp the result
312 acc_0 = MAX(acc_0, activation_min);
313 acc_0 = MIN(acc_0, activation_max);
314 *dst = (int8_t)acc_0;
315 dst += address_offset;
316 }
317
318 #else
319 (void)kernel_sum;
320
321 const int32_t row_loop_cnt = rhs_rows / 3;
322
323 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
324 {
325 const int8_t *lhs_ptr = lhs;
326 const int8_t *rhs_ptr_0 = &rhs[0];
327 const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
328 const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
329
330 int32_t res00 = 0;
331 int32_t res01 = 0;
332 int32_t res02 = 0;
333 if (bias)
334 {
335 res00 = *bias++;
336 res01 = *bias++;
337 res02 = *bias++;
338 }
339 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
340 {
341 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
342 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
343 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
344 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
345
346 res00 += lhs_value * rhs_value0;
347 res01 += lhs_value * rhs_value1;
348 res02 += lhs_value * rhs_value2;
349
350 ++rhs_ptr_0;
351 ++rhs_ptr_1;
352 ++rhs_ptr_2;
353 ++lhs_ptr;
354 }
355
356 // Quantize down
357 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
358 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
359 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
360
361 // Add offset
362 res00 += dst_offset;
363 res01 += dst_offset;
364 res02 += dst_offset;
365
366 // Clamp the result
367 res00 = MAX(res00, activation_min);
368 res00 = MIN(res00, activation_max);
369 res01 = MAX(res01, activation_min);
370 res01 = MIN(res01, activation_max);
371 res02 = MAX(res02, activation_min);
372 res02 = MIN(res02, activation_max);
373
374 *dst = (int8_t)res00;
375 *(dst + address_offset) = (int8_t)res01;
376 *(dst + 2 * address_offset) = (int8_t)res02;
377 dst += 3 * address_offset;
378
379 rhs += 3 * rhs_cols;
380 }
381
382 const int loop_cnt = rhs_rows % 3;
383
384 for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
385 {
386 const int8_t *lhs_ptr = &lhs[0];
387 const int8_t *rhs_ptr = &rhs[0];
388
389 int32_t res00 = 0;
390 if (bias)
391 {
392 res00 = *bias++;
393 }
394
395 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
396 {
397 int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
398 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
399
400 res00 += lhs_value * rhs_value0;
401
402 ++rhs_ptr;
403 ++lhs_ptr;
404 }
405
406 // Quantize down
407 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
408
409 // Add offset
410 res00 += dst_offset;
411
412 // Clamp the result
413 res00 = MAX(res00, activation_min);
414 res00 = MIN(res00, activation_max);
415
416 *dst = (int8_t)res00;
417 dst += address_offset;
418 rhs += rhs_cols;
419 }
420 #endif
421 }
422
423 else
424 {
425
426 #if defined(ARM_MATH_MVEI)
427 const int32_t row_loop_cnt = rhs_rows / 4;
428 const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
429
430 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
431 {
432 int32_t acc_0 = 0;
433 int32_t acc_1 = 0;
434 int32_t acc_2 = 0;
435 int32_t acc_3 = 0;
436
437 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
438
439 const int8_t *lhs_vec = lhs;
440 const int8_t *rhs_0_ptr = rhs;
441 const int8_t *rhs_1_ptr = rhs + rhs_cols;
442 const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
443 const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
444
445 if (bias)
446 {
447 acc_0 = *bias++;
448 acc_1 = *bias++;
449 acc_2 = *bias++;
450 acc_3 = *bias++;
451 }
452
453 uint32_t col_cnt = (uint32_t)rhs_cols;
454
455 for (int32_t i = 0; i < col_loop_cnt; i++)
456 {
457 mve_pred16_t p = vctp8q(col_cnt);
458 col_cnt -= 16;
459
460 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
461
462 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
463 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
464
465 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
466 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
467
468 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
469 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
470
471 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
472 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
473
474 lhs_vec += 16;
475 rhs_0_ptr += 16;
476 rhs_1_ptr += 16;
477 rhs_2_ptr += 16;
478 rhs_3_ptr += 16;
479 }
480 rhs += 4 * rhs_cols;
481
482 int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
483
484 const int32x4_t rhs_sum = {kernel_sum[0], kernel_sum[1], kernel_sum[2], kernel_sum[3]};
485 acc += vdupq_n_s32(lhs_offset) * rhs_sum;
486 kernel_sum += 4;
487
488 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
489 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
490 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
491 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
492
493 vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
494
495 dst += 4 * address_offset;
496 }
497
498 const int loop_cnt = rhs_rows % 4;
499 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
500 {
501 int32_t acc_0 = 0;
502 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
503 const int8_t *lhs_vec = lhs;
504 const int8_t *rhs_ptr = rhs;
505 uint32_t col_cnt = (uint32_t)rhs_cols;
506
507 for (int32_t i = 0; i < col_loop_cnt; i++)
508 {
509 mve_pred16_t p = vctp8q(col_cnt);
510 col_cnt -= 16;
511 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
512
513 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
514 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
515
516 lhs_vec += 16;
517 rhs_ptr += 16;
518 }
519 rhs += rhs_cols;
520
521 if (bias)
522 {
523 acc_0 += *bias;
524 bias++;
525 }
526 const int32_t rhs_sum = kernel_sum[i_row_loop_cnt];
527 const int32_t offsets = rhs_sum * lhs_offset;
528 acc_0 += offsets;
529 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
530 acc_0 += dst_offset;
531
532 // Clamp the result
533 acc_0 = MAX(acc_0, activation_min);
534 *dst = MIN(acc_0, activation_max);
535 dst += address_offset;
536 }
537
538 #elif defined(ARM_MATH_DSP)
539 (void)kernel_sum;
540
541 const int32_t row_loop_cnt = rhs_rows / 2;
542 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
543 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
544
545 for (int32_t i = 0; i < row_loop_cnt; i++)
546 {
547 int32_t acc_0 = 0;
548 int32_t acc_1 = 0;
549 if (bias)
550 {
551 acc_0 = *bias++;
552 acc_1 = *bias++;
553 }
554
555 const int32_t col_loop_cnt = rhs_cols / 4;
556
557 const int8_t *lhs_vec = lhs;
558 const int8_t *rhs_0_ptr = rhs;
559 const int8_t *rhs_1_ptr = rhs + rhs_cols;
560 rhs += 2 * rhs_cols;
561
562 for (int32_t j = col_loop_cnt; j != 0; j--)
563 {
564 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
565 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
566
567 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
568
569 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
570 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
571 ker_0 = SXTB16(ker_0);
572
573 acc_0 = SMLAD(ker_1, vec_1, acc_0);
574 acc_0 = SMLAD(ker_0, vec_0, acc_0);
575
576 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
577 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
578 ker_0 = SXTB16(ker_0);
579
580 acc_1 = SMLAD(ker_1, vec_1, acc_1);
581 acc_1 = SMLAD(ker_0, vec_0, acc_1);
582 }
583
584 for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
585 {
586 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
587 lhs_vec++;
588 acc_0 += lhs_temp * (*rhs_0_ptr);
589 rhs_0_ptr++;
590 acc_1 += lhs_temp * (*rhs_1_ptr);
591 rhs_1_ptr++;
592 }
593
594 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
595 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
596
597 // Add offset
598 acc_0 += dst_offset;
599 acc_1 += dst_offset;
600 // Clamp the result
601 acc_0 = MAX(acc_0, activation_min);
602 acc_0 = MIN(acc_0, activation_max);
603 acc_1 = MAX(acc_1, activation_min);
604 acc_1 = MIN(acc_1, activation_max);
605 *dst = (int8_t)acc_0;
606 *(dst + address_offset) = (int8_t)acc_1;
607 dst += 2 * address_offset;
608 }
609
610 if (rhs_rows & 0x1)
611 {
612 int32_t acc_0 = 0;
613 if (bias)
614 {
615 acc_0 = *bias++;
616 }
617 const int32_t col_loop_cnt = rhs_cols / 4;
618
619 const int8_t *lhs_vec = lhs;
620 const int8_t *rhs_ptr = rhs;
621
622 for (int32_t i = col_loop_cnt; i != 0; i--)
623 {
624 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
625 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
626 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
627
628 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
629 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
630 ker_0 = SXTB16(ker_0);
631
632 acc_0 = SMLAD(ker_1, vec_1, acc_0);
633 acc_0 = SMLAD(ker_0, vec_0, acc_0);
634 }
635
636 for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
637 {
638 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
639 lhs_vec++;
640 acc_0 += lhs_temp * (*rhs_ptr);
641 rhs_ptr++;
642 }
643
644 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
645
646 // Add offset
647 acc_0 += dst_offset;
648 // Clamp the result
649 acc_0 = MAX(acc_0, activation_min);
650 acc_0 = MIN(acc_0, activation_max);
651 *dst = (int8_t)acc_0;
652 dst += address_offset;
653 }
654
655 #else
656 (void)kernel_sum;
657
658 const int32_t row_loop_cnt = rhs_rows / 3;
659
660 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
661 {
662 const int8_t *lhs_ptr = lhs;
663 const int8_t *rhs_ptr_0 = &rhs[0];
664 const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
665 const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
666
667 int32_t res00 = 0;
668 int32_t res01 = 0;
669 int32_t res02 = 0;
670 if (bias)
671 {
672 res00 = *bias++;
673 res01 = *bias++;
674 res02 = *bias++;
675 }
676 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
677 {
678 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
679 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
680 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
681 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
682
683 res00 += lhs_value * rhs_value0;
684 res01 += lhs_value * rhs_value1;
685 res02 += lhs_value * rhs_value2;
686
687 ++rhs_ptr_0;
688 ++rhs_ptr_1;
689 ++rhs_ptr_2;
690 ++lhs_ptr;
691 }
692 // Quantize down
693 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
694 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
695 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
696
697 // Add offset
698 res00 += dst_offset;
699 res01 += dst_offset;
700 res02 += dst_offset;
701
702 // Clamp the result
703 res00 = MAX(res00, activation_min);
704 res00 = MIN(res00, activation_max);
705 res01 = MAX(res01, activation_min);
706 res01 = MIN(res01, activation_max);
707 res02 = MAX(res02, activation_min);
708 res02 = MIN(res02, activation_max);
709
710 *dst = (int8_t)res00;
711 *(dst + address_offset) = (int8_t)res01;
712 *(dst + 2 * address_offset) = (int8_t)res02;
713 dst += 3 * address_offset;
714
715 rhs += 3 * rhs_cols;
716 }
717
718 const int loop_cnt = rhs_rows % 3;
719
720 for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
721 {
722 const int8_t *lhs_ptr = &lhs[0];
723 const int8_t *rhs_ptr = &rhs[0];
724
725 int32_t res00 = 0;
726 if (bias)
727 {
728 res00 = *bias++;
729 }
730
731 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
732 {
733 int32_t rhs_value0 = (int8_t)rhs_ptr[0];
734 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
735
736 res00 += lhs_value * rhs_value0;
737
738 ++rhs_ptr;
739 ++lhs_ptr;
740 }
741
742 // Quantize down
743 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
744
745 // Add offset
746 res00 += dst_offset;
747
748 // Clamp the result
749 res00 = MAX(res00, activation_min);
750 res00 = MIN(res00, activation_max);
751
752 *dst = (int8_t)res00;
753 dst += address_offset;
754 rhs += rhs_cols;
755 }
756 #endif
757 }
758 return ARM_CMSIS_NN_SUCCESS;
759 }
760
761 /**
762 * @} end of Doxygen group
763 */
764