1 /*
2  * Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_vec_mat_mult_t_s8
22  * Description:  s8 vector by matrix (transposed) multiplication
23  *
24  * $Date:        02. May 2021
25  * $Revision:    V.2.5.0
26  *
27  * Target Processor:  Cortex-M
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup NNBasicMath
39  * @{
40  */
41 
42 /*
43  * s8 vector(lhs) by matrix (transposed) multiplication
44  *
45  * Refer header file for details.
46  *
47  */
arm_nn_vec_mat_mult_t_s8(const q7_t * lhs,const q7_t * rhs,const q31_t * bias,q7_t * dst,const int32_t lhs_offset,const int32_t rhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)48 arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
49                                     const q7_t *rhs,
50                                     const q31_t *bias,
51                                     q7_t *dst,
52                                     const int32_t lhs_offset,
53                                     const int32_t rhs_offset,
54                                     const int32_t dst_offset,
55                                     const int32_t dst_multiplier,
56                                     const int32_t dst_shift,
57                                     const int32_t rhs_cols,
58                                     const int32_t rhs_rows,
59                                     const int32_t activation_min,
60                                     const int32_t activation_max)
61 {
62     (void)rhs_offset;
63 #if defined(ARM_MATH_MVEI)
64     int32_t row_loop_cnt = rhs_rows / 3;
65 
66     for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
67     {
68         int32_t acc_0 = 0;
69         int32_t acc_1 = 0;
70         int32_t acc_2 = 0;
71 
72         const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
73 
74         const int8_t *lhs_vec = lhs;
75         const int8_t *rhs_0 = rhs;
76         const int8_t *rhs_1 = rhs + rhs_cols;
77         const int8_t *rhs_2 = rhs + 2 * rhs_cols;
78 
79         int32_t rhs_sum_0 = 0;
80         int32_t rhs_sum_1 = 0;
81         int32_t rhs_sum_2 = 0;
82 
83         uint32_t col_cnt = (uint32_t)rhs_cols;
84 
85         for (int i = 0; i < col_loop_cnt; i++)
86         {
87             mve_pred16_t p = vctp8q(col_cnt);
88             col_cnt -= 16;
89 
90             const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
91 
92             const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
93             rhs_sum_0 = vaddvaq_p_s8(rhs_sum_0, ker_0, p);
94             acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
95 
96             const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
97             rhs_sum_1 = vaddvaq_p_s8(rhs_sum_1, ker_1, p);
98             acc_1 = vmladavaq_p_s8(acc_1, ker_1, input, p);
99 
100             const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
101             rhs_sum_2 = vaddvaq_p_s8(rhs_sum_2, ker_2, p);
102             acc_2 = vmladavaq_p_s8(acc_2, ker_2, input, p);
103 
104             lhs_vec += 16;
105             rhs_0 += 16;
106             rhs_1 += 16;
107             rhs_2 += 16;
108         }
109         rhs += 3 * rhs_cols;
110 
111         int32x4_t acc = {acc_0, acc_1, acc_2, 0};
112         mve_pred16_t p = vctp32q(3);
113         if (bias)
114         {
115             int32x4_t b = vldrwq_z_s32(bias, p);
116             acc = vaddq_m_s32(vuninitializedq_s32(), acc, b, p);
117             bias += 3;
118         }
119         const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, 0};
120         acc += vdupq_n_s32(lhs_offset) * rhs_sum;
121 
122         acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
123         acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
124         acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
125         acc = vminq_s32(acc, vdupq_n_s32(activation_max));
126         vstrbq_p_s32(dst, acc, p);
127         dst += 3;
128     }
129 
130     const int loop_cnt = rhs_rows % 3;
131     for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
132     {
133         int32_t acc_0 = 0;
134         const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
135         const int8_t *lhs_vec = lhs;
136         const int8_t *rhs_0 = rhs;
137         int32_t rhs_sum_0 = 0;
138         uint32_t col_cnt = (uint32_t)rhs_cols;
139 
140         for (int i = 0; i < col_loop_cnt; i++)
141         {
142             mve_pred16_t p = vctp8q(col_cnt);
143             col_cnt -= 16;
144             const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
145 
146             const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
147             rhs_sum_0 = vaddvaq_p_s8(rhs_sum_0, ker_0, p);
148             acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
149 
150             lhs_vec += 16;
151             rhs_0 += 16;
152         }
153         rhs += rhs_cols;
154 
155         if (bias)
156         {
157             acc_0 += *bias;
158             bias++;
159         }
160         const int32_t offsets = rhs_sum_0 * lhs_offset;
161         acc_0 += offsets;
162         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
163         acc_0 += dst_offset;
164 
165         // Clamp the result
166         acc_0 = MAX(acc_0, activation_min);
167         *dst = MIN(acc_0, activation_max);
168         dst++;
169     }
170 
171 #elif defined(ARM_MATH_DSP)
172     int32_t row_loop_cnt = rhs_rows / 2;
173 
174     const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
175 
176     const uint32_t lhs_offset_s16x2 = __PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
177 
178     for (int32_t i = 0; i < row_loop_cnt; i++)
179     {
180         int32_t acc_0 = 0;
181         int32_t acc_1 = 0;
182         if (bias)
183         {
184             acc_0 = *bias++;
185             acc_1 = *bias++;
186         }
187 
188         const int32_t col_loop_cnt = rhs_cols / 4;
189 
190         const int8_t *lhs_vec = lhs;
191         const int8_t *rhs_0 = rhs;
192         const int8_t *rhs_1 = rhs + rhs_cols;
193         rhs += 2 * rhs_cols;
194 
195         for (int j = col_loop_cnt; j != 0; j--)
196         {
197             int32_t vec_0 = arm_nn_read_q7x4_ia(&lhs_vec);
198             int32_t vec_1 = __SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
199 
200             vec_0 = __SXTAB16(lhs_offset_s16x2, vec_0);
201 
202             int32_t ker_0 = arm_nn_read_q7x4_ia(&rhs_0);
203             int32_t ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
204             ker_0 = __SXTB16(ker_0);
205 
206             acc_0 = __SMLAD(ker_1, vec_1, acc_0);
207             acc_0 = __SMLAD(ker_0, vec_0, acc_0);
208 
209             ker_0 = arm_nn_read_q7x4_ia(&rhs_1);
210             ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
211             ker_0 = __SXTB16(ker_0);
212 
213             acc_1 = __SMLAD(ker_1, vec_1, acc_1);
214             acc_1 = __SMLAD(ker_0, vec_0, acc_1);
215         }
216 
217         for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
218         {
219             const int32_t lhs_temp = (*lhs_vec + lhs_offset);
220             lhs_vec++;
221             acc_0 += lhs_temp * (*rhs_0);
222             rhs_0++;
223             acc_1 += lhs_temp * (*rhs_1);
224             rhs_1++;
225         }
226 
227         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
228         acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
229 
230         // Add offset
231         acc_0 += dst_offset;
232         acc_1 += dst_offset;
233         // Clamp the result
234         acc_0 = MAX(acc_0, activation_min);
235         acc_0 = MIN(acc_0, activation_max);
236         acc_1 = MAX(acc_1, activation_min);
237         acc_1 = MIN(acc_1, activation_max);
238 
239         *dst++ = (q7_t)acc_0;
240         *dst++ = (q7_t)acc_1;
241     }
242 
243     if (rhs_rows & 0x1)
244     {
245         int32_t acc_0 = 0;
246         if (bias)
247         {
248             acc_0 = *bias++;
249         }
250         const int32_t col_loop_cnt = rhs_cols / 4;
251 
252         const int8_t *lhs_vec = lhs;
253         const int8_t *rhs_0 = rhs;
254 
255         for (int i = col_loop_cnt; i != 0; i--)
256         {
257             int32_t vec_0 = arm_nn_read_q7x4_ia(&lhs_vec);
258             int32_t vec_1 = __SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
259             vec_0 = __SXTAB16(lhs_offset_s16x2, vec_0);
260 
261             int32_t ker_0 = arm_nn_read_q7x4_ia(&rhs_0);
262             int32_t ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
263             ker_0 = __SXTB16(ker_0);
264 
265             acc_0 = __SMLAD(ker_1, vec_1, acc_0);
266             acc_0 = __SMLAD(ker_0, vec_0, acc_0);
267         }
268 
269         for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
270         {
271             const int32_t lhs_temp = (*lhs_vec + lhs_offset);
272             lhs_vec++;
273             acc_0 += lhs_temp * (*rhs_0);
274             rhs_0++;
275         }
276 
277         acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
278 
279         // Add offset
280         acc_0 += dst_offset;
281         // Clamp the result
282         acc_0 = MAX(acc_0, activation_min);
283         acc_0 = MIN(acc_0, activation_max);
284 
285         *dst++ = (q7_t)acc_0;
286     }
287 
288 #else
289 
290     int32_t row_loop_cnt = rhs_rows / 3;
291 
292     for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
293     {
294         const q7_t *lhs_ptr = lhs;
295         const q7_t *rhs_ptr_0 = &rhs[0];
296         const q7_t *rhs_ptr_1 = &rhs[rhs_cols];
297         const q7_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
298 
299         q31_t res00 = 0;
300         q31_t res01 = 0;
301         q31_t res02 = 0;
302         if (bias)
303         {
304             res00 = *bias++;
305             res01 = *bias++;
306             res02 = *bias++;
307         }
308         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
309         {
310             const q31_t rhs_value0 = (int8_t)*rhs_ptr_0;
311             const q31_t rhs_value1 = (int8_t)*rhs_ptr_1;
312             const q31_t rhs_value2 = (int8_t)*rhs_ptr_2;
313             const q31_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
314 
315             res00 += lhs_value * rhs_value0;
316             res01 += lhs_value * rhs_value1;
317             res02 += lhs_value * rhs_value2;
318 
319             ++rhs_ptr_0;
320             ++rhs_ptr_1;
321             ++rhs_ptr_2;
322             ++lhs_ptr;
323         }
324         // Quantize down
325         res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
326         res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
327         res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
328 
329         // Add offset
330         res00 += dst_offset;
331         res01 += dst_offset;
332         res02 += dst_offset;
333 
334         // Clamp the result
335         res00 = MAX(res00, activation_min);
336         res00 = MIN(res00, activation_max);
337         res01 = MAX(res01, activation_min);
338         res01 = MIN(res01, activation_max);
339         res02 = MAX(res02, activation_min);
340         res02 = MIN(res02, activation_max);
341 
342         *dst++ = (q7_t)res00;
343         *dst++ = (q7_t)res01;
344         *dst++ = (q7_t)res02;
345 
346         rhs += 3 * rhs_cols;
347     }
348 
349     const int loop_cnt = rhs_rows % 3;
350 
351     for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
352     {
353         const q7_t *lhs_ptr = &lhs[0];
354         const q7_t *rhs_ptr = &rhs[0];
355 
356         q31_t res00 = 0;
357         if (bias)
358         {
359             res00 = *bias++;
360         }
361 
362         for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
363         {
364             q31_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
365             q31_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
366 
367             res00 += lhs_value * rhs_value0;
368 
369             ++rhs_ptr;
370             ++lhs_ptr;
371         }
372 
373         // Quantize down
374         res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
375 
376         // Add offset
377         res00 += dst_offset;
378 
379         // Clamp the result
380         res00 = MAX(res00, activation_min);
381         res00 = MIN(res00, activation_max);
382 
383         *dst++ = (q7_t)res00;
384         rhs += rhs_cols;
385     }
386 #endif
387     return ARM_MATH_SUCCESS;
388 }
389 
390 /**
391  * @} end of NNBasicMath group
392  */
393