1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates
3  * <open-source-office@arm.com>
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * Licensed under the Apache License, Version 2.0 (the License); you may
8  * not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 /* ----------------------------------------------------------------------
21  * Project:      CMSIS NN Library
22  * Title:        arm_nn_mat_mult_kernel_s16.c
23  * Description:  Matrix-multiplication function for 16 bits convolution
24  *
25  * $Date:        12 April 2024
26  * $Revision:    V.3.0.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup groupSupport
36  */
37 
38 /**
39  * @addtogroup supportConvolution
40  * @{
41  */
42 
43 /*
44  * Matrix-multiplication function for convolution with per-channel requantization.
45  *
46  * Refer header file for details.
47  *
48  */
arm_nn_mat_mult_kernel_s16(const int8_t * input_a,const int16_t * input_b,const int32_t output_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t activation_min,const int32_t activation_max,const int32_t num_col_a,const cmsis_nn_bias_data * const bias_data,int16_t * out_0)49 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
50                                     const int16_t *input_b,
51                                     const int32_t output_ch,
52                                     const int32_t *out_shift,
53                                     const int32_t *out_mult,
54                                     const int32_t activation_min,
55                                     const int32_t activation_max,
56                                     const int32_t num_col_a,
57                                     const cmsis_nn_bias_data *const bias_data,
58                                     int16_t *out_0)
59 {
60 #if !defined(ARM_MATH_MVEI)
61     const int64_t *bias_s64 = (const int64_t *)bias_data->data;
62     const int32_t *bias_s32 = (const int32_t *)bias_data->data;
63     const bool is_int32_bias = bias_data->is_int32_bias;
64 
65     const int32_t num_col_a_fast = is_int32_bias ? num_col_a : (num_col_a > MAX_COL_COUNT ? MAX_COL_COUNT : num_col_a);
66     const int32_t num_col_a_slow = num_col_a - MAX_COL_COUNT;
67 
68     int16_t *out_1 = out_0 + output_ch;
69     int32_t row_count = output_ch / 2;
70     const int8_t *ip_a0 = input_a;
71 
72     /* This loop over rows in A */
73     while (row_count)
74     {
75         /* Setup pointers for B */
76         const int16_t *ip_b0 = input_b;
77         const int16_t *ip_b1 = ip_b0 + num_col_a;
78 
79         /* Align the second pointer for A */
80         const int8_t *ip_a1 = ip_a0 + num_col_a;
81 
82         /* Init accumulator for channel N and N + 1 */
83         int32_t ch_0_out_0 = 0;
84         int32_t ch_0_out_1 = 0;
85         int32_t ch_1_out_0 = 0;
86         int32_t ch_1_out_1 = 0;
87 
88     #if defined(ARM_MATH_DSP)
89         uint16_t col_count = num_col_a_fast / 4;
90 
91         /* Accumulate over the vector */
92         while (col_count)
93         {
94             int32_t a01, a02, a11, a12;
95             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
96             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
97 
98             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
99             ip_a1 = read_and_pad(ip_a1, &a11, &a12);
100 
101             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
102             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
103             ch_1_out_0 = SMLAD(a11, b0, ch_1_out_0);
104             ch_1_out_1 = SMLAD(a11, b1, ch_1_out_1);
105 
106             b0 = arm_nn_read_q15x2_ia(&ip_b0);
107             b1 = arm_nn_read_q15x2_ia(&ip_b1);
108 
109             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
110             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
111             ch_1_out_0 = SMLAD(a12, b0, ch_1_out_0);
112             ch_1_out_1 = SMLAD(a12, b1, ch_1_out_1);
113 
114             col_count--;
115         }
116         col_count = num_col_a_fast & 0x3;
117     #else
118         int32_t col_count = num_col_a_fast;
119     #endif
120 
121         while (col_count)
122         {
123             int8_t a0 = *ip_a0++;
124             int16_t b0 = *ip_b0++;
125             int8_t a1 = *ip_a1++;
126             int16_t b1 = *ip_b1++;
127 
128             ch_0_out_0 += a0 * b0;
129             ch_0_out_1 += a0 * b1;
130             ch_1_out_0 += a1 * b0;
131             ch_1_out_1 += a1 * b1;
132             col_count--;
133         }
134 
135         if (is_int32_bias)
136         {
137             if (bias_s32)
138             {
139                 ch_0_out_0 += *bias_s32;
140                 ch_0_out_1 += *bias_s32++;
141                 ch_1_out_0 += *bias_s32;
142                 ch_1_out_1 += *bias_s32++;
143             }
144 
145             ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
146             ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
147             out_mult++;
148             out_shift++;
149 
150             ch_0_out_0 = MAX(ch_0_out_0, activation_min);
151             ch_0_out_0 = MIN(ch_0_out_0, activation_max);
152             *out_0++ = (int16_t)ch_0_out_0;
153 
154             ch_0_out_1 = MAX(ch_0_out_1, activation_min);
155             ch_0_out_1 = MIN(ch_0_out_1, activation_max);
156             *out_1++ = (int16_t)ch_0_out_1;
157 
158             ch_1_out_0 = arm_nn_requantize(ch_1_out_0, *out_mult, *out_shift);
159             ch_1_out_1 = arm_nn_requantize(ch_1_out_1, *out_mult, *out_shift);
160             out_mult++;
161             out_shift++;
162 
163             ch_1_out_0 = MAX(ch_1_out_0, activation_min);
164             ch_1_out_0 = MIN(ch_1_out_0, activation_max);
165             *out_0++ = (int16_t)ch_1_out_0;
166 
167             ch_1_out_1 = MAX(ch_1_out_1, activation_min);
168             ch_1_out_1 = MIN(ch_1_out_1, activation_max);
169             *out_1++ = (int16_t)ch_1_out_1;
170         }
171         else
172         {
173             int64_t ch_0_out_0_s64 = ch_0_out_0;
174             int64_t ch_0_out_1_s64 = ch_0_out_1;
175             int64_t ch_1_out_0_s64 = ch_1_out_0;
176             int64_t ch_1_out_1_s64 = ch_1_out_1;
177 
178             if (num_col_a > MAX_COL_COUNT)
179             {
180                 col_count = num_col_a_slow;
181                 while (col_count)
182                 {
183                     int8_t a0 = *ip_a0++;
184                     int16_t b0 = *ip_b0++;
185                     int8_t a1 = *ip_a1++;
186                     int16_t b1 = *ip_b1++;
187 
188                     ch_0_out_0_s64 += a0 * b0;
189                     ch_0_out_1_s64 += a0 * b1;
190                     ch_1_out_0_s64 += a1 * b0;
191                     ch_1_out_1_s64 += a1 * b1;
192                     col_count--;
193                 }
194             }
195 
196             if (bias_s64)
197             {
198                 ch_0_out_0_s64 += *bias_s64;
199                 ch_0_out_1_s64 += *bias_s64++;
200                 ch_1_out_0_s64 += *bias_s64;
201                 ch_1_out_1_s64 += *bias_s64++;
202             }
203 
204             int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
205             ch_0_out_0 = arm_nn_requantize_s64(ch_0_out_0_s64, reduced_multiplier, *out_shift);
206             ch_0_out_1 = arm_nn_requantize_s64(ch_0_out_1_s64, reduced_multiplier, *out_shift);
207             out_mult++;
208             out_shift++;
209 
210             reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
211             ch_1_out_0 = arm_nn_requantize_s64(ch_1_out_0_s64, reduced_multiplier, *out_shift);
212             ch_1_out_1 = arm_nn_requantize_s64(ch_1_out_1_s64, reduced_multiplier, *out_shift);
213 
214             ch_0_out_0 = MAX(ch_0_out_0, activation_min);
215             ch_0_out_0 = MIN(ch_0_out_0, activation_max);
216             *out_0++ = (int16_t)ch_0_out_0;
217 
218             ch_0_out_1 = MAX(ch_0_out_1, activation_min);
219             ch_0_out_1 = MIN(ch_0_out_1, activation_max);
220             *out_1++ = (int16_t)ch_0_out_1;
221 
222             ch_1_out_0 = MAX(ch_1_out_0, activation_min);
223             ch_1_out_0 = MIN(ch_1_out_0, activation_max);
224             *out_0++ = (int16_t)ch_1_out_0;
225 
226             ch_1_out_1 = MAX(ch_1_out_1, activation_min);
227             ch_1_out_1 = MIN(ch_1_out_1, activation_max);
228             *out_1++ = (int16_t)ch_1_out_1;
229 
230             out_mult++;
231             out_shift++;
232         }
233 
234         /* Skip row */
235         ip_a0 += num_col_a;
236         row_count--;
237     }
238 
239     /* Compute the last odd numbered row if any */
240     if (output_ch & 0x1)
241     {
242         /* Setup pointers for B */
243         const int16_t *ip_b0 = input_b;
244         const int16_t *ip_b1 = ip_b0 + num_col_a;
245 
246         int32_t ch_0_out_0 = 0;
247         int32_t ch_0_out_1 = 0;
248 
249     #if defined(ARM_MATH_DSP)
250         uint16_t col_count = num_col_a_fast >> 2;
251         while (col_count)
252         {
253             int32_t a01, a02;
254             int32_t b0 = arm_nn_read_q15x2_ia(&ip_b0);
255             int32_t b1 = arm_nn_read_q15x2_ia(&ip_b1);
256 
257             ip_a0 = read_and_pad(ip_a0, &a01, &a02);
258 
259             ch_0_out_0 = SMLAD(a01, b0, ch_0_out_0);
260             ch_0_out_1 = SMLAD(a01, b1, ch_0_out_1);
261 
262             b0 = arm_nn_read_q15x2_ia(&ip_b0);
263             b1 = arm_nn_read_q15x2_ia(&ip_b1);
264             ch_0_out_0 = SMLAD(a02, b0, ch_0_out_0);
265             ch_0_out_1 = SMLAD(a02, b1, ch_0_out_1);
266 
267             col_count--;
268         }
269         col_count = num_col_a & 0x3;
270     #else
271         int32_t col_count = num_col_a_fast;
272     #endif
273         while (col_count)
274         {
275             int8_t a0 = *ip_a0++;
276             int16_t b0 = *ip_b0++;
277             int16_t b1 = *ip_b1++;
278 
279             ch_0_out_0 += a0 * b0;
280             ch_0_out_1 += a0 * b1;
281             col_count--;
282         }
283 
284         if (is_int32_bias)
285         {
286             if (bias_s32)
287             {
288                 ch_0_out_0 += *bias_s32;
289                 ch_0_out_1 += *bias_s32++;
290             }
291 
292             ch_0_out_0 = arm_nn_requantize(ch_0_out_0, *out_mult, *out_shift);
293             ch_0_out_1 = arm_nn_requantize(ch_0_out_1, *out_mult, *out_shift);
294             out_mult++;
295             out_shift++;
296 
297             ch_0_out_0 = MAX(ch_0_out_0, activation_min);
298             ch_0_out_0 = MIN(ch_0_out_0, activation_max);
299             *out_0++ = (int16_t)ch_0_out_0;
300 
301             ch_0_out_1 = MAX(ch_0_out_1, activation_min);
302             ch_0_out_1 = MIN(ch_0_out_1, activation_max);
303             *out_1++ = (int16_t)ch_0_out_1;
304         }
305         else
306         {
307             int64_t ch_0_out_0_s64 = ch_0_out_0;
308             int64_t ch_0_out_1_s64 = ch_0_out_1;
309 
310             if (num_col_a > MAX_COL_COUNT)
311             {
312                 col_count = num_col_a_slow;
313                 while (col_count)
314                 {
315                     int8_t a0 = *ip_a0++;
316                     int16_t b0 = *ip_b0++;
317                     int16_t b1 = *ip_b1++;
318 
319                     ch_0_out_0_s64 += a0 * b0;
320                     ch_0_out_1_s64 += a0 * b1;
321                     col_count--;
322                 }
323             }
324 
325             if (bias_s64)
326             {
327                 ch_0_out_0_s64 += *bias_s64;
328                 ch_0_out_1_s64 += *bias_s64++;
329             }
330 
331             int32_t reduced_multiplier = REDUCE_MULTIPLIER(*out_mult);
332             ch_0_out_0 = arm_nn_requantize_s64(ch_0_out_0_s64, reduced_multiplier, *out_shift);
333             ch_0_out_1 = arm_nn_requantize_s64(ch_0_out_1_s64, reduced_multiplier, *out_shift);
334 
335             ch_0_out_0 = MAX(ch_0_out_0, activation_min);
336             ch_0_out_0 = MIN(ch_0_out_0, activation_max);
337             *out_0++ = (int16_t)ch_0_out_0;
338 
339             ch_0_out_1 = MAX(ch_0_out_1, activation_min);
340             ch_0_out_1 = MIN(ch_0_out_1, activation_max);
341             *out_1++ = (int16_t)ch_0_out_1;
342             out_mult++;
343             out_shift++;
344         }
345     }
346 
347     out_0 += output_ch;
348 
349     /* Return the new output pointer with offset */
350     return out_0;
351 #else
352     (void)input_a;
353     (void)input_b;
354     (void)output_ch;
355     (void)out_shift;
356     (void)out_mult;
357     (void)activation_min;
358     (void)activation_max;
359     (void)num_col_a;
360     (void)bias_data;
361     (void)out_0;
362 
363     return NULL;
364 #endif
365 }
366 
367 /**
368  * @} end of Doxygen group
369  */
370