1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_s8_nt_t_s8
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23 *
24 * $Date: 22 March 2023
25 * $Revision: V.2.1.2
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s8 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
49 const int8_t *rhs,
50 const int32_t *bias,
51 int8_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t activation_min,
60 const int32_t activation_max,
61 const int32_t lhs_cols_offset)
62 {
63
64 #if defined(ARM_MATH_MVEI)
65 int i_items = 0;
66 for (; i_items <= (lhs_rows - 4); i_items += 4)
67 {
68 for (int i = 0; i < rhs_rows; i++)
69 {
70 int32_t acc_n0 = 0;
71 int32_t acc_n1 = 0;
72 int32_t acc_n2 = 0;
73 int32_t acc_n3 = 0;
74
75 const int8_t *lhs_vec = lhs;
76 const int8_t *ip_row_1 = lhs + lhs_cols_offset;
77 const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
78 const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
79 const int8_t *col_base = rhs + i * rhs_cols;
80 int32_t sum_tmp = 0;
81
82 #if defined(ARM_MATH_AUTOVECTORIZE)
83 for (int j = 0; j < rhs_cols; j++)
84 {
85 int32_t col = col_base[j];
86 sum_tmp += col;
87 acc_n0 += lhs_vec[j] * col;
88 acc_n1 += ip_row_1[j] * col;
89 acc_n2 += ip_row_2[j] * col;
90 acc_n3 += ip_row_3[j] * col;
91 }
92 #else
93 // Note: If operand initialization is moved around, use '&' constraint to
94 // specify earlyclobber operands.
95 __ASM volatile(" .p2align 2 \n"
96 " wlstp.8 lr, %[cnt], 1f \n"
97 " mov %[sum], 0 \n"
98 " mov %[out0], 0 \n"
99 " mov %[out1], 0 \n"
100 " mov %[out2], 0 \n"
101 " mov %[out3], 0 \n"
102 " vldrb.8 q0, [%[col]], #16 \n"
103 "2: \n"
104 " vaddva.s8 %[sum], q0 \n"
105 " vldrb.8 q1, [%[row0]], #16 \n"
106 " vmladava.s8 %[out0], q0, q1 \n"
107 " vldrb.8 q2, [%[row1]], #16 \n"
108 " vmladava.s8 %[out1], q0, q2 \n"
109 " vldrb.8 q3, [%[row2]], #16 \n"
110 " vmladava.s8 %[out2], q0, q3 \n"
111 " vldrb.8 q4, [%[row3]], #16 \n"
112 " vmladava.s8 %[out3], q0, q4 \n"
113 " vldrb.8 q0, [%[col]], #16 \n"
114 " letp lr, 2b \n"
115 "1: \n"
116 : [col] "+r"(col_base),
117 [sum] "=Te"(sum_tmp),
118 [row0] "+r"(lhs_vec),
119 [row1] "+r"(ip_row_1),
120 [row2] "+r"(ip_row_2),
121 [row3] "+r"(ip_row_3),
122 [out0] "=Te"(acc_n0),
123 [out1] "=Te"(acc_n1),
124 [out2] "=Te"(acc_n2),
125 [out3] "=Te"(acc_n3)
126 : [cnt] "r"(rhs_cols)
127 : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
128 #endif
129 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
130 sum_tmp *= lhs_offset;
131 if (bias)
132 {
133 sum_tmp += bias[i];
134 }
135 res = vaddq_n_s32(res, sum_tmp);
136
137 res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
138 res = vaddq_n_s32(res, dst_offset);
139
140 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
141 res = vminq_s32(res, vdupq_n_s32(activation_max));
142
143 const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3};
144 vstrbq_scatter_offset_s32(dst, scatter_offset, res);
145 dst++;
146 }
147 lhs += 4 * lhs_cols_offset;
148 dst += (3 * rhs_rows);
149 }
150
151 for (; i_items < lhs_rows; i_items++)
152 {
153 int32_t acc[4];
154 const int32_t *multipliers = dst_multipliers;
155 const int32_t *shifts = dst_shifts;
156 for (int i = 0; i < rhs_rows; i++)
157 {
158 int32_t acc_n0 = 0;
159 const int8_t *lhs_vec = lhs;
160 const int8_t *col_base = rhs + i * rhs_cols;
161 int32_t sum_tmp = 0;
162
163 #if defined(ARM_MATH_AUTOVECTORIZE)
164 for (int j = 0; j < rhs_cols; j++)
165 {
166 int32_t col = col_base[j];
167 sum_tmp += col;
168 acc_n0 += lhs_vec[j] * col;
169 }
170 #else
171 __ASM volatile(" .p2align 2 \n"
172 " wlstp.8 lr, %[cnt], 1f \n"
173 " mov %[sum], 0 \n"
174 " mov %[out0], 0 \n"
175 " vldrb.8 q0, [%[col]], #16 \n"
176 "2: \n"
177 " vaddva.s8 %[sum], q0 \n"
178 " vldrb.8 q1, [%[row0]], #16 \n"
179 " vmladava.s8 %[out0], q0, q1 \n"
180 " vldrb.8 q0, [%[col]], #16 \n"
181 " letp lr, 2b \n"
182 "1: \n"
183 : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
184 : [cnt] "r"(rhs_cols)
185 : "q0", "q1", "memory", "r14");
186 #endif
187 sum_tmp *= lhs_offset;
188 sum_tmp += acc_n0;
189 if (bias)
190 {
191 sum_tmp += bias[i];
192 }
193 const int32_t index = i & 0x3;
194 acc[index] = sum_tmp;
195
196 if (index == 3)
197 {
198 int32x4_t res = vldrwq_s32(acc);
199 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
200 multipliers += 4;
201 shifts += 4;
202 res = vaddq_n_s32(res, dst_offset);
203 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
204 res = vminq_s32(res, vdupq_n_s32(activation_max));
205 vstrbq_s32(dst, res);
206 dst += 4;
207 }
208 }
209 lhs += lhs_cols_offset;
210 const int32_t tail_rows = rhs_rows & 0x3;
211 for (int i = 0; i < tail_rows; i++)
212 {
213 int32_t acc_n0 = acc[i];
214 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
215 acc_n0 += dst_offset;
216 acc_n0 = MAX(acc_n0, activation_min);
217 acc_n0 = MIN(acc_n0, activation_max);
218 *dst++ = (int8_t)acc_n0;
219 }
220 }
221
222 #elif defined(ARM_MATH_DSP)
223 const int32_t rhs_off0 = rhs_cols - 4;
224 const int32_t lhs_off0 = lhs_cols_offset - 4;
225
226 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
227 {
228 const int8_t *lhs_ptr = &lhs[0];
229 int8_t *dst_ptr = &dst[0];
230
231 int32_t lhs_offset_contribution0 = 0;
232 int32_t lhs_offset_contribution1 = 0;
233
234 for (int32_t x = 0; x < rhs_cols; ++x)
235 {
236 lhs_offset_contribution0 += rhs[x];
237 lhs_offset_contribution1 += rhs[x + rhs_cols];
238 }
239
240 lhs_offset_contribution0 *= lhs_offset;
241 lhs_offset_contribution1 *= lhs_offset;
242 if (bias)
243 {
244 lhs_offset_contribution0 += bias[rhs_rows_idx];
245 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
246 }
247
248 int32_t lhs_rows_idx = lhs_rows >> 1;
249
250 while (lhs_rows_idx)
251 {
252 const int8_t *rhs_ptr = &rhs[0];
253
254 int32_t res00 = lhs_offset_contribution0;
255 int32_t res01 = lhs_offset_contribution1;
256 int32_t res10 = lhs_offset_contribution0;
257 int32_t res11 = lhs_offset_contribution1;
258
259 int32_t rhs_cols_idx = 0;
260
261 int32_t val0, val1, val2, val3, val4, val5;
262
263 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
264 {
265 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
266 val2 = SXTB16(val1);
267 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
268 val3 = SXTB16(val0);
269 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
270 val1 = SXTB16_RORn(val1, 8);
271 val0 = SXTB16_RORn(val0, 8);
272
273 // 4 x MAC res00, res01
274 res00 = SMLAD(val3, val2, res00);
275 val5 = SXTB16(val4);
276 res00 = SMLAD(val0, val1, res00);
277 val4 = SXTB16_RORn(val4, 8);
278 res01 = SMLAD(val3, val5, res01);
279 res01 = SMLAD(val0, val4, res01);
280
281 // 4 x MAC res10, res11
282 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
283 val3 = SXTB16(val0);
284 val0 = SXTB16_RORn(val0, 8);
285 res10 = SMLAD(val3, val2, res10);
286 res11 = SMLAD(val3, val5, res11);
287 res10 = SMLAD(val0, val1, res10);
288 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
289 res11 = SMLAD(val0, val4, res11);
290
291 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
292 val2 = SXTB16(val1);
293 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
294 val3 = SXTB16(val0);
295 val1 = SXTB16_RORn(val1, 8);
296 val0 = SXTB16_RORn(val0, 8);
297
298 // 4 x MAC res00, res01
299 res00 = SMLAD(val3, val2, res00);
300 val5 = SXTB16(val4);
301 res00 = SMLAD(val0, val1, res00);
302 val4 = SXTB16_RORn(val4, 8);
303 res01 = SMLAD(val3, val5, res01);
304 res01 = SMLAD(val0, val4, res01);
305
306 // 4 x MAC res10, res11
307 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
308 val3 = SXTB16(val0);
309 val0 = SXTB16_RORn(val0, 8);
310 res10 = SMLAD(val3, val2, res10);
311 res11 = SMLAD(val3, val5, res11);
312 res10 = SMLAD(val0, val1, res10);
313 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
314 res11 = SMLAD(val0, val4, res11);
315
316 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
317 val2 = SXTB16(val1);
318 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
319 val3 = SXTB16(val0);
320 val1 = SXTB16_RORn(val1, 8);
321 val0 = SXTB16_RORn(val0, 8);
322
323 // 4 x MAC res00, res01
324 res00 = SMLAD(val3, val2, res00);
325 val5 = SXTB16(val4);
326 res00 = SMLAD(val0, val1, res00);
327 val4 = SXTB16_RORn(val4, 8);
328 res01 = SMLAD(val3, val5, res01);
329 res01 = SMLAD(val0, val4, res01);
330
331 // 4 x MAC res10, res11
332 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
333 val3 = SXTB16(val0);
334 val0 = SXTB16_RORn(val0, 8);
335 res10 = SMLAD(val3, val2, res10);
336 res11 = SMLAD(val3, val5, res11);
337 res10 = SMLAD(val0, val1, res10);
338 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
339 res11 = SMLAD(val0, val4, res11);
340
341 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
342 val2 = SXTB16(val1);
343 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
344 val3 = SXTB16(val0);
345 val1 = SXTB16_RORn(val1, 8);
346 val0 = SXTB16_RORn(val0, 8);
347
348 // 4 x MAC res00, res01
349 res00 = SMLAD(val3, val2, res00);
350 val5 = SXTB16(val4);
351 res00 = SMLAD(val0, val1, res00);
352 val4 = SXTB16_RORn(val4, 8);
353 res01 = SMLAD(val3, val5, res01);
354 res01 = SMLAD(val0, val4, res01);
355
356 // 4 x MAC res10, res11
357 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
358 val3 = SXTB16(val0);
359 val0 = SXTB16_RORn(val0, 8);
360 res10 = SMLAD(val3, val2, res10);
361 res11 = SMLAD(val3, val5, res11);
362 res10 = SMLAD(val0, val1, res10);
363 res11 = SMLAD(val0, val4, res11);
364 }
365
366 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
367 {
368 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
369 val2 = SXTB16(val1);
370 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
371 val3 = SXTB16(val0);
372 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
373 val1 = SXTB16_RORn(val1, 8);
374 val0 = SXTB16_RORn(val0, 8);
375
376 // 4 x MAC res00, res01
377 res00 = SMLAD(val3, val2, res00);
378 val5 = SXTB16(val4);
379 res00 = SMLAD(val0, val1, res00);
380 val4 = SXTB16_RORn(val4, 8);
381 res01 = SMLAD(val3, val5, res01);
382 res01 = SMLAD(val0, val4, res01);
383
384 // 4 x MAC res10, res11
385 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
386 val3 = SXTB16(val0);
387 val0 = SXTB16_RORn(val0, 8);
388 res10 = SMLAD(val3, val2, res10);
389 res11 = SMLAD(val3, val5, res11);
390 res10 = SMLAD(val0, val1, res10);
391 res11 = SMLAD(val0, val4, res11);
392 }
393
394 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
395 {
396 int8_t rhs_value0 = rhs_ptr[0];
397 int8_t rhs_value1 = rhs_ptr[rhs_cols];
398 int8_t lhs_value = lhs_ptr[0];
399
400 res00 += lhs_value * rhs_value0;
401 res01 += lhs_value * rhs_value1;
402
403 lhs_value = lhs_ptr[lhs_cols_offset];
404 res10 += lhs_value * rhs_value0;
405 res11 += lhs_value * rhs_value1;
406
407 ++rhs_ptr;
408 ++lhs_ptr;
409 }
410
411 // Quantize down
412 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
413 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
414 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
415 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
416
417 // Add offset
418 res00 += dst_offset;
419 res01 += dst_offset;
420 res10 += dst_offset;
421 res11 += dst_offset;
422
423 // Clamp the result
424 res00 = MAX(res00, activation_min);
425 res00 = MIN(res00, activation_max);
426 res01 = MAX(res01, activation_min);
427 res01 = MIN(res01, activation_max);
428 res10 = MAX(res10, activation_min);
429 res10 = MIN(res10, activation_max);
430 res11 = MAX(res11, activation_min);
431 res11 = MIN(res11, activation_max);
432
433 dst_ptr[0] = (int8_t)res00;
434 dst_ptr[1] = (int8_t)res01;
435 dst_ptr += rhs_rows;
436 dst_ptr[0] = (int8_t)res10;
437 dst_ptr[1] = (int8_t)res11;
438 dst_ptr += rhs_rows;
439
440 lhs_ptr -= rhs_cols;
441 lhs_ptr += 2 * lhs_cols_offset;
442
443 lhs_rows_idx--;
444 }
445
446 // Left-over rows
447 if (lhs_rows % 2)
448 {
449 const int8_t *rhs_ptr = &rhs[0];
450
451 int32_t res00 = lhs_offset_contribution0;
452 int32_t res01 = lhs_offset_contribution1;
453
454 int32_t rhs_cols_idx = 0;
455
456 int32_t val0, val1, val2, val3, val4, val5;
457 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
458 {
459 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
460 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
461 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
462 val3 = SXTB16(val0);
463 val5 = SXTB16(val2);
464 val4 = SXTB16(val1);
465 val0 = SXTB16_RORn(val0, 8);
466 val2 = SXTB16_RORn(val2, 8);
467 val1 = SXTB16_RORn(val1, 8);
468
469 // 4 x MAC res00, res01
470 res00 = SMLAD(val5, val3, res00);
471 res00 = SMLAD(val2, val0, res00);
472 res01 = SMLAD(val5, val4, res01);
473 res01 = SMLAD(val2, val1, res01);
474
475 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
476 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
477 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
478 val3 = SXTB16(val0);
479 val5 = SXTB16(val2);
480 val4 = SXTB16(val1);
481 val0 = SXTB16_RORn(val0, 8);
482 val2 = SXTB16_RORn(val2, 8);
483 val1 = SXTB16_RORn(val1, 8);
484
485 // 4 x MAC res00, res01
486 res00 = SMLAD(val5, val3, res00);
487 res00 = SMLAD(val2, val0, res00);
488 res01 = SMLAD(val5, val4, res01);
489 res01 = SMLAD(val2, val1, res01);
490
491 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
492 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
493 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
494 val3 = SXTB16(val0);
495 val5 = SXTB16(val2);
496 val4 = SXTB16(val1);
497 val0 = SXTB16_RORn(val0, 8);
498 val2 = SXTB16_RORn(val2, 8);
499 val1 = SXTB16_RORn(val1, 8);
500
501 // 4 x MAC res00, res01
502 res00 = SMLAD(val5, val3, res00);
503 res00 = SMLAD(val2, val0, res00);
504 res01 = SMLAD(val5, val4, res01);
505 res01 = SMLAD(val2, val1, res01);
506
507 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
508 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
509 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
510 val3 = SXTB16(val0);
511 val5 = SXTB16(val2);
512 val4 = SXTB16(val1);
513 val0 = SXTB16_RORn(val0, 8);
514 val2 = SXTB16_RORn(val2, 8);
515 val1 = SXTB16_RORn(val1, 8);
516
517 // 4 x MAC res00, res01
518 res00 = SMLAD(val5, val3, res00);
519 res00 = SMLAD(val2, val0, res00);
520 res01 = SMLAD(val5, val4, res01);
521 res01 = SMLAD(val2, val1, res01);
522 }
523
524 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
525 {
526 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
527 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
528 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
529 val3 = SXTB16(val0);
530 val5 = SXTB16(val2);
531 val4 = SXTB16(val1);
532 val0 = SXTB16_RORn(val0, 8);
533 val2 = SXTB16_RORn(val2, 8);
534 val1 = SXTB16_RORn(val1, 8);
535
536 // 4 x MAC res00, res01
537 res00 = SMLAD(val5, val3, res00);
538 res00 = SMLAD(val2, val0, res00);
539 res01 = SMLAD(val5, val4, res01);
540 res01 = SMLAD(val2, val1, res01);
541 }
542
543 // Left-over accumulations
544 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
545 {
546 int8_t rhs_value0 = rhs_ptr[0];
547 int8_t rhs_value1 = rhs_ptr[rhs_cols];
548 int8_t lhs_value = lhs_ptr[0];
549
550 res00 += lhs_value * rhs_value0;
551 res01 += lhs_value * rhs_value1;
552
553 ++rhs_ptr;
554 ++lhs_ptr;
555 }
556
557 // Quantize down
558 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
559 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
560
561 // Add offset
562 res00 += dst_offset;
563 res01 += dst_offset;
564
565 // Clamp the result
566 res00 = MAX(res00, activation_min);
567 res00 = MIN(res00, activation_max);
568 res01 = MAX(res01, activation_min);
569 res01 = MIN(res01, activation_max);
570
571 dst_ptr[0] = (int8_t)res00;
572 dst_ptr[1] = (int8_t)res01;
573 }
574
575 rhs += 2 * rhs_cols;
576 dst += 2;
577 }
578
579 if (rhs_rows % 2)
580 {
581 const int8_t *lhs_ptr = &lhs[0];
582 int8_t *dst_ptr = &dst[0];
583
584 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
585 {
586 const int8_t *rhs_ptr = &rhs[0];
587 int32_t res00 = 0;
588 if (bias)
589 {
590 res00 = bias[rhs_rows - 1];
591 }
592
593 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
594 {
595 int32_t rhs_value = rhs_ptr[0];
596 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
597
598 res00 += lhs_value * rhs_value;
599
600 ++rhs_ptr;
601 ++lhs_ptr;
602 }
603 lhs_ptr -= rhs_cols;
604 lhs_ptr += lhs_cols_offset;
605
606 // Quantize down
607 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
608
609 // Add offset
610 res00 += dst_offset;
611
612 // Clamp the result
613 res00 = MAX(res00, activation_min);
614 res00 = MIN(res00, activation_max);
615
616 dst_ptr[0] = (int8_t)res00;
617 dst_ptr += rhs_rows;
618 }
619 }
620 #else
621 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
622 {
623 const int8_t *lhs_ptr = &lhs[0];
624 int8_t *dst_ptr = &dst[0];
625
626 int32_t lhs_offset_contribution0 = 0;
627 int32_t lhs_offset_contribution1 = 0;
628
629 for (int32_t x = 0; x < rhs_cols; ++x)
630 {
631 lhs_offset_contribution0 += rhs[x];
632 lhs_offset_contribution1 += rhs[x + rhs_cols];
633 }
634
635 lhs_offset_contribution0 *= lhs_offset;
636 lhs_offset_contribution1 *= lhs_offset;
637 if (bias)
638 {
639 lhs_offset_contribution0 += bias[rhs_rows_idx];
640 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
641 }
642
643 int32_t lhs_rows_idx = lhs_rows >> 1;
644
645 while (lhs_rows_idx)
646 {
647 const int8_t *rhs_ptr = &rhs[0];
648
649 int32_t res00 = lhs_offset_contribution0;
650 int32_t res01 = lhs_offset_contribution1;
651 int32_t res10 = lhs_offset_contribution0;
652 int32_t res11 = lhs_offset_contribution1;
653
654 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
655 {
656 int8_t rhs_value0 = rhs_ptr[0];
657 int8_t rhs_value1 = rhs_ptr[rhs_cols];
658 int8_t lhs_value = lhs_ptr[0];
659
660 res00 += lhs_value * rhs_value0;
661 res01 += lhs_value * rhs_value1;
662
663 lhs_value = lhs_ptr[lhs_cols_offset];
664 res10 += lhs_value * rhs_value0;
665 res11 += lhs_value * rhs_value1;
666
667 ++rhs_ptr;
668 ++lhs_ptr;
669 }
670
671 // Quantize down
672 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
673 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
674 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
675 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
676
677 // Add offset
678 res00 += dst_offset;
679 res01 += dst_offset;
680 res10 += dst_offset;
681 res11 += dst_offset;
682
683 // Clamp the result
684 res00 = MAX(res00, activation_min);
685 res00 = MIN(res00, activation_max);
686 res01 = MAX(res01, activation_min);
687 res01 = MIN(res01, activation_max);
688 res10 = MAX(res10, activation_min);
689 res10 = MIN(res10, activation_max);
690 res11 = MAX(res11, activation_min);
691 res11 = MIN(res11, activation_max);
692
693 dst_ptr[0] = (int8_t)res00;
694 dst_ptr[1] = (int8_t)res01;
695 dst_ptr += rhs_rows;
696 dst_ptr[0] = (int8_t)res10;
697 dst_ptr[1] = (int8_t)res11;
698 dst_ptr += rhs_rows;
699
700 lhs_ptr -= rhs_cols;
701 lhs_ptr += 2 * lhs_cols_offset;
702
703 lhs_rows_idx--;
704 }
705
706 // Left-over rows
707 if (lhs_rows % 2)
708 {
709 const int8_t *rhs_ptr = &rhs[0];
710
711 int32_t res00 = lhs_offset_contribution0;
712 int32_t res01 = lhs_offset_contribution1;
713
714 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
715 {
716 int8_t rhs_value0 = rhs_ptr[0];
717 int8_t rhs_value1 = rhs_ptr[rhs_cols];
718 int8_t lhs_value = lhs_ptr[0];
719
720 res00 += lhs_value * rhs_value0;
721 res01 += lhs_value * rhs_value1;
722
723 ++rhs_ptr;
724 ++lhs_ptr;
725 }
726
727 // Quantize down
728 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
729 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
730
731 // Add offset
732 res00 += dst_offset;
733 res01 += dst_offset;
734
735 // Clamp the result
736 res00 = MAX(res00, activation_min);
737 res00 = MIN(res00, activation_max);
738 res01 = MAX(res01, activation_min);
739 res01 = MIN(res01, activation_max);
740
741 dst_ptr[0] = (int8_t)res00;
742 dst_ptr[1] = (int8_t)res01;
743 }
744
745 rhs += 2 * rhs_cols;
746 dst += 2;
747 }
748
749 if (rhs_rows % 2)
750 {
751 const int8_t *lhs_ptr = &lhs[0];
752 int8_t *dst_ptr = &dst[0];
753
754 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
755 {
756 const int8_t *rhs_ptr = &rhs[0];
757 int32_t res00 = 0;
758 if (bias)
759 {
760 res00 = bias[rhs_rows - 1];
761 }
762
763 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
764 {
765 int32_t rhs_value = rhs_ptr[0];
766 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
767
768 res00 += lhs_value * rhs_value;
769
770 ++rhs_ptr;
771 ++lhs_ptr;
772 }
773 lhs_ptr -= rhs_cols;
774 lhs_ptr += lhs_cols_offset;
775
776 // Quantize down
777 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
778
779 // Add offset
780 res00 += dst_offset;
781
782 // Clamp the result
783 res00 = MAX(res00, activation_min);
784 res00 = MIN(res00, activation_max);
785
786 dst_ptr[0] = (int8_t)res00;
787 dst_ptr += rhs_rows;
788 }
789 }
790 #endif
791 return ARM_CMSIS_NN_SUCCESS;
792 }
793
794 /**
795 * @} end of Doxygen group
796 */
797