1 /*
2 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_s8_nt_t_s8
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23 *
24 * $Date: 04 January 2024
25 * $Revision: V.3.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s8 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s8(const int8_t * lhs,const int8_t * rhs,const int32_t * bias,int8_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max,const int32_t row_address_offset,const int32_t lhs_cols_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
49 const int8_t *rhs,
50 const int32_t *bias,
51 int8_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t activation_min,
60 const int32_t activation_max,
61 const int32_t row_address_offset,
62 const int32_t lhs_cols_offset)
63 {
64
65 #if defined(ARM_MATH_MVEI)
66 int i_items = 0;
67 for (; i_items <= (lhs_rows - 4); i_items += 4)
68 {
69 for (int i = 0; i < rhs_rows; i++)
70 {
71 int32_t acc_n0 = 0;
72 int32_t acc_n1 = 0;
73 int32_t acc_n2 = 0;
74 int32_t acc_n3 = 0;
75
76 const int8_t *lhs_vec = lhs;
77 const int8_t *ip_row_1 = lhs + lhs_cols_offset;
78 const int8_t *ip_row_2 = lhs + (2 * lhs_cols_offset);
79 const int8_t *ip_row_3 = lhs + (3 * lhs_cols_offset);
80 const int8_t *col_base = rhs + i * rhs_cols;
81 int32_t sum_tmp = 0;
82
83 #if defined(ARM_MATH_AUTOVECTORIZE)
84 for (int j = 0; j < rhs_cols; j++)
85 {
86 int32_t col = col_base[j];
87 sum_tmp += col;
88 acc_n0 += lhs_vec[j] * col;
89 acc_n1 += ip_row_1[j] * col;
90 acc_n2 += ip_row_2[j] * col;
91 acc_n3 += ip_row_3[j] * col;
92 }
93 #else
94 // Note: If operand initialization is moved around, use '&' constraint to
95 // specify earlyclobber operands.
96 __ASM volatile(" .p2align 2 \n"
97 " wlstp.8 lr, %[cnt], 1f \n"
98 " mov %[sum], 0 \n"
99 " mov %[out0], 0 \n"
100 " mov %[out1], 0 \n"
101 " mov %[out2], 0 \n"
102 " mov %[out3], 0 \n"
103 " vldrb.8 q0, [%[col]], #16 \n"
104 "2: \n"
105 " vaddva.s8 %[sum], q0 \n"
106 " vldrb.8 q1, [%[row0]], #16 \n"
107 " vmladava.s8 %[out0], q0, q1 \n"
108 " vldrb.8 q2, [%[row1]], #16 \n"
109 " vmladava.s8 %[out1], q0, q2 \n"
110 " vldrb.8 q3, [%[row2]], #16 \n"
111 " vmladava.s8 %[out2], q0, q3 \n"
112 " vldrb.8 q4, [%[row3]], #16 \n"
113 " vmladava.s8 %[out3], q0, q4 \n"
114 " vldrb.8 q0, [%[col]], #16 \n"
115 " letp lr, 2b \n"
116 "1: \n"
117 : [col] "+r"(col_base),
118 [sum] "=Te"(sum_tmp),
119 [row0] "+r"(lhs_vec),
120 [row1] "+r"(ip_row_1),
121 [row2] "+r"(ip_row_2),
122 [row3] "+r"(ip_row_3),
123 [out0] "=Te"(acc_n0),
124 [out1] "=Te"(acc_n1),
125 [out2] "=Te"(acc_n2),
126 [out3] "=Te"(acc_n3)
127 : [cnt] "r"(rhs_cols)
128 : "q0", "q1", "q2", "q3", "q4", "memory", "r14");
129 #endif
130 int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3};
131 sum_tmp *= lhs_offset;
132 if (bias)
133 {
134 sum_tmp += bias[i];
135 }
136 res = vaddq_n_s32(res, sum_tmp);
137
138 res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]);
139 res = vaddq_n_s32(res, dst_offset);
140
141 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
142 res = vminq_s32(res, vdupq_n_s32(activation_max));
143
144 const uint32x4_t scatter_offset = {
145 0, (uint32_t)row_address_offset, (uint32_t)row_address_offset * 2, (uint32_t)row_address_offset * 3};
146 vstrbq_scatter_offset_s32(dst, scatter_offset, res);
147 dst++;
148 }
149 lhs += 4 * lhs_cols_offset;
150 dst += 4 * row_address_offset - rhs_rows;
151 }
152
153 for (; i_items < lhs_rows; i_items++)
154 {
155 int32_t acc[4];
156 const int32_t *multipliers = dst_multipliers;
157 const int32_t *shifts = dst_shifts;
158 for (int i = 0; i < rhs_rows; i++)
159 {
160 int32_t acc_n0 = 0;
161 const int8_t *lhs_vec = lhs;
162 const int8_t *col_base = rhs + i * rhs_cols;
163 int32_t sum_tmp = 0;
164
165 #if defined(ARM_MATH_AUTOVECTORIZE)
166 for (int j = 0; j < rhs_cols; j++)
167 {
168 int32_t col = col_base[j];
169 sum_tmp += col;
170 acc_n0 += lhs_vec[j] * col;
171 }
172 #else
173 __ASM volatile(" .p2align 2 \n"
174 " wlstp.8 lr, %[cnt], 1f \n"
175 " mov %[sum], 0 \n"
176 " mov %[out0], 0 \n"
177 " vldrb.8 q0, [%[col]], #16 \n"
178 "2: \n"
179 " vaddva.s8 %[sum], q0 \n"
180 " vldrb.8 q1, [%[row0]], #16 \n"
181 " vmladava.s8 %[out0], q0, q1 \n"
182 " vldrb.8 q0, [%[col]], #16 \n"
183 " letp lr, 2b \n"
184 "1: \n"
185 : [col] "+r"(col_base), [sum] "=Te"(sum_tmp), [row0] "+r"(lhs_vec), [out0] "=Te"(acc_n0)
186 : [cnt] "r"(rhs_cols)
187 : "q0", "q1", "memory", "r14");
188 #endif
189 sum_tmp *= lhs_offset;
190 sum_tmp += acc_n0;
191 if (bias)
192 {
193 sum_tmp += bias[i];
194 }
195 const int32_t index = i & 0x3;
196 acc[index] = sum_tmp;
197
198 if (index == 3)
199 {
200 int32x4_t res = vldrwq_s32(acc);
201 res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts));
202 multipliers += 4;
203 shifts += 4;
204 res = vaddq_n_s32(res, dst_offset);
205 res = vmaxq_s32(res, vdupq_n_s32(activation_min));
206 res = vminq_s32(res, vdupq_n_s32(activation_max));
207 vstrbq_s32(dst, res);
208 dst += 4;
209 }
210 }
211 lhs += lhs_cols_offset;
212 const int32_t tail_rows = rhs_rows & 0x3;
213 for (int i = 0; i < tail_rows; i++)
214 {
215 int32_t acc_n0 = acc[i];
216 acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]);
217 acc_n0 += dst_offset;
218 acc_n0 = MAX(acc_n0, activation_min);
219 acc_n0 = MIN(acc_n0, activation_max);
220 *dst++ = (int8_t)acc_n0;
221 }
222 dst += row_address_offset - rhs_rows;
223 }
224
225 #elif defined(ARM_MATH_DSP)
226 (void)row_address_offset;
227 const int32_t rhs_off0 = rhs_cols - 4;
228 const int32_t lhs_off0 = lhs_cols_offset - 4;
229
230 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
231 {
232 const int8_t *lhs_ptr = &lhs[0];
233 int8_t *dst_ptr = &dst[0];
234
235 int32_t lhs_offset_contribution0 = 0;
236 int32_t lhs_offset_contribution1 = 0;
237
238 for (int32_t x = 0; x < rhs_cols; ++x)
239 {
240 lhs_offset_contribution0 += rhs[x];
241 lhs_offset_contribution1 += rhs[x + rhs_cols];
242 }
243
244 lhs_offset_contribution0 *= lhs_offset;
245 lhs_offset_contribution1 *= lhs_offset;
246 if (bias)
247 {
248 lhs_offset_contribution0 += bias[rhs_rows_idx];
249 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
250 }
251
252 int32_t lhs_rows_idx = lhs_rows >> 1;
253
254 while (lhs_rows_idx)
255 {
256 const int8_t *rhs_ptr = &rhs[0];
257
258 int32_t res00 = lhs_offset_contribution0;
259 int32_t res01 = lhs_offset_contribution1;
260 int32_t res10 = lhs_offset_contribution0;
261 int32_t res11 = lhs_offset_contribution1;
262
263 int32_t rhs_cols_idx = 0;
264
265 int32_t val0, val1, val2, val3, val4, val5;
266
267 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
268 {
269 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
270 val2 = SXTB16(val1);
271 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
272 val3 = SXTB16(val0);
273 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
274 val1 = SXTB16_RORn(val1, 8);
275 val0 = SXTB16_RORn(val0, 8);
276
277 // 4 x MAC res00, res01
278 res00 = SMLAD(val3, val2, res00);
279 val5 = SXTB16(val4);
280 res00 = SMLAD(val0, val1, res00);
281 val4 = SXTB16_RORn(val4, 8);
282 res01 = SMLAD(val3, val5, res01);
283 res01 = SMLAD(val0, val4, res01);
284
285 // 4 x MAC res10, res11
286 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
287 val3 = SXTB16(val0);
288 val0 = SXTB16_RORn(val0, 8);
289 res10 = SMLAD(val3, val2, res10);
290 res11 = SMLAD(val3, val5, res11);
291 res10 = SMLAD(val0, val1, res10);
292 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
293 res11 = SMLAD(val0, val4, res11);
294
295 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
296 val2 = SXTB16(val1);
297 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
298 val3 = SXTB16(val0);
299 val1 = SXTB16_RORn(val1, 8);
300 val0 = SXTB16_RORn(val0, 8);
301
302 // 4 x MAC res00, res01
303 res00 = SMLAD(val3, val2, res00);
304 val5 = SXTB16(val4);
305 res00 = SMLAD(val0, val1, res00);
306 val4 = SXTB16_RORn(val4, 8);
307 res01 = SMLAD(val3, val5, res01);
308 res01 = SMLAD(val0, val4, res01);
309
310 // 4 x MAC res10, res11
311 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
312 val3 = SXTB16(val0);
313 val0 = SXTB16_RORn(val0, 8);
314 res10 = SMLAD(val3, val2, res10);
315 res11 = SMLAD(val3, val5, res11);
316 res10 = SMLAD(val0, val1, res10);
317 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
318 res11 = SMLAD(val0, val4, res11);
319
320 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
321 val2 = SXTB16(val1);
322 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
323 val3 = SXTB16(val0);
324 val1 = SXTB16_RORn(val1, 8);
325 val0 = SXTB16_RORn(val0, 8);
326
327 // 4 x MAC res00, res01
328 res00 = SMLAD(val3, val2, res00);
329 val5 = SXTB16(val4);
330 res00 = SMLAD(val0, val1, res00);
331 val4 = SXTB16_RORn(val4, 8);
332 res01 = SMLAD(val3, val5, res01);
333 res01 = SMLAD(val0, val4, res01);
334
335 // 4 x MAC res10, res11
336 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
337 val3 = SXTB16(val0);
338 val0 = SXTB16_RORn(val0, 8);
339 res10 = SMLAD(val3, val2, res10);
340 res11 = SMLAD(val3, val5, res11);
341 res10 = SMLAD(val0, val1, res10);
342 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
343 res11 = SMLAD(val0, val4, res11);
344
345 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
346 val2 = SXTB16(val1);
347 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
348 val3 = SXTB16(val0);
349 val1 = SXTB16_RORn(val1, 8);
350 val0 = SXTB16_RORn(val0, 8);
351
352 // 4 x MAC res00, res01
353 res00 = SMLAD(val3, val2, res00);
354 val5 = SXTB16(val4);
355 res00 = SMLAD(val0, val1, res00);
356 val4 = SXTB16_RORn(val4, 8);
357 res01 = SMLAD(val3, val5, res01);
358 res01 = SMLAD(val0, val4, res01);
359
360 // 4 x MAC res10, res11
361 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
362 val3 = SXTB16(val0);
363 val0 = SXTB16_RORn(val0, 8);
364 res10 = SMLAD(val3, val2, res10);
365 res11 = SMLAD(val3, val5, res11);
366 res10 = SMLAD(val0, val1, res10);
367 res11 = SMLAD(val0, val4, res11);
368 }
369
370 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
371 {
372 val1 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
373 val2 = SXTB16(val1);
374 val0 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
375 val3 = SXTB16(val0);
376 val4 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
377 val1 = SXTB16_RORn(val1, 8);
378 val0 = SXTB16_RORn(val0, 8);
379
380 // 4 x MAC res00, res01
381 res00 = SMLAD(val3, val2, res00);
382 val5 = SXTB16(val4);
383 res00 = SMLAD(val0, val1, res00);
384 val4 = SXTB16_RORn(val4, 8);
385 res01 = SMLAD(val3, val5, res01);
386 res01 = SMLAD(val0, val4, res01);
387
388 // 4 x MAC res10, res11
389 val0 = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_off0]);
390 val3 = SXTB16(val0);
391 val0 = SXTB16_RORn(val0, 8);
392 res10 = SMLAD(val3, val2, res10);
393 res11 = SMLAD(val3, val5, res11);
394 res10 = SMLAD(val0, val1, res10);
395 res11 = SMLAD(val0, val4, res11);
396 }
397
398 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
399 {
400 int8_t rhs_value0 = rhs_ptr[0];
401 int8_t rhs_value1 = rhs_ptr[rhs_cols];
402 int8_t lhs_value = lhs_ptr[0];
403
404 res00 += lhs_value * rhs_value0;
405 res01 += lhs_value * rhs_value1;
406
407 lhs_value = lhs_ptr[lhs_cols_offset];
408 res10 += lhs_value * rhs_value0;
409 res11 += lhs_value * rhs_value1;
410
411 ++rhs_ptr;
412 ++lhs_ptr;
413 }
414
415 // Quantize down
416 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
417 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
418 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
419 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
420
421 // Add offset
422 res00 += dst_offset;
423 res01 += dst_offset;
424 res10 += dst_offset;
425 res11 += dst_offset;
426
427 // Clamp the result
428 res00 = MAX(res00, activation_min);
429 res00 = MIN(res00, activation_max);
430 res01 = MAX(res01, activation_min);
431 res01 = MIN(res01, activation_max);
432 res10 = MAX(res10, activation_min);
433 res10 = MIN(res10, activation_max);
434 res11 = MAX(res11, activation_min);
435 res11 = MIN(res11, activation_max);
436
437 dst_ptr[0] = (int8_t)res00;
438 dst_ptr[1] = (int8_t)res01;
439 dst_ptr += rhs_rows;
440 dst_ptr[0] = (int8_t)res10;
441 dst_ptr[1] = (int8_t)res11;
442 dst_ptr += rhs_rows;
443
444 lhs_ptr -= rhs_cols;
445 lhs_ptr += 2 * lhs_cols_offset;
446
447 lhs_rows_idx--;
448 }
449
450 // Left-over rows
451 if (lhs_rows % 2)
452 {
453 const int8_t *rhs_ptr = &rhs[0];
454
455 int32_t res00 = lhs_offset_contribution0;
456 int32_t res01 = lhs_offset_contribution1;
457
458 int32_t rhs_cols_idx = 0;
459
460 int32_t val0, val1, val2, val3, val4, val5;
461 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
462 {
463 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
464 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
465 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
466 val3 = SXTB16(val0);
467 val5 = SXTB16(val2);
468 val4 = SXTB16(val1);
469 val0 = SXTB16_RORn(val0, 8);
470 val2 = SXTB16_RORn(val2, 8);
471 val1 = SXTB16_RORn(val1, 8);
472
473 // 4 x MAC res00, res01
474 res00 = SMLAD(val5, val3, res00);
475 res00 = SMLAD(val2, val0, res00);
476 res01 = SMLAD(val5, val4, res01);
477 res01 = SMLAD(val2, val1, res01);
478
479 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
480 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
481 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
482 val3 = SXTB16(val0);
483 val5 = SXTB16(val2);
484 val4 = SXTB16(val1);
485 val0 = SXTB16_RORn(val0, 8);
486 val2 = SXTB16_RORn(val2, 8);
487 val1 = SXTB16_RORn(val1, 8);
488
489 // 4 x MAC res00, res01
490 res00 = SMLAD(val5, val3, res00);
491 res00 = SMLAD(val2, val0, res00);
492 res01 = SMLAD(val5, val4, res01);
493 res01 = SMLAD(val2, val1, res01);
494
495 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
496 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
497 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
498 val3 = SXTB16(val0);
499 val5 = SXTB16(val2);
500 val4 = SXTB16(val1);
501 val0 = SXTB16_RORn(val0, 8);
502 val2 = SXTB16_RORn(val2, 8);
503 val1 = SXTB16_RORn(val1, 8);
504
505 // 4 x MAC res00, res01
506 res00 = SMLAD(val5, val3, res00);
507 res00 = SMLAD(val2, val0, res00);
508 res01 = SMLAD(val5, val4, res01);
509 res01 = SMLAD(val2, val1, res01);
510
511 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
512 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
513 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
514 val3 = SXTB16(val0);
515 val5 = SXTB16(val2);
516 val4 = SXTB16(val1);
517 val0 = SXTB16_RORn(val0, 8);
518 val2 = SXTB16_RORn(val2, 8);
519 val1 = SXTB16_RORn(val1, 8);
520
521 // 4 x MAC res00, res01
522 res00 = SMLAD(val5, val3, res00);
523 res00 = SMLAD(val2, val0, res00);
524 res01 = SMLAD(val5, val4, res01);
525 res01 = SMLAD(val2, val1, res01);
526 }
527
528 for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4)
529 {
530 val0 = arm_nn_read_s8x4_ia((const int8_t **)&rhs_ptr);
531 val1 = arm_nn_read_s8x4((const int8_t *)&rhs_ptr[rhs_off0]);
532 val2 = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr);
533 val3 = SXTB16(val0);
534 val5 = SXTB16(val2);
535 val4 = SXTB16(val1);
536 val0 = SXTB16_RORn(val0, 8);
537 val2 = SXTB16_RORn(val2, 8);
538 val1 = SXTB16_RORn(val1, 8);
539
540 // 4 x MAC res00, res01
541 res00 = SMLAD(val5, val3, res00);
542 res00 = SMLAD(val2, val0, res00);
543 res01 = SMLAD(val5, val4, res01);
544 res01 = SMLAD(val2, val1, res01);
545 }
546
547 // Left-over accumulations
548 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
549 {
550 int8_t rhs_value0 = rhs_ptr[0];
551 int8_t rhs_value1 = rhs_ptr[rhs_cols];
552 int8_t lhs_value = lhs_ptr[0];
553
554 res00 += lhs_value * rhs_value0;
555 res01 += lhs_value * rhs_value1;
556
557 ++rhs_ptr;
558 ++lhs_ptr;
559 }
560
561 // Quantize down
562 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
563 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
564
565 // Add offset
566 res00 += dst_offset;
567 res01 += dst_offset;
568
569 // Clamp the result
570 res00 = MAX(res00, activation_min);
571 res00 = MIN(res00, activation_max);
572 res01 = MAX(res01, activation_min);
573 res01 = MIN(res01, activation_max);
574
575 dst_ptr[0] = (int8_t)res00;
576 dst_ptr[1] = (int8_t)res01;
577 }
578
579 rhs += 2 * rhs_cols;
580 dst += 2;
581 }
582
583 if (rhs_rows % 2)
584 {
585 const int8_t *lhs_ptr = &lhs[0];
586 int8_t *dst_ptr = &dst[0];
587
588 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
589 {
590 const int8_t *rhs_ptr = &rhs[0];
591 int32_t res00 = 0;
592 if (bias)
593 {
594 res00 = bias[rhs_rows - 1];
595 }
596
597 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
598 {
599 int32_t rhs_value = rhs_ptr[0];
600 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
601
602 res00 += lhs_value * rhs_value;
603
604 ++rhs_ptr;
605 ++lhs_ptr;
606 }
607 lhs_ptr -= rhs_cols;
608 lhs_ptr += lhs_cols_offset;
609
610 // Quantize down
611 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
612
613 // Add offset
614 res00 += dst_offset;
615
616 // Clamp the result
617 res00 = MAX(res00, activation_min);
618 res00 = MIN(res00, activation_max);
619
620 dst_ptr[0] = (int8_t)res00;
621 dst_ptr += rhs_rows;
622 }
623 }
624 #else
625 (void)row_address_offset;
626 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
627 {
628 const int8_t *lhs_ptr = &lhs[0];
629 int8_t *dst_ptr = &dst[0];
630
631 int32_t lhs_offset_contribution0 = 0;
632 int32_t lhs_offset_contribution1 = 0;
633
634 for (int32_t x = 0; x < rhs_cols; ++x)
635 {
636 lhs_offset_contribution0 += rhs[x];
637 lhs_offset_contribution1 += rhs[x + rhs_cols];
638 }
639
640 lhs_offset_contribution0 *= lhs_offset;
641 lhs_offset_contribution1 *= lhs_offset;
642 if (bias)
643 {
644 lhs_offset_contribution0 += bias[rhs_rows_idx];
645 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
646 }
647
648 int32_t lhs_rows_idx = lhs_rows >> 1;
649
650 while (lhs_rows_idx)
651 {
652 const int8_t *rhs_ptr = &rhs[0];
653
654 int32_t res00 = lhs_offset_contribution0;
655 int32_t res01 = lhs_offset_contribution1;
656 int32_t res10 = lhs_offset_contribution0;
657 int32_t res11 = lhs_offset_contribution1;
658
659 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
660 {
661 int8_t rhs_value0 = rhs_ptr[0];
662 int8_t rhs_value1 = rhs_ptr[rhs_cols];
663 int8_t lhs_value = lhs_ptr[0];
664
665 res00 += lhs_value * rhs_value0;
666 res01 += lhs_value * rhs_value1;
667
668 lhs_value = lhs_ptr[lhs_cols_offset];
669 res10 += lhs_value * rhs_value0;
670 res11 += lhs_value * rhs_value1;
671
672 ++rhs_ptr;
673 ++lhs_ptr;
674 }
675
676 // Quantize down
677 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
678 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
679 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
680 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
681
682 // Add offset
683 res00 += dst_offset;
684 res01 += dst_offset;
685 res10 += dst_offset;
686 res11 += dst_offset;
687
688 // Clamp the result
689 res00 = MAX(res00, activation_min);
690 res00 = MIN(res00, activation_max);
691 res01 = MAX(res01, activation_min);
692 res01 = MIN(res01, activation_max);
693 res10 = MAX(res10, activation_min);
694 res10 = MIN(res10, activation_max);
695 res11 = MAX(res11, activation_min);
696 res11 = MIN(res11, activation_max);
697
698 dst_ptr[0] = (int8_t)res00;
699 dst_ptr[1] = (int8_t)res01;
700 dst_ptr += rhs_rows;
701 dst_ptr[0] = (int8_t)res10;
702 dst_ptr[1] = (int8_t)res11;
703 dst_ptr += rhs_rows;
704
705 lhs_ptr -= rhs_cols;
706 lhs_ptr += 2 * lhs_cols_offset;
707
708 lhs_rows_idx--;
709 }
710
711 // Left-over rows
712 if (lhs_rows % 2)
713 {
714 const int8_t *rhs_ptr = &rhs[0];
715
716 int32_t res00 = lhs_offset_contribution0;
717 int32_t res01 = lhs_offset_contribution1;
718
719 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
720 {
721 int8_t rhs_value0 = rhs_ptr[0];
722 int8_t rhs_value1 = rhs_ptr[rhs_cols];
723 int8_t lhs_value = lhs_ptr[0];
724
725 res00 += lhs_value * rhs_value0;
726 res01 += lhs_value * rhs_value1;
727
728 ++rhs_ptr;
729 ++lhs_ptr;
730 }
731
732 // Quantize down
733 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
734 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
735
736 // Add offset
737 res00 += dst_offset;
738 res01 += dst_offset;
739
740 // Clamp the result
741 res00 = MAX(res00, activation_min);
742 res00 = MIN(res00, activation_max);
743 res01 = MAX(res01, activation_min);
744 res01 = MIN(res01, activation_max);
745
746 dst_ptr[0] = (int8_t)res00;
747 dst_ptr[1] = (int8_t)res01;
748 }
749
750 rhs += 2 * rhs_cols;
751 dst += 2;
752 }
753
754 if (rhs_rows % 2)
755 {
756 const int8_t *lhs_ptr = &lhs[0];
757 int8_t *dst_ptr = &dst[0];
758
759 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
760 {
761 const int8_t *rhs_ptr = &rhs[0];
762 int32_t res00 = 0;
763 if (bias)
764 {
765 res00 = bias[rhs_rows - 1];
766 }
767
768 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
769 {
770 int32_t rhs_value = rhs_ptr[0];
771 int32_t lhs_value = lhs_ptr[0] + lhs_offset;
772
773 res00 += lhs_value * rhs_value;
774
775 ++rhs_ptr;
776 ++lhs_ptr;
777 }
778 lhs_ptr -= rhs_cols;
779 lhs_ptr += lhs_cols_offset;
780
781 // Quantize down
782 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
783
784 // Add offset
785 res00 += dst_offset;
786
787 // Clamp the result
788 res00 = MAX(res00, activation_min);
789 res00 = MIN(res00, activation_max);
790
791 dst_ptr[0] = (int8_t)res00;
792 dst_ptr += rhs_rows;
793 }
794 }
795 #endif
796 return ARM_CMSIS_NN_SUCCESS;
797 }
798
799 /**
800 * @} end of Doxygen group
801 */
802