1 /*
2 * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_mat_mult_kernel_s8_s16.c
22 * Description: Matrix-multiplication function for convolution
23 *
24 * $Date: 09. October 2020
25 * $Revision: V.1.0.3
26 *
27 * Target Processor: Cortex-M cores
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32
33 /*
34 * Matrix-multiplication function for convolution with per-channel requantization.
35 *
36 * Refer header file for details.
37 *
38 */
39
arm_nn_mat_mult_kernel_s8_s16(const q7_t * input_a,const q15_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const uint16_t num_col_a,const int32_t * const output_bias,q7_t * out_0)40 q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
41 const q15_t *input_b,
42 const uint16_t output_ch,
43 const int32_t *out_shift,
44 const int32_t *out_mult,
45 const int32_t out_offset,
46 const int16_t activation_min,
47 const int16_t activation_max,
48 const uint16_t num_col_a,
49 const int32_t *const output_bias,
50 q7_t *out_0)
51 {
52 #if defined(ARM_MATH_MVEI)
53 #define ROW_PER_LOOP (4)
54 #define COL_PER_LOOP (8)
55
56 const q7_t *ip_a0_s8 = input_a;
57 q7_t *out_1 = out_0 + output_ch;
58
59 const int32_t *bias = output_bias;
60
61 int32_t row_count = output_ch / ROW_PER_LOOP;
62
63 while (row_count)
64 {
65 const q15_t *ip_b0_s16 = input_b;
66 const q15_t *ip_b1_s16 = input_b + num_col_a;
67
68 const q7_t *ip_a1_s8 = ip_a0_s8 + num_col_a;
69 const q7_t *ip_a2_s8 = ip_a0_s8 + num_col_a * 2;
70 const q7_t *ip_a3_s8 = ip_a0_s8 + num_col_a * 3;
71
72 q31_t ch_0_out_n = bias[0];
73 q31_t ch_1_out_n = bias[1];
74 q31_t ch_2_out_n = bias[2];
75 q31_t ch_3_out_n = bias[3];
76
77 q31_t ch_0_out_n1 = ch_0_out_n;
78 q31_t ch_1_out_n1 = ch_1_out_n;
79 q31_t ch_2_out_n1 = ch_2_out_n;
80 q31_t ch_3_out_n1 = ch_3_out_n;
81 bias += 4;
82
83 int32_t col_count = num_col_a / COL_PER_LOOP;
84
85 while (col_count)
86 {
87 // Load inputs
88 const int16x8_t ip_b0 = vld1q_s16(ip_b0_s16);
89 ip_b0_s16 += COL_PER_LOOP;
90 const int16x8_t ip_b1 = vld1q_s16(ip_b1_s16);
91 ip_b1_s16 += COL_PER_LOOP;
92
93 // Load filters
94 const int16x8_t ip_a0 = vldrbq_s16(ip_a0_s8);
95 ip_a0_s8 += COL_PER_LOOP;
96 const int16x8_t ip_a1 = vldrbq_s16(ip_a1_s8);
97 ip_a1_s8 += COL_PER_LOOP;
98 const int16x8_t ip_a2 = vldrbq_s16(ip_a2_s8);
99 ip_a2_s8 += COL_PER_LOOP;
100 const int16x8_t ip_a3 = vldrbq_s16(ip_a3_s8);
101 ip_a3_s8 += COL_PER_LOOP;
102
103 // MAC
104 ch_0_out_n += vmladavq_s16(ip_b0, ip_a0);
105 ch_1_out_n += vmladavq_s16(ip_b0, ip_a1);
106 ch_2_out_n += vmladavq_s16(ip_b0, ip_a2);
107 ch_3_out_n += vmladavq_s16(ip_b0, ip_a3);
108 ch_0_out_n1 += vmladavq_s16(ip_b1, ip_a0);
109 ch_1_out_n1 += vmladavq_s16(ip_b1, ip_a1);
110 ch_2_out_n1 += vmladavq_s16(ip_b1, ip_a2);
111 ch_3_out_n1 += vmladavq_s16(ip_b1, ip_a3);
112
113 col_count--;
114 }
115
116 /* Handle tail */
117 col_count = (num_col_a & (COL_PER_LOOP - 1)) - 1;
118 while (col_count >= 0)
119 {
120 const int32_t b0 = ip_b0_s16[col_count];
121 const int32_t b1 = ip_b1_s16[col_count];
122
123 ch_0_out_n += b0 * ip_a0_s8[col_count];
124 ch_1_out_n += b0 * ip_a1_s8[col_count];
125 ch_2_out_n += b0 * ip_a2_s8[col_count];
126 ch_3_out_n += b0 * ip_a3_s8[col_count];
127
128 ch_0_out_n1 += b1 * ip_a0_s8[col_count];
129 ch_1_out_n1 += b1 * ip_a1_s8[col_count];
130 ch_2_out_n1 += b1 * ip_a2_s8[col_count];
131 ch_3_out_n1 += b1 * ip_a3_s8[col_count];
132 col_count--;
133 }
134 ip_a0_s8 += (num_col_a & (COL_PER_LOOP - 1));
135
136 int32x4_t out_vec_0;
137 int32x4_t out_vec_1;
138 out_vec_0[0] = ch_0_out_n;
139 out_vec_0[1] = ch_1_out_n;
140 out_vec_0[2] = ch_2_out_n;
141 out_vec_0[3] = ch_3_out_n;
142
143 out_vec_1[0] = ch_0_out_n1;
144 out_vec_1[1] = ch_1_out_n1;
145 out_vec_1[2] = ch_2_out_n1;
146 out_vec_1[3] = ch_3_out_n1;
147
148 int32x4_t mult = vldrwq_s32(out_mult);
149 int32x4_t shift = vldrwq_s32(out_shift);
150 out_mult += ROW_PER_LOOP;
151 out_shift += ROW_PER_LOOP;
152
153 out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult, shift);
154 out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult, shift);
155
156 out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
157 out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
158 out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
159 vstrbq_s32(out_0, out_vec_0);
160 out_0 += ROW_PER_LOOP;
161
162 out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
163 out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
164 out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
165 vstrbq_s32(out_1, out_vec_1);
166 out_1 += ROW_PER_LOOP;
167 row_count--;
168 ip_a0_s8 += (num_col_a * 3);
169 }
170
171 row_count = output_ch & (ROW_PER_LOOP - 1);
172
173 if (row_count)
174 {
175 ip_a0_s8 = input_a + num_col_a * (output_ch & ~3);
176 const mve_pred16_t p = vctp32q((uint32_t)row_count);
177 int32x4_t out_vec_0 = vdupq_n_s32(0);
178 int32x4_t out_vec_1 = vdupq_n_s32(0);
179 int32x4_t mult_tail;
180 int32x4_t shift_tail;
181
182 for (int i_ch = 0; i_ch < row_count; i_ch++)
183 {
184 int32_t output_0 = bias[i_ch];
185 int32_t output_1 = bias[i_ch];
186 const q15_t *ip_b0_s16 = input_b;
187 const q15_t *ip_b1_s16 = input_b + num_col_a;
188
189 for (int i_idx = 0; i_idx < num_col_a; i_idx++)
190 {
191 output_0 += ip_b0_s16[i_idx] * ip_a0_s8[i_idx];
192 output_1 += ip_b1_s16[i_idx] * ip_a0_s8[i_idx];
193 }
194
195 ip_a0_s8 += num_col_a;
196 out_vec_0[i_ch] = output_0;
197 out_vec_1[i_ch] = output_1;
198 mult_tail[i_ch] = out_mult[i_ch];
199 shift_tail[i_ch] = out_shift[i_ch];
200 }
201 out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult_tail, shift_tail);
202 out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult_tail, shift_tail);
203
204 out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
205 out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
206 out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
207 vstrbq_p_s32(out_0, out_vec_0, p);
208
209 out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
210 out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
211 out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
212
213 vstrbq_p_s32(out_1, out_vec_1, p);
214 out_1 += row_count;
215 }
216
217 return out_1;
218
219 #elif defined(ARM_MATH_DSP)
220 /* set up the second output pointers */
221 q7_t *out_1 = out_0 + output_ch;
222 const int32_t *bias = output_bias;
223
224 uint16_t row_count = output_ch / 2;
225 const q7_t *ip_a0 = input_a;
226 /* this loop over rows in A */
227 while (row_count)
228 {
229 /* setup pointers for B */
230 const q15_t *ip_b0 = input_b;
231 const q15_t *ip_b1 = ip_b0 + num_col_a;
232
233 /* align the second pointer for A */
234 const q7_t *ip_a1 = ip_a0 + num_col_a;
235
236 /* Init accumulator with bias for channel N and N + 1 */
237 q31_t ch_0_out_0 = *bias;
238 q31_t ch_0_out_1 = *bias++;
239 q31_t ch_1_out_0 = *bias;
240 q31_t ch_1_out_1 = *bias++;
241
242 uint16_t col_count = num_col_a / 4;
243 /* accumulate over the vector */
244 while (col_count)
245 {
246 q31_t a01, a02, a11, a12;
247 q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
248 q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
249
250 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
251 ip_a1 = read_and_pad(ip_a1, &a11, &a12);
252
253 ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
254 ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
255 ch_1_out_0 = __SMLAD(a11, b0, ch_1_out_0);
256 ch_1_out_1 = __SMLAD(a11, b1, ch_1_out_1);
257
258 b0 = arm_nn_read_q15x2_ia(&ip_b0);
259 b1 = arm_nn_read_q15x2_ia(&ip_b1);
260
261 ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
262 ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
263 ch_1_out_0 = __SMLAD(a12, b0, ch_1_out_0);
264 ch_1_out_1 = __SMLAD(a12, b1, ch_1_out_1);
265
266 col_count--;
267 } /* while over col_count */
268 col_count = num_col_a & 0x3;
269 while (col_count)
270 {
271 q7_t a0 = *ip_a0++;
272 q15_t b0 = *ip_b0++;
273 q7_t a1 = *ip_a1++;
274 q15_t b1 = *ip_b1++;
275
276 ch_0_out_0 += a0 * b0;
277 ch_0_out_1 += a0 * b1;
278 ch_1_out_0 += a1 * b0;
279 ch_1_out_1 += a1 * b1;
280 col_count--;
281 } /* while over col_count */
282
283 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
284 ch_0_out_0 += out_offset;
285 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
286 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
287 *out_0++ = (q7_t)ch_0_out_0;
288
289 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
290 ch_0_out_1 += out_offset;
291 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
292 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
293 *out_1++ = (q7_t)ch_0_out_1;
294 out_mult++;
295 out_shift++;
296
297 ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
298 ch_1_out_0 += out_offset;
299 ch_1_out_0 = MAX(ch_1_out_0, activation_min);
300 ch_1_out_0 = MIN(ch_1_out_0, activation_max);
301 *out_0++ = (q7_t)ch_1_out_0;
302
303 ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
304 ch_1_out_1 += out_offset;
305 ch_1_out_1 = MAX(ch_1_out_1, activation_min);
306 ch_1_out_1 = MIN(ch_1_out_1, activation_max);
307 *out_1++ = (q7_t)ch_1_out_1;
308 out_mult++;
309 out_shift++;
310
311 /* skip row */
312 ip_a0 += num_col_a;
313 row_count--;
314 }
315
316 /* compute the last odd numbered row if any */
317 if (output_ch & 0x1)
318 {
319 /* setup pointers for B */
320 const q15_t *ip_b0 = input_b;
321 const q15_t *ip_b1 = ip_b0 + num_col_a;
322
323 /* load the bias */
324 q31_t ch_0_out_0 = *bias;
325 q31_t ch_0_out_1 = *bias++;
326
327 uint16_t col_count = num_col_a >> 2;
328 while (col_count)
329 {
330 q31_t a01, a02;
331 q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
332 q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
333
334 ip_a0 = read_and_pad(ip_a0, &a01, &a02);
335
336 ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
337 ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
338
339 b0 = arm_nn_read_q15x2_ia(&ip_b0);
340 b1 = arm_nn_read_q15x2_ia(&ip_b1);
341 ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
342 ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
343
344 col_count--;
345 }
346 col_count = num_col_a & 0x3;
347 while (col_count)
348 {
349 q7_t a0 = *ip_a0++;
350 q15_t b0 = *ip_b0++;
351 q15_t b1 = *ip_b1++;
352
353 ch_0_out_0 += a0 * b0;
354 ch_0_out_1 += a0 * b1;
355 col_count--;
356 }
357 ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
358 ch_0_out_0 += out_offset;
359 ch_0_out_0 = MAX(ch_0_out_0, activation_min);
360 ch_0_out_0 = MIN(ch_0_out_0, activation_max);
361 *out_0++ = (q7_t)ch_0_out_0;
362
363 ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
364 ch_0_out_1 += out_offset;
365 ch_0_out_1 = MAX(ch_0_out_1, activation_min);
366 ch_0_out_1 = MIN(ch_0_out_1, activation_max);
367 *out_1++ = (q7_t)ch_0_out_1;
368 out_mult++;
369 out_shift++;
370 }
371
372 out_0 += output_ch;
373
374 /* return the new output pointer with offset */
375 return out_0;
376 #else
377 (void)input_a;
378 (void)input_b;
379 (void)output_ch;
380 (void)out_shift;
381 (void)out_mult;
382 (void)out_offset;
383 (void)activation_min;
384 (void)activation_max;
385 (void)num_col_a;
386 (void)output_bias;
387 (void)out_0;
388 /* To be completed */
389 return NULL;
390 #endif
391 }
392