1 /*
2 * Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_s8_nt_t_s8
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23 *
24 * $Date: 09. October 2020
25 * $Revision: V.1.0.3
26 *
27 * Target Processor: Cortex-M
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup NNBasicMath
39 * @{
40 */
41
42 /*
43 * s8 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s8(const q7_t * lhs,const q7_t * rhs,const q31_t * bias,q7_t * dst,const int32_t * dst_multipliers,const int32_t * dst_shifts,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_offset,const int32_t activation_min,const int32_t activation_max)48 arm_status arm_nn_mat_mult_nt_t_s8(const q7_t *lhs,
49 const q7_t *rhs,
50 const q31_t *bias,
51 q7_t *dst,
52 const int32_t *dst_multipliers,
53 const int32_t *dst_shifts,
54 const int32_t lhs_rows,
55 const int32_t rhs_rows,
56 const int32_t rhs_cols,
57 const int32_t lhs_offset,
58 const int32_t dst_offset,
59 const int32_t activation_min,
60 const int32_t activation_max)
61 {
62 #if defined(ARM_MATH_DSP)
63 const int32_t off0 = rhs_cols - 4;
64
65 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
66 {
67 const q7_t *lhs_ptr = &lhs[0];
68 q7_t *dst_ptr = &dst[0];
69
70 q31_t lhs_offset_contribution0 = 0;
71 q31_t lhs_offset_contribution1 = 0;
72
73 for (int32_t x = 0; x < rhs_cols; ++x)
74 {
75 lhs_offset_contribution0 += rhs[x];
76 lhs_offset_contribution1 += rhs[x + rhs_cols];
77 }
78
79 lhs_offset_contribution0 *= lhs_offset;
80 lhs_offset_contribution1 *= lhs_offset;
81 if (bias)
82 {
83 lhs_offset_contribution0 += bias[rhs_rows_idx];
84 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
85 }
86
87 int32_t lhs_rows_idx = lhs_rows >> 1;
88
89 while (lhs_rows_idx)
90 {
91 const q7_t *rhs_ptr = &rhs[0];
92
93 q31_t res00 = lhs_offset_contribution0;
94 q31_t res01 = lhs_offset_contribution1;
95 q31_t res10 = lhs_offset_contribution0;
96 q31_t res11 = lhs_offset_contribution1;
97
98 int32_t rhs_cols_idx = 0;
99
100 q31_t val0, val1, val2, val3, val4, val5;
101
102 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
103 {
104 val1 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
105 val2 = __SXTB16(val1);
106 val0 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
107 val3 = __SXTB16(val0);
108 val4 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
109 val1 = __SXTB16_RORn(val1, 8);
110 val0 = __SXTB16_RORn(val0, 8);
111
112 // 4 x MAC res00, res01
113 res00 = __SMLAD(val3, val2, res00);
114 val5 = __SXTB16(val4);
115 res00 = __SMLAD(val0, val1, res00);
116 val4 = __SXTB16_RORn(val4, 8);
117 res01 = __SMLAD(val3, val5, res01);
118 res01 = __SMLAD(val0, val4, res01);
119
120 // 4 x MAC res10, res11
121 val0 = arm_nn_read_q7x4((const q7_t *)&lhs_ptr[off0]);
122 val3 = __SXTB16(val0);
123 val0 = __SXTB16_RORn(val0, 8);
124 res10 = __SMLAD(val3, val2, res10);
125 res11 = __SMLAD(val3, val5, res11);
126 res10 = __SMLAD(val0, val1, res10);
127 val1 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
128 res11 = __SMLAD(val0, val4, res11);
129
130 val4 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
131 val2 = __SXTB16(val1);
132 val0 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
133 val3 = __SXTB16(val0);
134 val1 = __SXTB16_RORn(val1, 8);
135 val0 = __SXTB16_RORn(val0, 8);
136
137 // 4 x MAC res00, res01
138 res00 = __SMLAD(val3, val2, res00);
139 val5 = __SXTB16(val4);
140 res00 = __SMLAD(val0, val1, res00);
141 val4 = __SXTB16_RORn(val4, 8);
142 res01 = __SMLAD(val3, val5, res01);
143 res01 = __SMLAD(val0, val4, res01);
144
145 // 4 x MAC res10, res11
146 val0 = arm_nn_read_q7x4((const q7_t *)&lhs_ptr[off0]);
147 val3 = __SXTB16(val0);
148 val0 = __SXTB16_RORn(val0, 8);
149 res10 = __SMLAD(val3, val2, res10);
150 res11 = __SMLAD(val3, val5, res11);
151 res10 = __SMLAD(val0, val1, res10);
152 val1 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
153 res11 = __SMLAD(val0, val4, res11);
154
155 val4 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
156 val2 = __SXTB16(val1);
157 val0 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
158 val3 = __SXTB16(val0);
159 val1 = __SXTB16_RORn(val1, 8);
160 val0 = __SXTB16_RORn(val0, 8);
161
162 // 4 x MAC res00, res01
163 res00 = __SMLAD(val3, val2, res00);
164 val5 = __SXTB16(val4);
165 res00 = __SMLAD(val0, val1, res00);
166 val4 = __SXTB16_RORn(val4, 8);
167 res01 = __SMLAD(val3, val5, res01);
168 res01 = __SMLAD(val0, val4, res01);
169
170 // 4 x MAC res10, res11
171 val0 = arm_nn_read_q7x4((const q7_t *)&lhs_ptr[off0]);
172 val3 = __SXTB16(val0);
173 val0 = __SXTB16_RORn(val0, 8);
174 res10 = __SMLAD(val3, val2, res10);
175 res11 = __SMLAD(val3, val5, res11);
176 res10 = __SMLAD(val0, val1, res10);
177 val1 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
178 res11 = __SMLAD(val0, val4, res11);
179
180 val4 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
181 val2 = __SXTB16(val1);
182 val0 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
183 val3 = __SXTB16(val0);
184 val1 = __SXTB16_RORn(val1, 8);
185 val0 = __SXTB16_RORn(val0, 8);
186
187 // 4 x MAC res00, res01
188 res00 = __SMLAD(val3, val2, res00);
189 val5 = __SXTB16(val4);
190 res00 = __SMLAD(val0, val1, res00);
191 val4 = __SXTB16_RORn(val4, 8);
192 res01 = __SMLAD(val3, val5, res01);
193 res01 = __SMLAD(val0, val4, res01);
194
195 // 4 x MAC res10, res11
196 val0 = arm_nn_read_q7x4((const q7_t *)&lhs_ptr[off0]);
197 val3 = __SXTB16(val0);
198 val0 = __SXTB16_RORn(val0, 8);
199 res10 = __SMLAD(val3, val2, res10);
200 res11 = __SMLAD(val3, val5, res11);
201 res10 = __SMLAD(val0, val1, res10);
202 res11 = __SMLAD(val0, val4, res11);
203 }
204
205 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
206 {
207 q7_t rhs_value0 = rhs_ptr[0];
208 q7_t rhs_value1 = rhs_ptr[rhs_cols];
209 q7_t lhs_value = lhs_ptr[0];
210
211 res00 += lhs_value * rhs_value0;
212 res01 += lhs_value * rhs_value1;
213
214 lhs_value = lhs_ptr[rhs_cols];
215 res10 += lhs_value * rhs_value0;
216 res11 += lhs_value * rhs_value1;
217
218 ++rhs_ptr;
219 ++lhs_ptr;
220 }
221
222 // Quantize down
223 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
224 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
225 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
226 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
227
228 // Add offset
229 res00 += dst_offset;
230 res01 += dst_offset;
231 res10 += dst_offset;
232 res11 += dst_offset;
233
234 // Clamp the result
235 res00 = MAX(res00, activation_min);
236 res00 = MIN(res00, activation_max);
237 res01 = MAX(res01, activation_min);
238 res01 = MIN(res01, activation_max);
239 res10 = MAX(res10, activation_min);
240 res10 = MIN(res10, activation_max);
241 res11 = MAX(res11, activation_min);
242 res11 = MIN(res11, activation_max);
243
244 dst_ptr[0] = (q7_t)res00;
245 dst_ptr[1] = (q7_t)res01;
246 dst_ptr += rhs_rows;
247 dst_ptr[0] = (q7_t)res10;
248 dst_ptr[1] = (q7_t)res11;
249 dst_ptr += rhs_rows;
250
251 lhs_ptr += rhs_cols;
252
253 lhs_rows_idx--;
254 }
255
256 // Left-over rows
257 if (lhs_rows % 2)
258 {
259 const q7_t *rhs_ptr = &rhs[0];
260
261 q31_t res00 = lhs_offset_contribution0;
262 q31_t res01 = lhs_offset_contribution1;
263
264 int32_t rhs_cols_idx = 0;
265
266 q31_t val0, val1, val2, val3, val4, val5;
267 for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16)
268 {
269 val0 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
270 val1 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
271 val2 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
272 val3 = __SXTB16(val0);
273 val5 = __SXTB16(val2);
274 val4 = __SXTB16(val1);
275 val0 = __SXTB16_RORn(val0, 8);
276 val2 = __SXTB16_RORn(val2, 8);
277 val1 = __SXTB16_RORn(val1, 8);
278
279 // 4 x MAC res00, res01
280 res00 = __SMLAD(val5, val3, res00);
281 res00 = __SMLAD(val2, val0, res00);
282 res01 = __SMLAD(val5, val4, res01);
283 res01 = __SMLAD(val2, val1, res01);
284
285 val0 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
286 val1 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
287 val2 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
288 val3 = __SXTB16(val0);
289 val5 = __SXTB16(val2);
290 val4 = __SXTB16(val1);
291 val0 = __SXTB16_RORn(val0, 8);
292 val2 = __SXTB16_RORn(val2, 8);
293 val1 = __SXTB16_RORn(val1, 8);
294
295 // 4 x MAC res00, res01
296 res00 = __SMLAD(val5, val3, res00);
297 res00 = __SMLAD(val2, val0, res00);
298 res01 = __SMLAD(val5, val4, res01);
299 res01 = __SMLAD(val2, val1, res01);
300
301 val0 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
302 val1 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
303 val2 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
304 val3 = __SXTB16(val0);
305 val5 = __SXTB16(val2);
306 val4 = __SXTB16(val1);
307 val0 = __SXTB16_RORn(val0, 8);
308 val2 = __SXTB16_RORn(val2, 8);
309 val1 = __SXTB16_RORn(val1, 8);
310
311 // 4 x MAC res00, res01
312 res00 = __SMLAD(val5, val3, res00);
313 res00 = __SMLAD(val2, val0, res00);
314 res01 = __SMLAD(val5, val4, res01);
315 res01 = __SMLAD(val2, val1, res01);
316
317 val0 = arm_nn_read_q7x4_ia((const q7_t **)&rhs_ptr);
318 val1 = arm_nn_read_q7x4((const q7_t *)&rhs_ptr[off0]);
319 val2 = arm_nn_read_q7x4_ia((const q7_t **)&lhs_ptr);
320 val3 = __SXTB16(val0);
321 val5 = __SXTB16(val2);
322 val4 = __SXTB16(val1);
323 val0 = __SXTB16_RORn(val0, 8);
324 val2 = __SXTB16_RORn(val2, 8);
325 val1 = __SXTB16_RORn(val1, 8);
326
327 // 4 x MAC res00, res01
328 res00 = __SMLAD(val5, val3, res00);
329 res00 = __SMLAD(val2, val0, res00);
330 res01 = __SMLAD(val5, val4, res01);
331 res01 = __SMLAD(val2, val1, res01);
332 }
333
334 // Left-over accumulations
335 for (; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
336 {
337 q7_t rhs_value0 = rhs_ptr[0];
338 q7_t rhs_value1 = rhs_ptr[rhs_cols];
339 q7_t lhs_value = lhs_ptr[0];
340
341 res00 += lhs_value * rhs_value0;
342 res01 += lhs_value * rhs_value1;
343
344 ++rhs_ptr;
345 ++lhs_ptr;
346 }
347
348 // Quantize down
349 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
350 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
351
352 // Add offset
353 res00 += dst_offset;
354 res01 += dst_offset;
355
356 // Clamp the result
357 res00 = MAX(res00, activation_min);
358 res00 = MIN(res00, activation_max);
359 res01 = MAX(res01, activation_min);
360 res01 = MIN(res01, activation_max);
361
362 dst_ptr[0] = (q7_t)res00;
363 dst_ptr[1] = (q7_t)res01;
364 }
365
366 rhs += 2 * rhs_cols;
367 dst += 2;
368 }
369
370 if (rhs_rows % 2)
371 {
372 const q7_t *lhs_ptr = &lhs[0];
373 q7_t *dst_ptr = &dst[0];
374
375 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
376 {
377 const q7_t *rhs_ptr = &rhs[0];
378 q31_t res00 = 0;
379 if (bias)
380 {
381 res00 = bias[rhs_rows - 1];
382 }
383
384 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
385 {
386 q31_t rhs_value = rhs_ptr[0];
387 q31_t lhs_value = lhs_ptr[0] + lhs_offset;
388
389 res00 += lhs_value * rhs_value;
390
391 ++rhs_ptr;
392 ++lhs_ptr;
393 }
394
395 // Quantize down
396 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
397
398 // Add offset
399 res00 += dst_offset;
400
401 // Clamp the result
402 res00 = MAX(res00, activation_min);
403 res00 = MIN(res00, activation_max);
404
405 dst_ptr[0] = (q7_t)res00;
406 dst_ptr += rhs_rows;
407 }
408 }
409 #else
410 for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 2); rhs_rows_idx += 2)
411 {
412 const q7_t *lhs_ptr = &lhs[0];
413 q7_t *dst_ptr = &dst[0];
414
415 q31_t lhs_offset_contribution0 = 0;
416 q31_t lhs_offset_contribution1 = 0;
417
418 for (int32_t x = 0; x < rhs_cols; ++x)
419 {
420 lhs_offset_contribution0 += rhs[x];
421 lhs_offset_contribution1 += rhs[x + rhs_cols];
422 }
423
424 lhs_offset_contribution0 *= lhs_offset;
425 lhs_offset_contribution1 *= lhs_offset;
426 if (bias)
427 {
428 lhs_offset_contribution0 += bias[rhs_rows_idx];
429 lhs_offset_contribution1 += bias[rhs_rows_idx + 1];
430 }
431
432 int32_t lhs_rows_idx = lhs_rows >> 1;
433
434 while (lhs_rows_idx)
435 {
436 const q7_t *rhs_ptr = &rhs[0];
437
438 q31_t res00 = lhs_offset_contribution0;
439 q31_t res01 = lhs_offset_contribution1;
440 q31_t res10 = lhs_offset_contribution0;
441 q31_t res11 = lhs_offset_contribution1;
442
443 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
444 {
445 q7_t rhs_value0 = rhs_ptr[0];
446 q7_t rhs_value1 = rhs_ptr[rhs_cols];
447 q7_t lhs_value = lhs_ptr[0];
448
449 res00 += lhs_value * rhs_value0;
450 res01 += lhs_value * rhs_value1;
451
452 lhs_value = lhs_ptr[rhs_cols];
453 res10 += lhs_value * rhs_value0;
454 res11 += lhs_value * rhs_value1;
455
456 ++rhs_ptr;
457 ++lhs_ptr;
458 }
459
460 // Quantize down
461 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
462 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
463 res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
464 res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
465
466 // Add offset
467 res00 += dst_offset;
468 res01 += dst_offset;
469 res10 += dst_offset;
470 res11 += dst_offset;
471
472 // Clamp the result
473 res00 = MAX(res00, activation_min);
474 res00 = MIN(res00, activation_max);
475 res01 = MAX(res01, activation_min);
476 res01 = MIN(res01, activation_max);
477 res10 = MAX(res10, activation_min);
478 res10 = MIN(res10, activation_max);
479 res11 = MAX(res11, activation_min);
480 res11 = MIN(res11, activation_max);
481
482 dst_ptr[0] = (q7_t)res00;
483 dst_ptr[1] = (q7_t)res01;
484 dst_ptr += rhs_rows;
485 dst_ptr[0] = (q7_t)res10;
486 dst_ptr[1] = (q7_t)res11;
487 dst_ptr += rhs_rows;
488
489 lhs_ptr += rhs_cols;
490
491 lhs_rows_idx--;
492 }
493
494 // Left-over rows
495 if (lhs_rows % 2)
496 {
497 const q7_t *rhs_ptr = &rhs[0];
498
499 q31_t res00 = lhs_offset_contribution0;
500 q31_t res01 = lhs_offset_contribution1;
501
502 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
503 {
504 q7_t rhs_value0 = rhs_ptr[0];
505 q7_t rhs_value1 = rhs_ptr[rhs_cols];
506 q7_t lhs_value = lhs_ptr[0];
507
508 res00 += lhs_value * rhs_value0;
509 res01 += lhs_value * rhs_value1;
510
511 ++rhs_ptr;
512 ++lhs_ptr;
513 }
514
515 // Quantize down
516 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]);
517 res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]);
518
519 // Add offset
520 res00 += dst_offset;
521 res01 += dst_offset;
522
523 // Clamp the result
524 res00 = MAX(res00, activation_min);
525 res00 = MIN(res00, activation_max);
526 res01 = MAX(res01, activation_min);
527 res01 = MIN(res01, activation_max);
528
529 dst_ptr[0] = (q7_t)res00;
530 dst_ptr[1] = (q7_t)res01;
531 }
532
533 rhs += 2 * rhs_cols;
534 dst += 2;
535 }
536
537 if (rhs_rows % 2)
538 {
539 const q7_t *lhs_ptr = &lhs[0];
540 q7_t *dst_ptr = &dst[0];
541
542 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
543 {
544 const q7_t *rhs_ptr = &rhs[0];
545 q31_t res00 = 0;
546 if (bias)
547 {
548 res00 = bias[rhs_rows - 1];
549 }
550
551 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
552 {
553 q31_t rhs_value = rhs_ptr[0];
554 q31_t lhs_value = lhs_ptr[0] + lhs_offset;
555
556 res00 += lhs_value * rhs_value;
557
558 ++rhs_ptr;
559 ++lhs_ptr;
560 }
561
562 // Quantize down
563 res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows - 1], dst_shifts[rhs_rows - 1]);
564
565 // Add offset
566 res00 += dst_offset;
567
568 // Clamp the result
569 res00 = MAX(res00, activation_min);
570 res00 = MIN(res00, activation_max);
571
572 dst_ptr[0] = (q7_t)res00;
573 dst_ptr += rhs_rows;
574 }
575 }
576 #endif
577 return ARM_MATH_SUCCESS;
578 }
579
580 /**
581 * @} end of NNBasicMath group
582 */
583