1 /*
2  * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_mat_mult_kernel_s8_s16.c
22  * Description:  Matrix-multiplication function for convolution
23  *
24  * $Date:        09. October 2020
25  * $Revision:    V.1.0.3
26  *
27  * Target Processor:  Cortex-M cores
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /*
34  * Matrix-multiplication function for convolution with per-channel requantization.
35  *
36  * Refer header file for details.
37  *
38  */
39 
arm_nn_mat_mult_kernel_s8_s16(const q7_t * input_a,const q15_t * input_b,const uint16_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int16_t activation_min,const int16_t activation_max,const uint16_t num_col_a,const int32_t * const output_bias,q7_t * out_0)40 q7_t *arm_nn_mat_mult_kernel_s8_s16(const q7_t *input_a,
41                                     const q15_t *input_b,
42                                     const uint16_t output_ch,
43                                     const int32_t *out_shift,
44                                     const int32_t *out_mult,
45                                     const int32_t out_offset,
46                                     const int16_t activation_min,
47                                     const int16_t activation_max,
48                                     const uint16_t num_col_a,
49                                     const int32_t *const output_bias,
50                                     q7_t *out_0)
51 {
52 #if defined(ARM_MATH_MVEI)
53 #define ROW_PER_LOOP (4)
54 #define COL_PER_LOOP (8)
55 
56     const q7_t *ip_a0_s8 = input_a;
57     q7_t *out_1 = out_0 + output_ch;
58 
59     const int32_t *bias = output_bias;
60 
61     int32_t row_count = output_ch / ROW_PER_LOOP;
62 
63     while (row_count)
64     {
65         const q15_t *ip_b0_s16 = input_b;
66         const q15_t *ip_b1_s16 = input_b + num_col_a;
67 
68         const q7_t *ip_a1_s8 = ip_a0_s8 + num_col_a;
69         const q7_t *ip_a2_s8 = ip_a0_s8 + num_col_a * 2;
70         const q7_t *ip_a3_s8 = ip_a0_s8 + num_col_a * 3;
71 
72         q31_t ch_0_out_n = bias[0];
73         q31_t ch_1_out_n = bias[1];
74         q31_t ch_2_out_n = bias[2];
75         q31_t ch_3_out_n = bias[3];
76 
77         q31_t ch_0_out_n1 = ch_0_out_n;
78         q31_t ch_1_out_n1 = ch_1_out_n;
79         q31_t ch_2_out_n1 = ch_2_out_n;
80         q31_t ch_3_out_n1 = ch_3_out_n;
81         bias += 4;
82 
83         int32_t col_count = num_col_a / COL_PER_LOOP;
84 
85         while (col_count)
86         {
87             // Load inputs
88             const int16x8_t ip_b0 = vld1q_s16(ip_b0_s16);
89             ip_b0_s16 += COL_PER_LOOP;
90             const int16x8_t ip_b1 = vld1q_s16(ip_b1_s16);
91             ip_b1_s16 += COL_PER_LOOP;
92 
93             // Load filters
94             const int16x8_t ip_a0 = vldrbq_s16(ip_a0_s8);
95             ip_a0_s8 += COL_PER_LOOP;
96             const int16x8_t ip_a1 = vldrbq_s16(ip_a1_s8);
97             ip_a1_s8 += COL_PER_LOOP;
98             const int16x8_t ip_a2 = vldrbq_s16(ip_a2_s8);
99             ip_a2_s8 += COL_PER_LOOP;
100             const int16x8_t ip_a3 = vldrbq_s16(ip_a3_s8);
101             ip_a3_s8 += COL_PER_LOOP;
102 
103             // MAC
104             ch_0_out_n += vmladavq_s16(ip_b0, ip_a0);
105             ch_1_out_n += vmladavq_s16(ip_b0, ip_a1);
106             ch_2_out_n += vmladavq_s16(ip_b0, ip_a2);
107             ch_3_out_n += vmladavq_s16(ip_b0, ip_a3);
108             ch_0_out_n1 += vmladavq_s16(ip_b1, ip_a0);
109             ch_1_out_n1 += vmladavq_s16(ip_b1, ip_a1);
110             ch_2_out_n1 += vmladavq_s16(ip_b1, ip_a2);
111             ch_3_out_n1 += vmladavq_s16(ip_b1, ip_a3);
112 
113             col_count--;
114         }
115 
116         /* Handle tail */
117         col_count = (num_col_a & (COL_PER_LOOP - 1)) - 1;
118         while (col_count >= 0)
119         {
120             const int32_t b0 = ip_b0_s16[col_count];
121             const int32_t b1 = ip_b1_s16[col_count];
122 
123             ch_0_out_n += b0 * ip_a0_s8[col_count];
124             ch_1_out_n += b0 * ip_a1_s8[col_count];
125             ch_2_out_n += b0 * ip_a2_s8[col_count];
126             ch_3_out_n += b0 * ip_a3_s8[col_count];
127 
128             ch_0_out_n1 += b1 * ip_a0_s8[col_count];
129             ch_1_out_n1 += b1 * ip_a1_s8[col_count];
130             ch_2_out_n1 += b1 * ip_a2_s8[col_count];
131             ch_3_out_n1 += b1 * ip_a3_s8[col_count];
132             col_count--;
133         }
134         ip_a0_s8 += (num_col_a & (COL_PER_LOOP - 1));
135 
136         int32x4_t out_vec_0;
137         int32x4_t out_vec_1;
138         out_vec_0[0] = ch_0_out_n;
139         out_vec_0[1] = ch_1_out_n;
140         out_vec_0[2] = ch_2_out_n;
141         out_vec_0[3] = ch_3_out_n;
142 
143         out_vec_1[0] = ch_0_out_n1;
144         out_vec_1[1] = ch_1_out_n1;
145         out_vec_1[2] = ch_2_out_n1;
146         out_vec_1[3] = ch_3_out_n1;
147 
148         int32x4_t mult = vldrwq_s32(out_mult);
149         int32x4_t shift = vldrwq_s32(out_shift);
150         out_mult += ROW_PER_LOOP;
151         out_shift += ROW_PER_LOOP;
152 
153         out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult, shift);
154         out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult, shift);
155 
156         out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
157         out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
158         out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
159         vstrbq_s32(out_0, out_vec_0);
160         out_0 += ROW_PER_LOOP;
161 
162         out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
163         out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
164         out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
165         vstrbq_s32(out_1, out_vec_1);
166         out_1 += ROW_PER_LOOP;
167         row_count--;
168         ip_a0_s8 += (num_col_a * 3);
169     }
170 
171     row_count = output_ch & (ROW_PER_LOOP - 1);
172 
173     if (row_count)
174     {
175         ip_a0_s8 = input_a + num_col_a * (output_ch & ~3);
176         const mve_pred16_t p = vctp32q((uint32_t)row_count);
177         int32x4_t out_vec_0 = vdupq_n_s32(0);
178         int32x4_t out_vec_1 = vdupq_n_s32(0);
179         int32x4_t mult_tail;
180         int32x4_t shift_tail;
181 
182         for (int i_ch = 0; i_ch < row_count; i_ch++)
183         {
184             int32_t output_0 = bias[i_ch];
185             int32_t output_1 = bias[i_ch];
186             const q15_t *ip_b0_s16 = input_b;
187             const q15_t *ip_b1_s16 = input_b + num_col_a;
188 
189             for (int i_idx = 0; i_idx < num_col_a; i_idx++)
190             {
191                 output_0 += ip_b0_s16[i_idx] * ip_a0_s8[i_idx];
192                 output_1 += ip_b1_s16[i_idx] * ip_a0_s8[i_idx];
193             }
194 
195             ip_a0_s8 += num_col_a;
196             out_vec_0[i_ch] = output_0;
197             out_vec_1[i_ch] = output_1;
198             mult_tail[i_ch] = out_mult[i_ch];
199             shift_tail[i_ch] = out_shift[i_ch];
200         }
201         out_vec_0 = arm_requantize_mve_32x4(out_vec_0, mult_tail, shift_tail);
202         out_vec_1 = arm_requantize_mve_32x4(out_vec_1, mult_tail, shift_tail);
203 
204         out_vec_0 = vaddq_n_s32(out_vec_0, out_offset);
205         out_vec_0 = vmaxq_s32(out_vec_0, vdupq_n_s32(activation_min));
206         out_vec_0 = vminq_s32(out_vec_0, vdupq_n_s32(activation_max));
207         vstrbq_p_s32(out_0, out_vec_0, p);
208 
209         out_vec_1 = vaddq_n_s32(out_vec_1, out_offset);
210         out_vec_1 = vmaxq_s32(out_vec_1, vdupq_n_s32(activation_min));
211         out_vec_1 = vminq_s32(out_vec_1, vdupq_n_s32(activation_max));
212 
213         vstrbq_p_s32(out_1, out_vec_1, p);
214         out_1 += row_count;
215     }
216 
217     return out_1;
218 
219 #elif defined(ARM_MATH_DSP)
220     /* set up the second output pointers */
221     q7_t *out_1 = out_0 + output_ch;
222     const int32_t *bias = output_bias;
223 
224     uint16_t row_count = output_ch / 2;
225     const q7_t *ip_a0 = input_a;
226     /* this loop over rows in A */
227     while (row_count)
228     {
229         /* setup pointers for B */
230         const q15_t *ip_b0 = input_b;
231         const q15_t *ip_b1 = ip_b0 + num_col_a;
232 
233         /* align the second pointer for A */
234         const q7_t *ip_a1 = ip_a0 + num_col_a;
235 
236         /* Init accumulator with bias for channel N and N + 1 */
237         q31_t ch_0_out_0 = *bias;
238         q31_t ch_0_out_1 = *bias++;
239         q31_t ch_1_out_0 = *bias;
240         q31_t ch_1_out_1 = *bias++;
241 
242         uint16_t col_count = num_col_a / 4;
243         /* accumulate over the vector */
244         while (col_count)
245         {
246             q31_t a01, a02, a11, a12;
247             q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
248             q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
249 
250             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
251             ip_a1 = read_and_pad(ip_a1, &a11, &a12);
252 
253             ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
254             ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
255             ch_1_out_0 = __SMLAD(a11, b0, ch_1_out_0);
256             ch_1_out_1 = __SMLAD(a11, b1, ch_1_out_1);
257 
258             b0 = arm_nn_read_q15x2_ia(&ip_b0);
259             b1 = arm_nn_read_q15x2_ia(&ip_b1);
260 
261             ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
262             ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
263             ch_1_out_0 = __SMLAD(a12, b0, ch_1_out_0);
264             ch_1_out_1 = __SMLAD(a12, b1, ch_1_out_1);
265 
266             col_count--;
267         } /* while over col_count */
268         col_count = num_col_a & 0x3;
269         while (col_count)
270         {
271             q7_t a0 = *ip_a0++;
272             q15_t b0 = *ip_b0++;
273             q7_t a1 = *ip_a1++;
274             q15_t b1 = *ip_b1++;
275 
276             ch_0_out_0 += a0 * b0;
277             ch_0_out_1 += a0 * b1;
278             ch_1_out_0 += a1 * b0;
279             ch_1_out_1 += a1 * b1;
280             col_count--;
281         } /* while over col_count */
282 
283         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
284         ch_0_out_0 += out_offset;
285         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
286         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
287         *out_0++ = (q7_t)ch_0_out_0;
288 
289         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
290         ch_0_out_1 += out_offset;
291         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
292         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
293         *out_1++ = (q7_t)ch_0_out_1;
294         out_mult++;
295         out_shift++;
296 
297         ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
298         ch_1_out_0 += out_offset;
299         ch_1_out_0 = MAX(ch_1_out_0, activation_min);
300         ch_1_out_0 = MIN(ch_1_out_0, activation_max);
301         *out_0++ = (q7_t)ch_1_out_0;
302 
303         ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
304         ch_1_out_1 += out_offset;
305         ch_1_out_1 = MAX(ch_1_out_1, activation_min);
306         ch_1_out_1 = MIN(ch_1_out_1, activation_max);
307         *out_1++ = (q7_t)ch_1_out_1;
308         out_mult++;
309         out_shift++;
310 
311         /* skip row */
312         ip_a0 += num_col_a;
313         row_count--;
314     }
315 
316     /* compute the last odd numbered row if any */
317     if (output_ch & 0x1)
318     {
319         /* setup pointers for B */
320         const q15_t *ip_b0 = input_b;
321         const q15_t *ip_b1 = ip_b0 + num_col_a;
322 
323         /* load the bias */
324         q31_t ch_0_out_0 = *bias;
325         q31_t ch_0_out_1 = *bias++;
326 
327         uint16_t col_count = num_col_a >> 2;
328         while (col_count)
329         {
330             q31_t a01, a02;
331             q31_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
332             q31_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
333 
334             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
335 
336             ch_0_out_0 = __SMLAD(a01, b0, ch_0_out_0);
337             ch_0_out_1 = __SMLAD(a01, b1, ch_0_out_1);
338 
339             b0 = arm_nn_read_q15x2_ia(&ip_b0);
340             b1 = arm_nn_read_q15x2_ia(&ip_b1);
341             ch_0_out_0 = __SMLAD(a02, b0, ch_0_out_0);
342             ch_0_out_1 = __SMLAD(a02, b1, ch_0_out_1);
343 
344             col_count--;
345         }
346         col_count = num_col_a & 0x3;
347         while (col_count)
348         {
349             q7_t a0 = *ip_a0++;
350             q15_t b0 = *ip_b0++;
351             q15_t b1 = *ip_b1++;
352 
353             ch_0_out_0 += a0 * b0;
354             ch_0_out_1 += a0 * b1;
355             col_count--;
356         }
357         ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
358         ch_0_out_0 += out_offset;
359         ch_0_out_0 = MAX(ch_0_out_0, activation_min);
360         ch_0_out_0 = MIN(ch_0_out_0, activation_max);
361         *out_0++ = (q7_t)ch_0_out_0;
362 
363         ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
364         ch_0_out_1 += out_offset;
365         ch_0_out_1 = MAX(ch_0_out_1, activation_min);
366         ch_0_out_1 = MIN(ch_0_out_1, activation_max);
367         *out_1++ = (q7_t)ch_0_out_1;
368         out_mult++;
369         out_shift++;
370     }
371 
372     out_0 += output_ch;
373 
374     /* return the new output pointer with offset */
375     return out_0;
376 #else
377     (void)input_a;
378     (void)input_b;
379     (void)output_ch;
380     (void)out_shift;
381     (void)out_mult;
382     (void)out_offset;
383     (void)activation_min;
384     (void)activation_max;
385     (void)num_col_a;
386     (void)output_bias;
387     (void)out_0;
388     /* To be completed */
389     return NULL;
390 #endif
391 }
392