1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_nt_t_s8_s32
22 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23 *
24 * $Date: 31 January 2024
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup supportConvolution
39 * @{
40 */
41
42 /*
43 * s32 matrix multiplication with the right-hand-side matrix transposed
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_mat_mult_nt_t_s8_s32(const int8_t * lhs,const int8_t * rhs,int32_t * dst,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_idx_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs,
49 const int8_t *rhs,
50 int32_t *dst,
51 const int32_t lhs_rows,
52 const int32_t rhs_rows,
53 const int32_t rhs_cols,
54 const int32_t lhs_offset,
55 const int32_t dst_idx_offset)
56 {
57 int32_t rhs_rows_idx = rhs_rows;
58 const int32_t dst_idx_col_offset = dst_idx_offset * rhs_cols;
59 #if defined(ARM_MATH_MVEI)
60 for (; rhs_rows_idx >= 16; rhs_rows_idx -= 16)
61 {
62 int32_t *dst_ptr = &dst[0];
63 const int8_t *lhs_ptr = &lhs[0];
64 int32_t lhs_rows_idx = lhs_rows;
65
66 for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4)
67 {
68 const int8_t *rhs_ptr = &rhs[0];
69 int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr);
70 lhs_ptr += rhs_rows;
71 int8x16_t v_lhs1 = vldrbq_s8(lhs_ptr);
72 lhs_ptr += rhs_rows;
73 int8x16_t v_lhs2 = vldrbq_s8(lhs_ptr);
74 lhs_ptr += rhs_rows;
75 int8x16_t v_lhs3 = vldrbq_s8(lhs_ptr);
76 lhs_ptr += rhs_rows;
77
78 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
79 {
80 int32_t *ip_dst = dst_ptr;
81
82 int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr);
83 int32_t rhs_sum = vaddvq_s8(v_rhs0);
84 rhs_sum *= lhs_offset;
85
86 *ip_dst += rhs_sum;
87 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs0, v_rhs0);
88 ip_dst += dst_idx_col_offset;
89
90 *ip_dst += rhs_sum;
91 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs1, v_rhs0);
92 ip_dst += dst_idx_col_offset;
93
94 *ip_dst += rhs_sum;
95 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs2, v_rhs0);
96 ip_dst += dst_idx_col_offset;
97
98 *ip_dst += rhs_sum;
99 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs3, v_rhs0);
100
101 dst_ptr += dst_idx_offset;
102 rhs_ptr += rhs_rows;
103 }
104
105 dst_ptr += 3 * dst_idx_col_offset;
106 }
107 for (; lhs_rows_idx > 0; lhs_rows_idx--)
108 {
109 const int8_t *rhs_ptr = &rhs[0];
110 int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr);
111
112 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
113 {
114 int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr);
115
116 int32_t offset_sum = vaddvq_s8(v_rhs0);
117 *dst_ptr += offset_sum * lhs_offset;
118
119 *dst_ptr = vmladavaq_s8(*dst_ptr, v_lhs0, v_rhs0);
120
121 dst_ptr += dst_idx_offset;
122 rhs_ptr += rhs_rows;
123 }
124 lhs_ptr += rhs_rows;
125 }
126
127 rhs += 16;
128 lhs += 16;
129 }
130 if (rhs_rows_idx)
131 {
132 mve_pred16_t rmdr = (1 << rhs_rows_idx) - 1;
133 int32_t *dst_ptr = &dst[0];
134 const int8_t *lhs_ptr = &lhs[0];
135 int32_t lhs_rows_idx = lhs_rows;
136
137 for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4)
138 {
139 const int8_t *rhs_ptr = &rhs[0];
140 int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr);
141 lhs_ptr += rhs_rows;
142 int8x16_t v_lhs1 = vldrbq_z_s8(lhs_ptr, rmdr);
143 lhs_ptr += rhs_rows;
144 int8x16_t v_lhs2 = vldrbq_z_s8(lhs_ptr, rmdr);
145 lhs_ptr += rhs_rows;
146 int8x16_t v_lhs3 = vldrbq_z_s8(lhs_ptr, rmdr);
147 lhs_ptr += rhs_rows;
148
149 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
150 {
151 int32_t *ip_dst = dst_ptr;
152 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr);
153
154 int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr);
155 rhs_sum *= lhs_offset;
156
157 *ip_dst += rhs_sum;
158 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs0, v_rhs0, rmdr);
159 ip_dst += dst_idx_col_offset;
160
161 *ip_dst += rhs_sum;
162 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs1, v_rhs0, rmdr);
163 ip_dst += dst_idx_col_offset;
164
165 *ip_dst += rhs_sum;
166 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs2, v_rhs0, rmdr);
167 ip_dst += dst_idx_col_offset;
168
169 *ip_dst += rhs_sum;
170 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs3, v_rhs0, rmdr);
171
172 dst_ptr += dst_idx_offset;
173 rhs_ptr += rhs_rows;
174 }
175
176 dst_ptr += 3 * dst_idx_col_offset;
177 }
178 for (; lhs_rows_idx > 0; lhs_rows_idx--)
179 {
180 const int8_t *rhs_ptr = &rhs[0];
181 int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr);
182
183 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
184 {
185 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr);
186
187 int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr);
188 *dst_ptr += rhs_sum * lhs_offset;
189
190 *dst_ptr = vmladavaq_p_s8(*dst_ptr, v_lhs0, v_rhs0, rmdr);
191
192 dst_ptr += dst_idx_offset;
193 rhs_ptr += rhs_rows;
194 }
195 lhs_ptr += rhs_rows;
196 }
197 }
198
199 #elif defined(ARM_MATH_DSP)
200 int16_t lhs_offset_s16 = (int16_t)lhs_offset;
201 const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
202 for (; rhs_rows_idx >= 8; rhs_rows_idx -= 8)
203 {
204 int32_t *dst_ptr = &dst[0];
205 const int8_t *lhs_ptr = &lhs[0];
206 int32_t lhs_rows_idx = lhs_rows >> 1;
207
208 while (lhs_rows_idx)
209 {
210 const int8_t *rhs_ptr = &rhs[0];
211
212 int32_t lhs000, lhs001, lhs010, lhs011, lhs100, lhs101, lhs110, lhs111;
213 read_pad_and_add_s8(lhs_ptr, &lhs000, &lhs001, lhs_offset_s16x2);
214 read_pad_and_add_s8(&lhs_ptr[4], &lhs010, &lhs011, lhs_offset_s16x2);
215 read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs100, &lhs101, lhs_offset_s16x2);
216 read_pad_and_add_s8(&lhs_ptr[rhs_rows + 4], &lhs110, &lhs111, lhs_offset_s16x2);
217
218 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
219 {
220 int32_t rhs_val00, rhs_val01;
221 read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01);
222
223 dst_ptr[0] = SMLAD(lhs000, rhs_val00, dst_ptr[0]);
224 dst_ptr[0] = SMLAD(lhs001, rhs_val01, dst_ptr[0]);
225 dst_ptr[dst_idx_col_offset] = SMLAD(lhs100, rhs_val00, dst_ptr[dst_idx_col_offset]);
226 dst_ptr[dst_idx_col_offset] = SMLAD(lhs101, rhs_val01, dst_ptr[dst_idx_col_offset]);
227
228 read_and_pad(&rhs_ptr[4], &rhs_val00, &rhs_val01);
229
230 dst_ptr[0] = SMLAD(lhs010, rhs_val00, dst_ptr[0]);
231 dst_ptr[0] = SMLAD(lhs011, rhs_val01, dst_ptr[0]);
232 dst_ptr[dst_idx_col_offset] = SMLAD(lhs110, rhs_val00, dst_ptr[dst_idx_col_offset]);
233 dst_ptr[dst_idx_col_offset] = SMLAD(lhs111, rhs_val01, dst_ptr[dst_idx_col_offset]);
234
235 dst_ptr += dst_idx_offset;
236 rhs_ptr += rhs_rows;
237 }
238 dst_ptr += dst_idx_col_offset;
239
240 lhs_ptr += rhs_rows << 1;
241
242 lhs_rows_idx--;
243 }
244 // Left-over rows
245 if (lhs_rows % 2)
246 {
247 const int8_t *rhs_ptr = &rhs[0];
248 int32_t lhs00, lhs01, lhs10, lhs11;
249 read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
250 read_pad_and_add_s8(&lhs_ptr[4], &lhs10, &lhs11, lhs_offset_s16x2);
251
252 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
253 {
254 int32_t rhs_val00, rhs_val01, rhs_val10, rhs_val11;
255 read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01);
256 read_and_pad(&rhs_ptr[4], &rhs_val10, &rhs_val11);
257
258 dst_ptr[0] = SMLAD(lhs00, rhs_val00, dst_ptr[0]);
259 dst_ptr[0] = SMLAD(lhs01, rhs_val01, dst_ptr[0]);
260 dst_ptr[0] = SMLAD(lhs10, rhs_val10, dst_ptr[0]);
261 dst_ptr[0] = SMLAD(lhs11, rhs_val11, dst_ptr[0]);
262
263 dst_ptr += dst_idx_offset;
264 rhs_ptr += rhs_rows;
265 }
266 }
267
268 rhs += 8;
269 lhs += 8;
270 }
271 for (; rhs_rows_idx >= 4; rhs_rows_idx -= 4)
272 {
273 int32_t *dst_ptr = &dst[0];
274 const int8_t *lhs_ptr = &lhs[0];
275
276 int32_t lhs_rows_idx = lhs_rows >> 1;
277
278 while (lhs_rows_idx)
279 {
280 const int8_t *rhs_ptr = &rhs[0];
281
282 int32_t lhs00, lhs01, lhs10, lhs11;
283 read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
284 read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs10, &lhs11, lhs_offset_s16x2);
285
286 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
287 {
288 int32_t rhs_val0, rhs_val1;
289 read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1);
290
291 dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]);
292 dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]);
293 dst_ptr[dst_idx_col_offset] = SMLAD(lhs10, rhs_val0, dst_ptr[dst_idx_col_offset]);
294 dst_ptr[dst_idx_col_offset] = SMLAD(lhs11, rhs_val1, dst_ptr[dst_idx_col_offset]);
295 dst_ptr += dst_idx_offset;
296 rhs_ptr += rhs_rows;
297 }
298 dst_ptr += dst_idx_col_offset;
299
300 lhs_ptr += rhs_rows << 1;
301
302 lhs_rows_idx--;
303 }
304 // Left-over rows
305 if (lhs_rows % 2)
306 {
307 const int8_t *rhs_ptr = &rhs[0];
308 int32_t lhs00, lhs01;
309 read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
310
311 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
312 {
313 int32_t rhs_val0, rhs_val1;
314 read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1);
315
316 dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]);
317 dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]);
318
319 dst_ptr += dst_idx_offset;
320 rhs_ptr += rhs_rows;
321 }
322 }
323
324 rhs += 4;
325 lhs += 4;
326 }
327 for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2)
328 {
329 int32_t *dst_ptr = &dst[0];
330 const int8_t *lhs_ptr = &lhs[0];
331
332 int32_t lhs_rows_idx = lhs_rows >> 1;
333
334 while (lhs_rows_idx)
335 {
336 const int8_t *rhs_ptr = &rhs[0];
337
338 int32_t lhs0, lhs1;
339 read_pad_and_add_s8x2(lhs_ptr, &lhs0, lhs_offset_s16x2);
340 read_pad_and_add_s8x2(&lhs_ptr[rhs_rows], &lhs1, lhs_offset_s16x2);
341
342 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
343 {
344 int32_t rhs_val;
345 read_and_pad_s8x2(rhs_ptr, &rhs_val);
346
347 dst_ptr[0] = SMLAD(lhs0, rhs_val, dst_ptr[0]);
348 dst_ptr[dst_idx_col_offset] = SMLAD(lhs1, rhs_val, dst_ptr[dst_idx_col_offset]);
349
350 dst_ptr += dst_idx_offset;
351 rhs_ptr += rhs_rows;
352 }
353 dst_ptr += dst_idx_col_offset;
354
355 lhs_ptr += rhs_rows << 1;
356
357 lhs_rows_idx--;
358 }
359 // Left-over rows
360 if (lhs_rows % 2)
361 {
362 const int8_t *rhs_ptr = &rhs[0];
363 const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
364 const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
365
366 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
367 {
368 const int32_t rhs_value0 = rhs_ptr[0];
369 const int32_t rhs_value01 = rhs_ptr[1];
370
371 dst_ptr[0] += lhs_value * rhs_value0;
372 dst_ptr[0] += lhs_value01 * rhs_value01;
373 dst_ptr += dst_idx_offset;
374 rhs_ptr += rhs_rows;
375 }
376 }
377
378 rhs += 2;
379 lhs += 2;
380 }
381 #else
382 for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2)
383 {
384 int32_t *dst_ptr = &dst[0];
385 const int8_t *lhs_ptr = &lhs[0];
386
387 int32_t lhs_rows_idx = lhs_rows >> 1;
388
389 while (lhs_rows_idx)
390 {
391 const int8_t *rhs_ptr = &rhs[0];
392
393 const int32_t lhs_value00 = lhs_ptr[0] + lhs_offset;
394 const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
395
396 const int32_t lhs_value10 = lhs_ptr[rhs_rows] + lhs_offset;
397 const int32_t lhs_value11 = lhs_ptr[rhs_rows + 1] + lhs_offset;
398
399 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
400 {
401 const int32_t rhs_value0 = rhs_ptr[0];
402 const int32_t rhs_value1 = rhs_ptr[1];
403
404 dst_ptr[0] += lhs_value00 * rhs_value0;
405 dst_ptr[0] += lhs_value01 * rhs_value1;
406
407 dst_ptr[dst_idx_col_offset] += lhs_value10 * rhs_value0;
408 dst_ptr[dst_idx_col_offset] += lhs_value11 * rhs_value1;
409 dst_ptr += dst_idx_offset;
410 rhs_ptr += rhs_rows;
411 }
412 dst_ptr += dst_idx_col_offset;
413
414 lhs_ptr += rhs_rows << 1;
415
416 lhs_rows_idx--;
417 }
418 // Left-over rows
419 if (lhs_rows % 2)
420 {
421 const int8_t *rhs_ptr = &rhs[0];
422 const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
423 const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
424
425 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
426 {
427 const int32_t rhs_value0 = rhs_ptr[0];
428 const int32_t rhs_value01 = rhs_ptr[1];
429
430 dst_ptr[0] += lhs_value * rhs_value0;
431 dst_ptr[0] += lhs_value01 * rhs_value01;
432 dst_ptr += dst_idx_offset;
433 rhs_ptr += rhs_rows;
434 }
435 }
436
437 rhs += 2;
438 lhs += 2;
439 }
440 #endif
441 #if !defined(ARM_MATH_MVEI)
442 if (rhs_rows_idx)
443 {
444 const int8_t *lhs_ptr = &lhs[0];
445 int32_t *dst_ptr = &dst[0];
446
447 for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
448 {
449 const int8_t *rhs_ptr = &rhs[0];
450 const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
451
452 for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
453 {
454 const int32_t rhs_value = rhs_ptr[0];
455
456 *dst_ptr += lhs_value * rhs_value;
457
458 dst_ptr += dst_idx_offset;
459 rhs_ptr += rhs_rows;
460 }
461 lhs_ptr += rhs_rows;
462 }
463 }
464 #endif
465 return ARM_CMSIS_NN_SUCCESS;
466 }
467
468 /**
469 * @} end of Doxygen group
470 */
471