1 /*
2 * Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_vec_mat_mult_t_s8
22 * Description: s8 vector by matrix (transposed) multiplication
23 *
24 * $Date: 02. May 2021
25 * $Revision: V.2.5.0
26 *
27 * Target Processor: Cortex-M
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnsupportfunctions.h"
32
33 /**
34 * @ingroup groupSupport
35 */
36
37 /**
38 * @addtogroup NNBasicMath
39 * @{
40 */
41
42 /*
43 * s8 vector(lhs) by matrix (transposed) multiplication
44 *
45 * Refer header file for details.
46 *
47 */
arm_nn_vec_mat_mult_t_s8(const q7_t * lhs,const q7_t * rhs,const q31_t * bias,q7_t * dst,const int32_t lhs_offset,const int32_t rhs_offset,const int32_t dst_offset,const int32_t dst_multiplier,const int32_t dst_shift,const int32_t rhs_cols,const int32_t rhs_rows,const int32_t activation_min,const int32_t activation_max)48 arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
49 const q7_t *rhs,
50 const q31_t *bias,
51 q7_t *dst,
52 const int32_t lhs_offset,
53 const int32_t rhs_offset,
54 const int32_t dst_offset,
55 const int32_t dst_multiplier,
56 const int32_t dst_shift,
57 const int32_t rhs_cols,
58 const int32_t rhs_rows,
59 const int32_t activation_min,
60 const int32_t activation_max)
61 {
62 (void)rhs_offset;
63 #if defined(ARM_MATH_MVEI)
64 int32_t row_loop_cnt = rhs_rows / 3;
65
66 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
67 {
68 int32_t acc_0 = 0;
69 int32_t acc_1 = 0;
70 int32_t acc_2 = 0;
71
72 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
73
74 const int8_t *lhs_vec = lhs;
75 const int8_t *rhs_0 = rhs;
76 const int8_t *rhs_1 = rhs + rhs_cols;
77 const int8_t *rhs_2 = rhs + 2 * rhs_cols;
78
79 int32_t rhs_sum_0 = 0;
80 int32_t rhs_sum_1 = 0;
81 int32_t rhs_sum_2 = 0;
82
83 uint32_t col_cnt = (uint32_t)rhs_cols;
84
85 for (int i = 0; i < col_loop_cnt; i++)
86 {
87 mve_pred16_t p = vctp8q(col_cnt);
88 col_cnt -= 16;
89
90 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
91
92 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
93 rhs_sum_0 = vaddvaq_p_s8(rhs_sum_0, ker_0, p);
94 acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
95
96 const int8x16_t ker_1 = vldrbq_z_s8(rhs_1, p);
97 rhs_sum_1 = vaddvaq_p_s8(rhs_sum_1, ker_1, p);
98 acc_1 = vmladavaq_p_s8(acc_1, ker_1, input, p);
99
100 const int8x16_t ker_2 = vldrbq_z_s8(rhs_2, p);
101 rhs_sum_2 = vaddvaq_p_s8(rhs_sum_2, ker_2, p);
102 acc_2 = vmladavaq_p_s8(acc_2, ker_2, input, p);
103
104 lhs_vec += 16;
105 rhs_0 += 16;
106 rhs_1 += 16;
107 rhs_2 += 16;
108 }
109 rhs += 3 * rhs_cols;
110
111 int32x4_t acc = {acc_0, acc_1, acc_2, 0};
112 mve_pred16_t p = vctp32q(3);
113 if (bias)
114 {
115 int32x4_t b = vldrwq_z_s32(bias, p);
116 acc = vaddq_m_s32(vuninitializedq_s32(), acc, b, p);
117 bias += 3;
118 }
119 const int32x4_t rhs_sum = {rhs_sum_0, rhs_sum_1, rhs_sum_2, 0};
120 acc += vdupq_n_s32(lhs_offset) * rhs_sum;
121
122 acc = arm_requantize_mve(acc, dst_multiplier, dst_shift);
123 acc = vaddq_s32(acc, vdupq_n_s32(dst_offset));
124 acc = vmaxq_s32(acc, vdupq_n_s32(activation_min));
125 acc = vminq_s32(acc, vdupq_n_s32(activation_max));
126 vstrbq_p_s32(dst, acc, p);
127 dst += 3;
128 }
129
130 const int loop_cnt = rhs_rows % 3;
131 for (int i_row_loop_cnt = 0; i_row_loop_cnt < loop_cnt; i_row_loop_cnt++)
132 {
133 int32_t acc_0 = 0;
134 const int32_t col_loop_cnt = (rhs_cols + 15) / 16;
135 const int8_t *lhs_vec = lhs;
136 const int8_t *rhs_0 = rhs;
137 int32_t rhs_sum_0 = 0;
138 uint32_t col_cnt = (uint32_t)rhs_cols;
139
140 for (int i = 0; i < col_loop_cnt; i++)
141 {
142 mve_pred16_t p = vctp8q(col_cnt);
143 col_cnt -= 16;
144 const int8x16_t input = vldrbq_z_s8(lhs_vec, p);
145
146 const int8x16_t ker_0 = vldrbq_z_s8(rhs_0, p);
147 rhs_sum_0 = vaddvaq_p_s8(rhs_sum_0, ker_0, p);
148 acc_0 = vmladavaq_p_s8(acc_0, ker_0, input, p);
149
150 lhs_vec += 16;
151 rhs_0 += 16;
152 }
153 rhs += rhs_cols;
154
155 if (bias)
156 {
157 acc_0 += *bias;
158 bias++;
159 }
160 const int32_t offsets = rhs_sum_0 * lhs_offset;
161 acc_0 += offsets;
162 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
163 acc_0 += dst_offset;
164
165 // Clamp the result
166 acc_0 = MAX(acc_0, activation_min);
167 *dst = MIN(acc_0, activation_max);
168 dst++;
169 }
170
171 #elif defined(ARM_MATH_DSP)
172 int32_t row_loop_cnt = rhs_rows / 2;
173
174 const int16_t lhs_offset_s16 = (int16_t)lhs_offset;
175
176 const uint32_t lhs_offset_s16x2 = __PKHBT(lhs_offset_s16, lhs_offset_s16, 16);
177
178 for (int32_t i = 0; i < row_loop_cnt; i++)
179 {
180 int32_t acc_0 = 0;
181 int32_t acc_1 = 0;
182 if (bias)
183 {
184 acc_0 = *bias++;
185 acc_1 = *bias++;
186 }
187
188 const int32_t col_loop_cnt = rhs_cols / 4;
189
190 const int8_t *lhs_vec = lhs;
191 const int8_t *rhs_0 = rhs;
192 const int8_t *rhs_1 = rhs + rhs_cols;
193 rhs += 2 * rhs_cols;
194
195 for (int j = col_loop_cnt; j != 0; j--)
196 {
197 int32_t vec_0 = arm_nn_read_q7x4_ia(&lhs_vec);
198 int32_t vec_1 = __SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
199
200 vec_0 = __SXTAB16(lhs_offset_s16x2, vec_0);
201
202 int32_t ker_0 = arm_nn_read_q7x4_ia(&rhs_0);
203 int32_t ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
204 ker_0 = __SXTB16(ker_0);
205
206 acc_0 = __SMLAD(ker_1, vec_1, acc_0);
207 acc_0 = __SMLAD(ker_0, vec_0, acc_0);
208
209 ker_0 = arm_nn_read_q7x4_ia(&rhs_1);
210 ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
211 ker_0 = __SXTB16(ker_0);
212
213 acc_1 = __SMLAD(ker_1, vec_1, acc_1);
214 acc_1 = __SMLAD(ker_0, vec_0, acc_1);
215 }
216
217 for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
218 {
219 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
220 lhs_vec++;
221 acc_0 += lhs_temp * (*rhs_0);
222 rhs_0++;
223 acc_1 += lhs_temp * (*rhs_1);
224 rhs_1++;
225 }
226
227 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
228 acc_1 = arm_nn_requantize(acc_1, dst_multiplier, dst_shift);
229
230 // Add offset
231 acc_0 += dst_offset;
232 acc_1 += dst_offset;
233 // Clamp the result
234 acc_0 = MAX(acc_0, activation_min);
235 acc_0 = MIN(acc_0, activation_max);
236 acc_1 = MAX(acc_1, activation_min);
237 acc_1 = MIN(acc_1, activation_max);
238
239 *dst++ = (q7_t)acc_0;
240 *dst++ = (q7_t)acc_1;
241 }
242
243 if (rhs_rows & 0x1)
244 {
245 int32_t acc_0 = 0;
246 if (bias)
247 {
248 acc_0 = *bias++;
249 }
250 const int32_t col_loop_cnt = rhs_cols / 4;
251
252 const int8_t *lhs_vec = lhs;
253 const int8_t *rhs_0 = rhs;
254
255 for (int i = col_loop_cnt; i != 0; i--)
256 {
257 int32_t vec_0 = arm_nn_read_q7x4_ia(&lhs_vec);
258 int32_t vec_1 = __SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)vec_0, 8);
259 vec_0 = __SXTAB16(lhs_offset_s16x2, vec_0);
260
261 int32_t ker_0 = arm_nn_read_q7x4_ia(&rhs_0);
262 int32_t ker_1 = __SXTB16_RORn((uint32_t)ker_0, 8);
263 ker_0 = __SXTB16(ker_0);
264
265 acc_0 = __SMLAD(ker_1, vec_1, acc_0);
266 acc_0 = __SMLAD(ker_0, vec_0, acc_0);
267 }
268
269 for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
270 {
271 const int32_t lhs_temp = (*lhs_vec + lhs_offset);
272 lhs_vec++;
273 acc_0 += lhs_temp * (*rhs_0);
274 rhs_0++;
275 }
276
277 acc_0 = arm_nn_requantize(acc_0, dst_multiplier, dst_shift);
278
279 // Add offset
280 acc_0 += dst_offset;
281 // Clamp the result
282 acc_0 = MAX(acc_0, activation_min);
283 acc_0 = MIN(acc_0, activation_max);
284
285 *dst++ = (q7_t)acc_0;
286 }
287
288 #else
289
290 int32_t row_loop_cnt = rhs_rows / 3;
291
292 for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
293 {
294 const q7_t *lhs_ptr = lhs;
295 const q7_t *rhs_ptr_0 = &rhs[0];
296 const q7_t *rhs_ptr_1 = &rhs[rhs_cols];
297 const q7_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
298
299 q31_t res00 = 0;
300 q31_t res01 = 0;
301 q31_t res02 = 0;
302 if (bias)
303 {
304 res00 = *bias++;
305 res01 = *bias++;
306 res02 = *bias++;
307 }
308 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
309 {
310 const q31_t rhs_value0 = (int8_t)*rhs_ptr_0;
311 const q31_t rhs_value1 = (int8_t)*rhs_ptr_1;
312 const q31_t rhs_value2 = (int8_t)*rhs_ptr_2;
313 const q31_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;
314
315 res00 += lhs_value * rhs_value0;
316 res01 += lhs_value * rhs_value1;
317 res02 += lhs_value * rhs_value2;
318
319 ++rhs_ptr_0;
320 ++rhs_ptr_1;
321 ++rhs_ptr_2;
322 ++lhs_ptr;
323 }
324 // Quantize down
325 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
326 res01 = arm_nn_requantize(res01, dst_multiplier, dst_shift);
327 res02 = arm_nn_requantize(res02, dst_multiplier, dst_shift);
328
329 // Add offset
330 res00 += dst_offset;
331 res01 += dst_offset;
332 res02 += dst_offset;
333
334 // Clamp the result
335 res00 = MAX(res00, activation_min);
336 res00 = MIN(res00, activation_max);
337 res01 = MAX(res01, activation_min);
338 res01 = MIN(res01, activation_max);
339 res02 = MAX(res02, activation_min);
340 res02 = MIN(res02, activation_max);
341
342 *dst++ = (q7_t)res00;
343 *dst++ = (q7_t)res01;
344 *dst++ = (q7_t)res02;
345
346 rhs += 3 * rhs_cols;
347 }
348
349 const int loop_cnt = rhs_rows % 3;
350
351 for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
352 {
353 const q7_t *lhs_ptr = &lhs[0];
354 const q7_t *rhs_ptr = &rhs[0];
355
356 q31_t res00 = 0;
357 if (bias)
358 {
359 res00 = *bias++;
360 }
361
362 for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
363 {
364 q31_t rhs_value0 = (int8_t)rhs_ptr[0] + rhs_offset;
365 q31_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;
366
367 res00 += lhs_value * rhs_value0;
368
369 ++rhs_ptr;
370 ++lhs_ptr;
371 }
372
373 // Quantize down
374 res00 = arm_nn_requantize(res00, dst_multiplier, dst_shift);
375
376 // Add offset
377 res00 += dst_offset;
378
379 // Clamp the result
380 res00 = MAX(res00, activation_min);
381 res00 = MIN(res00, activation_max);
382
383 *dst++ = (q7_t)res00;
384 rhs += rhs_cols;
385 }
386 #endif
387 return ARM_MATH_SUCCESS;
388 }
389
390 /**
391 * @} end of NNBasicMath group
392 */
393