1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s8
22 * Description: s8 vector by matrix (transposed) multiplication
23 *
24 * $Date: 5 Sep 2024
25 * $Revision: V.6.2.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @defgroup supportFC Fully Connected
39 *
40 * Support functions for Fully Connected
41 *
42 */
43
44 /**
45 * @addtogroup supportFC
46 * @{
47 */
48
49 /*
50 * s8 vector(lhs) by matrix (transposed) multiplication
51 *
52 * Refer header file for details.
53 *
54 */
55 #if !defined(ARM_MATH_MVEI) && defined(ARM_MATH_DSP) && !defined(__ARMCC_VERSION) && !defined(__ICCARM__)
56 #pragma GCC optimize("unroll-loops")
57 #endif
arm_nn_vec_mat_mult_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * kernel_sum,const int32_t * bias,int8_t * dst,const int32_t lhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max,const int32_t address_offset,const int32_t rhs_offset)58 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
59 const int8_t *rhs,
60 const int32_t *kernel_sum,
61 const int32_t *bias,
62 int8_t *dst,
63 const int32_t lhs_offset,
64 const int32_t dst_offset,
65 const int32_t dst_multiplier,
66 const int32_t dst_shift,
67 const int32_t rhs_cols,
68 const int32_t rhs_rows,
69 const int32_t activation_min,
70 const int32_t activation_max,
71 const int32_t address_offset,
72 const int32_t rhs_offset)
73 {
74 if (rhs_offset)
75 {
76 #if defined(ARM_MATH_MVEI)
77 (void)bias;
78 (void)lhs_offset;
79 const int32_t row_loop_cnt = rhs_rows / 4;
80 const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
81
82 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
83 {
84 int32_t acc_0 = *kernel_sum++;
85 int32_t acc_1 = *kernel_sum++;
86 int32_t acc_2 = *kernel_sum++;
87 int32_t acc_3 = *kernel_sum++;
88
89 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
90
91 const int8_t *lhs_vec = lhs;
92 const int8_t *rhs_0_ptr = rhs;
93 const int8_t *rhs_1_ptr = rhs + rhs_cols;
94 const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
95 const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
96
97 int32_t lhs_sum = 0;
98
99 uint32_t col_cnt = (uint32_t)rhs_cols;
100
101 for (int32_t i = 0; i < col_loop_cnt; i++)
102 {
103 mve_pred16_t p = vctp8q(col_cnt);
104 col_cnt -= 16;
105
106 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
107 lhs_sum = vaddvaq_s8(lhs_sum, input);
108
109 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
110 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
111
112 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
113 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
114
115 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
116 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
117
118 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
119 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
120
121 lhs_vec += 16;
122 rhs_0_ptr += 16;
123 rhs_1_ptr += 16;
124 rhs_2_ptr += 16;
125 rhs_3_ptr += 16;
126 }
127 rhs += 4 * rhs_cols;
128
129 int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
130
131 acc += vdupq_n_s32(rhs_offset) * vdupq_n_s32(lhs_sum);
132
133 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
134 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
135 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
136 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
137
138 vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
139
140 dst += 4 * address_offset;
141 }
142
143 const int loop_cnt = rhs_rows % 4;
144 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
145 {
146 int32_t acc_0 = *kernel_sum++;
147 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
148 const int8_t *lhs_vec = lhs;
149 const int8_t *rhs_ptr = rhs;
150 int32_t lhs_sum = 0;
151 uint32_t col_cnt = (uint32_t)rhs_cols;
152
153 for (int32_t i = 0; i < col_loop_cnt; i++)
154 {
155 mve_pred16_t p = vctp8q(col_cnt);
156 col_cnt -= 16;
157 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
158 lhs_sum = vaddvaq_s8(lhs_sum, input);
159
160 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
161 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
162
163 lhs_vec += 16;
164 rhs_ptr += 16;
165 }
166 rhs += rhs_cols;
167
168 acc_0 += lhs_sum * rhs_offset;
169
170 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
171 acc_0 += dst_offset;
172
173 // Clamp the result
174 acc_0 = MAX(acc_0, activation_min);
175 *dst = MIN(acc_0, activation_max);
176 dst += address_offset;
177 }
178
179 #elif defined(ARM_MATH_DSP)
180 (void)kernel_sum;
181
182 const int32_t row_loop_cnt = rhs_rows / 2;
183 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
184 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
185
186 const int16_t rhs_offset_s16 = (int16_t)rhs_offset;
187 const uint32_t rhs_offset_s16x2 = PKHBT(rhs_offset_s16, rhs_offset_s16, 16);
188
189 for (int32_t i = 0; i < row_loop_cnt; i++)
190 {
191 int32_t acc_0 = 0;
192 int32_t acc_1 = 0;
193 if (bias)
194 {
195 acc_0 = *bias++;
196 acc_1 = *bias++;
197 }
198
199 const int32_t col_loop_cnt = rhs_cols / 4;
200
201 const int8_t *lhs_vec = lhs;
202 const int8_t *rhs_0_ptr = rhs;
203 const int8_t *rhs_1_ptr = rhs + rhs_cols;
204 rhs += 2 * rhs_cols;
205
206 for (int32_t j = col_loop_cnt; j != 0; j--)
207 {
208 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
209 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
210
211 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
212
213 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
214 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
215 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
216
217 acc_0 = SMLAD(ker_1, vec_1, acc_0);
218 acc_0 = SMLAD(ker_0, vec_0, acc_0);
219
220 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
221 ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
222 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
223
224 acc_1 = SMLAD(ker_1, vec_1, acc_1);
225 acc_1 = SMLAD(ker_0, vec_0, acc_1);
226 }
227
228 for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
229 {
230 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
231 lhs_vec++;
232 acc_0 += lhs_temp * (*rhs_0_ptr + rhs_offset);
233 rhs_0_ptr++;
234 acc_1 += lhs_temp * (*rhs_1_ptr + rhs_offset);
235 rhs_1_ptr++;
236 }
237
238 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
239 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
240
241 // Add offset
242 acc_0 += dst_offset;
243 acc_1 += dst_offset;
244 // Clamp the result
245 acc_0 = MAX(acc_0, activation_min);
246 acc_0 = MIN(acc_0, activation_max);
247 acc_1 = MAX(acc_1, activation_min);
248 acc_1 = MIN(acc_1, activation_max);
249 *dst = (int8_t)acc_0;
250 *(dst + address_offset) = (int8_t)acc_1;
251 dst += 2 * address_offset;
252 }
253
254 if (rhs_rows & 0x1)
255 {
256 int32_t acc_0 = 0;
257 if (bias)
258 {
259 acc_0 = *bias++;
260 }
261 const int32_t col_loop_cnt = rhs_cols / 4;
262
263 const int8_t *lhs_vec = lhs;
264 const int8_t *rhs_ptr = rhs;
265
266 for (int32_t i = col_loop_cnt; i != 0; i--)
267 {
268 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
269 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
270 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
271
272 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
273 int32_t ker_1 = SXTAB16_RORn(rhs_offset_s16x2, (uint32_t)ker_0, 8);
274 ker_0 = SXTAB16(rhs_offset_s16x2, ker_0);
275
276 acc_0 = SMLAD(ker_1, vec_1, acc_0);
277 acc_0 = SMLAD(ker_0, vec_0, acc_0);
278 }
279
280 for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
281 {
282 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
283 lhs_vec++;
284 acc_0 += lhs_temp * (*rhs_ptr + rhs_offset);
285 rhs_ptr++;
286 }
287
288 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
289
290 // Add offset
291 acc_0 += dst_offset;
292 // Clamp the result
293 acc_0 = MAX(acc_0, activation_min);
294 acc_0 = MIN(acc_0, activation_max);
295 *dst = (int8_t)acc_0;
296 dst += address_offset;
297 }
298
299 #else
300 (void)kernel_sum;
301
302 const int32_t row_loop_cnt = rhs_rows / 3;
303
304 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
305 {
306 const int8_t *lhs_ptr = lhs;
307 const int8_t *rhs_ptr_0 = &rhs[0];
308 const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
309 const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
310
311 int32_t res00 = 0;
312 int32_t res01 = 0;
313 int32_t res02 = 0;
314 if (bias)
315 {
316 res00 = *bias++;
317 res01 = *bias++;
318 res02 = *bias++;
319 }
320 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
321 {
322 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0 + rhs_offset;
323 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1 + rhs_offset;
324 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2 + rhs_offset;
325 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
326
327 res00 += lhs_value * rhs_value0;
328 res01 += lhs_value * rhs_value1;
329 res02 += lhs_value * rhs_value2;
330
331 ++rhs_ptr_0;
332 ++rhs_ptr_1;
333 ++rhs_ptr_2;
334 ++lhs_ptr;
335 }
336
337 // Quantize down
338 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
339 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
340 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
341
342 // Add offset
343 res00 += dst_offset;
344 res01 += dst_offset;
345 res02 += dst_offset;
346
347 // Clamp the result
348 res00 = MAX(res00, activation_min);
349 res00 = MIN(res00, activation_max);
350 res01 = MAX(res01, activation_min);
351 res01 = MIN(res01, activation_max);
352 res02 = MAX(res02, activation_min);
353 res02 = MIN(res02, activation_max);
354
355 *dst = (int8_t)res00;
356 *(dst + address_offset) = (int8_t)res01;
357 *(dst + 2 * address_offset) = (int8_t)res02;
358 dst += 3 * address_offset;
359
360 rhs += 3 * rhs_cols;
361 }
362
363 const int loop_cnt = rhs_rows % 3;
364
365 for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
366 {
367 const int8_t *lhs_ptr = &lhs[0];
368 const int8_t *rhs_ptr = &rhs[0];
369
370 int32_t res00 = 0;
371 if (bias)
372 {
373 res00 = *bias++;
374 }
375
376 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
377 {
378 int32_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
379 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
380
381 res00 += lhs_value * rhs_value0;
382
383 ++rhs_ptr;
384 ++lhs_ptr;
385 }
386
387 // Quantize down
388 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
389
390 // Add offset
391 res00 += dst_offset;
392
393 // Clamp the result
394 res00 = MAX(res00, activation_min);
395 res00 = MIN(res00, activation_max);
396
397 *dst = (int8_t)res00;
398 dst += address_offset;
399 rhs += rhs_cols;
400 }
401 #endif
402 }
403
404 else
405 {
406 #if defined(ARM_MATH_MVEI)
407 const int32_t row_loop_cnt = rhs_rows / 4;
408 const uint32x4_t address_offset_array = {0, address_offset, address_offset * 2, address_offset * 3};
409
410 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
411 {
412 int32_t acc_0 = *kernel_sum++;
413 int32_t acc_1 = *kernel_sum++;
414 int32_t acc_2 = *kernel_sum++;
415 int32_t acc_3 = *kernel_sum++;
416
417 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
418
419 const int8_t *lhs_vec = lhs;
420 const int8_t *rhs_0_ptr = rhs;
421 const int8_t *rhs_1_ptr = rhs + rhs_cols;
422 const int8_t *rhs_2_ptr = rhs + 2 * rhs_cols;
423 const int8_t *rhs_3_ptr = rhs + 3 * rhs_cols;
424
425 uint32_t col_cnt = (uint32_t)rhs_cols;
426
427 for (int32_t i = 0; i < col_loop_cnt; i++)
428 {
429 mve_pred16_t p = vctp8q(col_cnt);
430 col_cnt -= 16;
431
432 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
433
434 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0_ptr, p);
435 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
436
437 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1_ptr, p);
438 acc_1 = vmladavaq_s8(acc_1, ker_1, input);
439
440 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2_ptr, p);
441 acc_2 = vmladavaq_s8(acc_2, ker_2, input);
442
443 const int8x16_t ker_3 = vldrbq_z_s8(rhs_3_ptr, p);
444 acc_3 = vmladavaq_s8(acc_3, ker_3, input);
445
446 lhs_vec += 16;
447 rhs_0_ptr += 16;
448 rhs_1_ptr += 16;
449 rhs_2_ptr += 16;
450 rhs_3_ptr += 16;
451 }
452 rhs += 4 * rhs_cols;
453
454 int32x4_t acc = {acc_0, acc_1, acc_2, acc_3};
455
456 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
457 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
458 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
459 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
460
461 vstrbq_scatter_offset_s32(dst, address_offset_array, acc);
462
463 dst += 4 * address_offset;
464 }
465
466 const int loop_cnt = rhs_rows % 4;
467 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
468 {
469 int32_t acc_0 = *kernel_sum++;
470 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
471 const int8_t *lhs_vec = lhs;
472 const int8_t *rhs_ptr = rhs;
473 uint32_t col_cnt = (uint32_t)rhs_cols;
474
475 for (int32_t i = 0; i < col_loop_cnt; i++)
476 {
477 mve_pred16_t p = vctp8q(col_cnt);
478 col_cnt -= 16;
479 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
480
481 const int8x16_t ker_0 = vldrbq_z_s8(rhs_ptr, p);
482 acc_0 = vmladavaq_s8(acc_0, ker_0, input);
483
484 lhs_vec += 16;
485 rhs_ptr += 16;
486 }
487 rhs += rhs_cols;
488
489 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
490 acc_0 += dst_offset;
491
492 // Clamp the result
493 acc_0 = MAX(acc_0, activation_min);
494 *dst = MIN(acc_0, activation_max);
495 dst += address_offset;
496 }
497
498 #elif defined(ARM_MATH_DSP)
499 (void)kernel_sum;
500
501 const int32_t row_loop_cnt = rhs_rows / 2;
502 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
503 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
504
505 for (int32_t i = 0; i < row_loop_cnt; i++)
506 {
507 int32_t acc_0 = 0;
508 int32_t acc_1 = 0;
509 if (bias)
510 {
511 acc_0 = *bias++;
512 acc_1 = *bias++;
513 }
514
515 const int32_t col_loop_cnt = rhs_cols / 4;
516
517 const int8_t *lhs_vec = lhs;
518 const int8_t *rhs_0_ptr = rhs;
519 const int8_t *rhs_1_ptr = rhs + rhs_cols;
520 rhs += 2 * rhs_cols;
521
522 for (int32_t j = col_loop_cnt; j != 0; j--)
523 {
524 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
525 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
526
527 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
528
529 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_0_ptr);
530 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
531 ker_0 = SXTB16(ker_0);
532
533 acc_0 = SMLAD(ker_1, vec_1, acc_0);
534 acc_0 = SMLAD(ker_0, vec_0, acc_0);
535
536 ker_0 = arm_nn_read_s8x4_ia(&rhs_1_ptr);
537 ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
538 ker_0 = SXTB16(ker_0);
539
540 acc_1 = SMLAD(ker_1, vec_1, acc_1);
541 acc_1 = SMLAD(ker_0, vec_0, acc_1);
542 }
543
544 for (int32_t k = col_loop_cnt * 4; k < rhs_cols; k++)
545 {
546 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
547 lhs_vec++;
548 acc_0 += lhs_temp * (*rhs_0_ptr);
549 rhs_0_ptr++;
550 acc_1 += lhs_temp * (*rhs_1_ptr);
551 rhs_1_ptr++;
552 }
553
554 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
555 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
556
557 // Add offset
558 acc_0 += dst_offset;
559 acc_1 += dst_offset;
560 // Clamp the result
561 acc_0 = MAX(acc_0, activation_min);
562 acc_0 = MIN(acc_0, activation_max);
563 acc_1 = MAX(acc_1, activation_min);
564 acc_1 = MIN(acc_1, activation_max);
565 *dst = (int8_t)acc_0;
566 *(dst + address_offset) = (int8_t)acc_1;
567 dst += 2 * address_offset;
568 }
569
570 if (rhs_rows & 0x1)
571 {
572 int32_t acc_0 = 0;
573 if (bias)
574 {
575 acc_0 = *bias++;
576 }
577 const int32_t col_loop_cnt = rhs_cols / 4;
578
579 const int8_t *lhs_vec = lhs;
580 const int8_t *rhs_ptr = rhs;
581
582 for (int32_t i = col_loop_cnt; i != 0; i--)
583 {
584 int32_t vec_0 = arm_nn_read_s8x4_ia(&lhs_vec);
585 int32_t vec_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
586 vec_0 = SXTAB16(lhs_offset_s16x2, vec_0);
587
588 int32_t ker_0 = arm_nn_read_s8x4_ia(&rhs_ptr);
589 int32_t ker_1 = SXTB16_RORn((uint32_t)ker_0, 8);
590 ker_0 = SXTB16(ker_0);
591
592 acc_0 = SMLAD(ker_1, vec_1, acc_0);
593 acc_0 = SMLAD(ker_0, vec_0, acc_0);
594 }
595
596 for (int32_t j = col_loop_cnt * 4; j < rhs_cols; j++)
597 {
598 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
599 lhs_vec++;
600 acc_0 += lhs_temp * (*rhs_ptr);
601 rhs_ptr++;
602 }
603
604 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
605
606 // Add offset
607 acc_0 += dst_offset;
608 // Clamp the result
609 acc_0 = MAX(acc_0, activation_min);
610 acc_0 = MIN(acc_0, activation_max);
611 *dst = (int8_t)acc_0;
612 dst += address_offset;
613 }
614
615 #else
616 (void)kernel_sum;
617
618 const int32_t row_loop_cnt = rhs_rows / 3;
619
620 for (int32_t i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
621 {
622 const int8_t *lhs_ptr = lhs;
623 const int8_t *rhs_ptr_0 = &rhs[0];
624 const int8_t *rhs_ptr_1 = &rhs[rhs_cols];
625 const int8_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
626
627 int32_t res00 = 0;
628 int32_t res01 = 0;
629 int32_t res02 = 0;
630 if (bias)
631 {
632 res00 = *bias++;
633 res01 = *bias++;
634 res02 = *bias++;
635 }
636 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
637 {
638 const int32_t rhs_value0 = (int8_t)*rhs_ptr_0;
639 const int32_t rhs_value1 = (int8_t)*rhs_ptr_1;
640 const int32_t rhs_value2 = (int8_t)*rhs_ptr_2;
641 const int32_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
642
643 res00 += lhs_value * rhs_value0;
644 res01 += lhs_value * rhs_value1;
645 res02 += lhs_value * rhs_value2;
646
647 ++rhs_ptr_0;
648 ++rhs_ptr_1;
649 ++rhs_ptr_2;
650 ++lhs_ptr;
651 }
652 // Quantize down
653 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
654 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
655 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
656
657 // Add offset
658 res00 += dst_offset;
659 res01 += dst_offset;
660 res02 += dst_offset;
661
662 // Clamp the result
663 res00 = MAX(res00, activation_min);
664 res00 = MIN(res00, activation_max);
665 res01 = MAX(res01, activation_min);
666 res01 = MIN(res01, activation_max);
667 res02 = MAX(res02, activation_min);
668 res02 = MIN(res02, activation_max);
669
670 *dst = (int8_t)res00;
671 *(dst + address_offset) = (int8_t)res01;
672 *(dst + 2 * address_offset) = (int8_t)res02;
673 dst += 3 * address_offset;
674
675 rhs += 3 * rhs_cols;
676 }
677
678 const int loop_cnt = rhs_rows % 3;
679
680 for (int32_t i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
681 {
682 const int8_t *lhs_ptr = &lhs[0];
683 const int8_t *rhs_ptr = &rhs[0];
684
685 int32_t res00 = 0;
686 if (bias)
687 {
688 res00 = *bias++;
689 }
690
691 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
692 {
693 int32_t rhs_value0 = (int8_t)rhs_ptr[0];
694 int32_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
695
696 res00 += lhs_value * rhs_value0;
697
698 ++rhs_ptr;
699 ++lhs_ptr;
700 }
701
702 // Quantize down
703 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
704
705 // Add offset
706 res00 += dst_offset;
707
708 // Clamp the result
709 res00 = MAX(res00, activation_min);
710 res00 = MIN(res00, activation_max);
711
712 *dst = (int8_t)res00;
713 dst += address_offset;
714 rhs += rhs_cols;
715 }
716 #endif
717 }
718 return ARM_CMSIS_NN_SUCCESS;
719 }
720
721 /**
722 * @} end of Doxygen group
723 */
724