1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_nt_t_s8_s32
22  * Description:  Matrix multiplication support function with the right-hand-side (rhs) matrix transposed
23  *
24  * $Date:        31 January 2024
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportConvolution
39  * @{
40  */
41 
42 /*
43  * s32 matrix multiplication with the right-hand-side matrix transposed
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_mat_mult_nt_t_s8_s32(const int8_t * lhs,const int8_t * rhs,int32_t * dst,const int32_t lhs_rows,const int32_t rhs_rows,const int32_t rhs_cols,const int32_t lhs_offset,const int32_t dst_idx_offset)48 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs,
49                                                 const int8_t *rhs,
50                                                 int32_t *dst,
51                                                 const int32_t lhs_rows,
52                                                 const int32_t rhs_rows,
53                                                 const int32_t rhs_cols,
54                                                 const int32_t lhs_offset,
55                                                 const int32_t dst_idx_offset)
56 {
57     int32_t rhs_rows_idx = rhs_rows;
58     const int32_t dst_idx_col_offset = dst_idx_offset * rhs_cols;
59 #if defined(ARM_MATH_MVEI)
60     for (; rhs_rows_idx >= 16; rhs_rows_idx -= 16)
61     {
62         int32_t *dst_ptr = &dst[0];
63         const int8_t *lhs_ptr = &lhs[0];
64         int32_t lhs_rows_idx = lhs_rows;
65 
66         for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4)
67         {
68             const int8_t *rhs_ptr = &rhs[0];
69             int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr);
70             lhs_ptr += rhs_rows;
71             int8x16_t v_lhs1 = vldrbq_s8(lhs_ptr);
72             lhs_ptr += rhs_rows;
73             int8x16_t v_lhs2 = vldrbq_s8(lhs_ptr);
74             lhs_ptr += rhs_rows;
75             int8x16_t v_lhs3 = vldrbq_s8(lhs_ptr);
76             lhs_ptr += rhs_rows;
77 
78             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
79             {
80                 int32_t *ip_dst = dst_ptr;
81 
82                 int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr);
83                 int32_t rhs_sum = vaddvq_s8(v_rhs0);
84                 rhs_sum *= lhs_offset;
85 
86                 *ip_dst += rhs_sum;
87                 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs0, v_rhs0);
88                 ip_dst += dst_idx_col_offset;
89 
90                 *ip_dst += rhs_sum;
91                 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs1, v_rhs0);
92                 ip_dst += dst_idx_col_offset;
93 
94                 *ip_dst += rhs_sum;
95                 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs2, v_rhs0);
96                 ip_dst += dst_idx_col_offset;
97 
98                 *ip_dst += rhs_sum;
99                 *ip_dst = vmladavaq_s8(*ip_dst, v_lhs3, v_rhs0);
100 
101                 dst_ptr += dst_idx_offset;
102                 rhs_ptr += rhs_rows;
103             }
104 
105             dst_ptr += 3 * dst_idx_col_offset;
106         }
107         for (; lhs_rows_idx > 0; lhs_rows_idx--)
108         {
109             const int8_t *rhs_ptr = &rhs[0];
110             int8x16_t v_lhs0 = vldrbq_s8(lhs_ptr);
111 
112             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
113             {
114                 int8x16_t v_rhs0 = vldrbq_s8(rhs_ptr);
115 
116                 int32_t offset_sum = vaddvq_s8(v_rhs0);
117                 *dst_ptr += offset_sum * lhs_offset;
118 
119                 *dst_ptr = vmladavaq_s8(*dst_ptr, v_lhs0, v_rhs0);
120 
121                 dst_ptr += dst_idx_offset;
122                 rhs_ptr += rhs_rows;
123             }
124             lhs_ptr += rhs_rows;
125         }
126 
127         rhs += 16;
128         lhs += 16;
129     }
130     if (rhs_rows_idx)
131     {
132         mve_pred16_t rmdr = (1 << rhs_rows_idx) - 1;
133         int32_t *dst_ptr = &dst[0];
134         const int8_t *lhs_ptr = &lhs[0];
135         int32_t lhs_rows_idx = lhs_rows;
136 
137         for (; lhs_rows_idx >= 4; lhs_rows_idx -= 4)
138         {
139             const int8_t *rhs_ptr = &rhs[0];
140             int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr);
141             lhs_ptr += rhs_rows;
142             int8x16_t v_lhs1 = vldrbq_z_s8(lhs_ptr, rmdr);
143             lhs_ptr += rhs_rows;
144             int8x16_t v_lhs2 = vldrbq_z_s8(lhs_ptr, rmdr);
145             lhs_ptr += rhs_rows;
146             int8x16_t v_lhs3 = vldrbq_z_s8(lhs_ptr, rmdr);
147             lhs_ptr += rhs_rows;
148 
149             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
150             {
151                 int32_t *ip_dst = dst_ptr;
152                 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr);
153 
154                 int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr);
155                 rhs_sum *= lhs_offset;
156 
157                 *ip_dst += rhs_sum;
158                 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs0, v_rhs0, rmdr);
159                 ip_dst += dst_idx_col_offset;
160 
161                 *ip_dst += rhs_sum;
162                 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs1, v_rhs0, rmdr);
163                 ip_dst += dst_idx_col_offset;
164 
165                 *ip_dst += rhs_sum;
166                 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs2, v_rhs0, rmdr);
167                 ip_dst += dst_idx_col_offset;
168 
169                 *ip_dst += rhs_sum;
170                 *ip_dst = vmladavaq_p_s8(*ip_dst, v_lhs3, v_rhs0, rmdr);
171 
172                 dst_ptr += dst_idx_offset;
173                 rhs_ptr += rhs_rows;
174             }
175 
176             dst_ptr += 3 * dst_idx_col_offset;
177         }
178         for (; lhs_rows_idx > 0; lhs_rows_idx--)
179         {
180             const int8_t *rhs_ptr = &rhs[0];
181             int8x16_t v_lhs0 = vldrbq_z_s8(lhs_ptr, rmdr);
182 
183             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
184             {
185                 int8x16_t v_rhs0 = vldrbq_z_s8(rhs_ptr, rmdr);
186 
187                 int32_t rhs_sum = vaddvq_p_s8(v_rhs0, rmdr);
188                 *dst_ptr += rhs_sum * lhs_offset;
189 
190                 *dst_ptr = vmladavaq_p_s8(*dst_ptr, v_lhs0, v_rhs0, rmdr);
191 
192                 dst_ptr += dst_idx_offset;
193                 rhs_ptr += rhs_rows;
194             }
195             lhs_ptr += rhs_rows;
196         }
197     }
198 
199 #elif defined(ARM_MATH_DSP)
200     int16_t lhs_offset_s16 = (int16_t)lhs_offset;
201     const uint32_t lhs_offset_s16x2 = PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
202     for (; rhs_rows_idx >= 8; rhs_rows_idx -= 8)
203     {
204         int32_t *dst_ptr = &dst[0];
205         const int8_t *lhs_ptr = &lhs[0];
206         int32_t lhs_rows_idx = lhs_rows >> 1;
207 
208         while (lhs_rows_idx)
209         {
210             const int8_t *rhs_ptr = &rhs[0];
211 
212             int32_t lhs000, lhs001, lhs010, lhs011, lhs100, lhs101, lhs110, lhs111;
213             read_pad_and_add_s8(lhs_ptr, &lhs000, &lhs001, lhs_offset_s16x2);
214             read_pad_and_add_s8(&lhs_ptr[4], &lhs010, &lhs011, lhs_offset_s16x2);
215             read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs100, &lhs101, lhs_offset_s16x2);
216             read_pad_and_add_s8(&lhs_ptr[rhs_rows + 4], &lhs110, &lhs111, lhs_offset_s16x2);
217 
218             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
219             {
220                 int32_t rhs_val00, rhs_val01;
221                 read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01);
222 
223                 dst_ptr[0] = SMLAD(lhs000, rhs_val00, dst_ptr[0]);
224                 dst_ptr[0] = SMLAD(lhs001, rhs_val01, dst_ptr[0]);
225                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs100, rhs_val00, dst_ptr[dst_idx_col_offset]);
226                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs101, rhs_val01, dst_ptr[dst_idx_col_offset]);
227 
228                 read_and_pad(&rhs_ptr[4], &rhs_val00, &rhs_val01);
229 
230                 dst_ptr[0] = SMLAD(lhs010, rhs_val00, dst_ptr[0]);
231                 dst_ptr[0] = SMLAD(lhs011, rhs_val01, dst_ptr[0]);
232                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs110, rhs_val00, dst_ptr[dst_idx_col_offset]);
233                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs111, rhs_val01, dst_ptr[dst_idx_col_offset]);
234 
235                 dst_ptr += dst_idx_offset;
236                 rhs_ptr += rhs_rows;
237             }
238             dst_ptr += dst_idx_col_offset;
239 
240             lhs_ptr += rhs_rows << 1;
241 
242             lhs_rows_idx--;
243         }
244         // Left-over rows
245         if (lhs_rows % 2)
246         {
247             const int8_t *rhs_ptr = &rhs[0];
248             int32_t lhs00, lhs01, lhs10, lhs11;
249             read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
250             read_pad_and_add_s8(&lhs_ptr[4], &lhs10, &lhs11, lhs_offset_s16x2);
251 
252             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
253             {
254                 int32_t rhs_val00, rhs_val01, rhs_val10, rhs_val11;
255                 read_and_pad(rhs_ptr, &rhs_val00, &rhs_val01);
256                 read_and_pad(&rhs_ptr[4], &rhs_val10, &rhs_val11);
257 
258                 dst_ptr[0] = SMLAD(lhs00, rhs_val00, dst_ptr[0]);
259                 dst_ptr[0] = SMLAD(lhs01, rhs_val01, dst_ptr[0]);
260                 dst_ptr[0] = SMLAD(lhs10, rhs_val10, dst_ptr[0]);
261                 dst_ptr[0] = SMLAD(lhs11, rhs_val11, dst_ptr[0]);
262 
263                 dst_ptr += dst_idx_offset;
264                 rhs_ptr += rhs_rows;
265             }
266         }
267 
268         rhs += 8;
269         lhs += 8;
270     }
271     for (; rhs_rows_idx >= 4; rhs_rows_idx -= 4)
272     {
273         int32_t *dst_ptr = &dst[0];
274         const int8_t *lhs_ptr = &lhs[0];
275 
276         int32_t lhs_rows_idx = lhs_rows >> 1;
277 
278         while (lhs_rows_idx)
279         {
280             const int8_t *rhs_ptr = &rhs[0];
281 
282             int32_t lhs00, lhs01, lhs10, lhs11;
283             read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
284             read_pad_and_add_s8(&lhs_ptr[rhs_rows], &lhs10, &lhs11, lhs_offset_s16x2);
285 
286             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
287             {
288                 int32_t rhs_val0, rhs_val1;
289                 read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1);
290 
291                 dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]);
292                 dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]);
293                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs10, rhs_val0, dst_ptr[dst_idx_col_offset]);
294                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs11, rhs_val1, dst_ptr[dst_idx_col_offset]);
295                 dst_ptr += dst_idx_offset;
296                 rhs_ptr += rhs_rows;
297             }
298             dst_ptr += dst_idx_col_offset;
299 
300             lhs_ptr += rhs_rows << 1;
301 
302             lhs_rows_idx--;
303         }
304         // Left-over rows
305         if (lhs_rows % 2)
306         {
307             const int8_t *rhs_ptr = &rhs[0];
308             int32_t lhs00, lhs01;
309             read_pad_and_add_s8(lhs_ptr, &lhs00, &lhs01, lhs_offset_s16x2);
310 
311             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
312             {
313                 int32_t rhs_val0, rhs_val1;
314                 read_and_pad(rhs_ptr, &rhs_val0, &rhs_val1);
315 
316                 dst_ptr[0] = SMLAD(lhs00, rhs_val0, dst_ptr[0]);
317                 dst_ptr[0] = SMLAD(lhs01, rhs_val1, dst_ptr[0]);
318 
319                 dst_ptr += dst_idx_offset;
320                 rhs_ptr += rhs_rows;
321             }
322         }
323 
324         rhs += 4;
325         lhs += 4;
326     }
327     for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2)
328     {
329         int32_t *dst_ptr = &dst[0];
330         const int8_t *lhs_ptr = &lhs[0];
331 
332         int32_t lhs_rows_idx = lhs_rows >> 1;
333 
334         while (lhs_rows_idx)
335         {
336             const int8_t *rhs_ptr = &rhs[0];
337 
338             int32_t lhs0, lhs1;
339             read_pad_and_add_s8x2(lhs_ptr, &lhs0, lhs_offset_s16x2);
340             read_pad_and_add_s8x2(&lhs_ptr[rhs_rows], &lhs1, lhs_offset_s16x2);
341 
342             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
343             {
344                 int32_t rhs_val;
345                 read_and_pad_s8x2(rhs_ptr, &rhs_val);
346 
347                 dst_ptr[0] = SMLAD(lhs0, rhs_val, dst_ptr[0]);
348                 dst_ptr[dst_idx_col_offset] = SMLAD(lhs1, rhs_val, dst_ptr[dst_idx_col_offset]);
349 
350                 dst_ptr += dst_idx_offset;
351                 rhs_ptr += rhs_rows;
352             }
353             dst_ptr += dst_idx_col_offset;
354 
355             lhs_ptr += rhs_rows << 1;
356 
357             lhs_rows_idx--;
358         }
359         // Left-over rows
360         if (lhs_rows % 2)
361         {
362             const int8_t *rhs_ptr = &rhs[0];
363             const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
364             const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
365 
366             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
367             {
368                 const int32_t rhs_value0 = rhs_ptr[0];
369                 const int32_t rhs_value01 = rhs_ptr[1];
370 
371                 dst_ptr[0] += lhs_value * rhs_value0;
372                 dst_ptr[0] += lhs_value01 * rhs_value01;
373                 dst_ptr += dst_idx_offset;
374                 rhs_ptr += rhs_rows;
375             }
376         }
377 
378         rhs += 2;
379         lhs += 2;
380     }
381 #else
382     for (; rhs_rows_idx >= 2; rhs_rows_idx -= 2)
383     {
384         int32_t *dst_ptr = &dst[0];
385         const int8_t *lhs_ptr = &lhs[0];
386 
387         int32_t lhs_rows_idx = lhs_rows >> 1;
388 
389         while (lhs_rows_idx)
390         {
391             const int8_t *rhs_ptr = &rhs[0];
392 
393             const int32_t lhs_value00 = lhs_ptr[0] + lhs_offset;
394             const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
395 
396             const int32_t lhs_value10 = lhs_ptr[rhs_rows] + lhs_offset;
397             const int32_t lhs_value11 = lhs_ptr[rhs_rows + 1] + lhs_offset;
398 
399             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
400             {
401                 const int32_t rhs_value0 = rhs_ptr[0];
402                 const int32_t rhs_value1 = rhs_ptr[1];
403 
404                 dst_ptr[0] += lhs_value00 * rhs_value0;
405                 dst_ptr[0] += lhs_value01 * rhs_value1;
406 
407                 dst_ptr[dst_idx_col_offset] += lhs_value10 * rhs_value0;
408                 dst_ptr[dst_idx_col_offset] += lhs_value11 * rhs_value1;
409                 dst_ptr += dst_idx_offset;
410                 rhs_ptr += rhs_rows;
411             }
412             dst_ptr += dst_idx_col_offset;
413 
414             lhs_ptr += rhs_rows << 1;
415 
416             lhs_rows_idx--;
417         }
418         // Left-over rows
419         if (lhs_rows % 2)
420         {
421             const int8_t *rhs_ptr = &rhs[0];
422             const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
423             const int32_t lhs_value01 = lhs_ptr[1] + lhs_offset;
424 
425             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
426             {
427                 const int32_t rhs_value0 = rhs_ptr[0];
428                 const int32_t rhs_value01 = rhs_ptr[1];
429 
430                 dst_ptr[0] += lhs_value * rhs_value0;
431                 dst_ptr[0] += lhs_value01 * rhs_value01;
432                 dst_ptr += dst_idx_offset;
433                 rhs_ptr += rhs_rows;
434             }
435         }
436 
437         rhs += 2;
438         lhs += 2;
439     }
440 #endif
441 #if !defined(ARM_MATH_MVEI)
442     if (rhs_rows_idx)
443     {
444         const int8_t *lhs_ptr = &lhs[0];
445         int32_t *dst_ptr = &dst[0];
446 
447         for (int32_t lhs_rows_idx = 0; lhs_rows_idx < lhs_rows; ++lhs_rows_idx)
448         {
449             const int8_t *rhs_ptr = &rhs[0];
450             const int32_t lhs_value = lhs_ptr[0] + lhs_offset;
451 
452             for (int32_t rhs_cols_idx = rhs_cols; rhs_cols_idx != 0; rhs_cols_idx--)
453             {
454                 const int32_t rhs_value = rhs_ptr[0];
455 
456                 *dst_ptr += lhs_value * rhs_value;
457 
458                 dst_ptr += dst_idx_offset;
459                 rhs_ptr += rhs_rows;
460             }
461             lhs_ptr += rhs_rows;
462         }
463     }
464 #endif
465     return ARM_CMSIS_NN_SUCCESS;
466 }
467 
468 /**
469  * @} end of Doxygen group
470  */
471