1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_s16.c
22  * Description:  s16 version of convolution.
23  *
24  * $Date:        22 April 2024
25  * $Revision:    V.4.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup Public
36  */
37 
38 /**
39  * @addtogroup NNConv
40  * @{
41  */
42 
43 /*
44  * Basic s16 convolution function.
45  *
46  * Refer header file for details. Optimal use case for the DSP/MVE implementation is when input and output channels
47  * are multiples of 4 or atleast greater than 4.
48  *
49  */
arm_convolve_s16(const cmsis_nn_context * ctx,const cmsis_nn_conv_params * conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int16_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const cmsis_nn_bias_data * bias_data,const cmsis_nn_dims * output_dims,int16_t * output_data)50 arm_cmsis_nn_status arm_convolve_s16(const cmsis_nn_context *ctx,
51                                      const cmsis_nn_conv_params *conv_params,
52                                      const cmsis_nn_per_channel_quant_params *quant_params,
53                                      const cmsis_nn_dims *input_dims,
54                                      const int16_t *input_data,
55                                      const cmsis_nn_dims *filter_dims,
56                                      const int8_t *filter_data,
57                                      const cmsis_nn_dims *bias_dims,
58                                      const cmsis_nn_bias_data *bias_data,
59                                      const cmsis_nn_dims *output_dims,
60                                      int16_t *output_data)
61 {
62     (void)bias_dims;
63 
64     if (ctx->buf == NULL)
65     {
66         return ARM_CMSIS_NN_ARG_ERROR;
67     }
68     int16_t *buffer_a = (int16_t *)ctx->buf;
69 
70     const int32_t input_batches = input_dims->n;
71     const int32_t input_x = input_dims->w;
72     const int32_t input_y = input_dims->h;
73     const int32_t input_ch = input_dims->c;
74     const int32_t kernel_x = filter_dims->w;
75     const int32_t kernel_y = filter_dims->h;
76     const int32_t output_x = output_dims->w;
77     const int32_t output_y = output_dims->h;
78     const int32_t output_ch = output_dims->c;
79     const int32_t rhs_cols = input_ch * kernel_y * kernel_x;
80 
81     const int32_t dilation_x = conv_params->dilation.w;
82     const int32_t dilation_y = conv_params->dilation.h;
83     const int32_t pad_x = conv_params->padding.w;
84     const int32_t pad_y = conv_params->padding.h;
85     const int32_t stride_x = conv_params->stride.w;
86     const int32_t stride_y = conv_params->stride.h;
87 
88     const int32_t out_activation_min = conv_params->activation.min;
89     const int32_t out_activation_max = conv_params->activation.max;
90     int32_t *output_mult = quant_params->multiplier;
91     int32_t *output_shift = quant_params->shift;
92 
93 #if defined(ARM_MATH_MVEI)
94     const int32_t rhs_rows = output_dims->c;
95 #endif
96 
97     for (int i_batch = 0; i_batch < input_batches; i_batch++)
98     {
99         int16_t *im2col = buffer_a;
100         int16_t *out = output_data;
101 
102         int32_t lhs_rows = 0;
103 
104         /* This part implements the im2col function */
105         for (int32_t i_out_y = 0; i_out_y < output_y; i_out_y++)
106         {
107             for (int32_t i_out_x = 0; i_out_x < output_x; i_out_x++)
108             {
109                 const int32_t base_idx_x = stride_x * i_out_x - pad_x;
110                 const int32_t base_idx_y = stride_y * i_out_y - pad_y;
111 
112                 for (int32_t i_ker_y = 0; i_ker_y < kernel_y; i_ker_y++)
113                 {
114                     for (int32_t i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
115                     {
116                         const int32_t k_y = base_idx_y + dilation_y * i_ker_y;
117                         const int32_t k_x = base_idx_x + dilation_x * i_ker_x;
118 
119                         if (k_y < 0 || k_y >= input_y || k_x < 0 || k_x >= input_x)
120                         {
121                             /* Filling 0 for out-of-bound paddings */
122                             arm_memset_s8((int8_t *)im2col, 0, sizeof(int16_t) * (uint32_t)input_ch);
123                         }
124                         else
125                         {
126                             arm_memcpy_s8((int8_t *)im2col,
127                                           (const int8_t *)(input_data + (k_y * input_x + k_x) * input_ch),
128                                           (uint32_t)input_ch * sizeof(int16_t));
129                         }
130                         im2col += input_ch;
131                     }
132                 }
133 
134                 lhs_rows++;
135 #if defined(ARM_MATH_MVEI)
136                 /* Computation is filed for every 4 columns */
137                 if (lhs_rows == 4)
138                 {
139                     arm_nn_mat_mult_nt_t_s16(buffer_a,
140                                              filter_data,
141                                              bias_data,
142                                              out,
143                                              output_mult,
144                                              output_shift,
145                                              lhs_rows,
146                                              rhs_rows,
147                                              rhs_cols,
148                                              out_activation_min,
149                                              out_activation_max);
150                     out += lhs_rows * output_ch;
151 
152                     lhs_rows = 0;
153                     im2col = buffer_a;
154                 }
155 #else
156                 /* Computation is filed for every 2 columns */
157                 if (lhs_rows == 2)
158                 {
159                     out = arm_nn_mat_mult_kernel_s16(filter_data,
160                                                      buffer_a,
161                                                      output_ch,
162                                                      output_shift,
163                                                      output_mult,
164                                                      out_activation_min,
165                                                      out_activation_max,
166                                                      rhs_cols,
167                                                      bias_data,
168                                                      out);
169 
170                     /* Counter reset */
171                     im2col = buffer_a;
172                     lhs_rows = 0;
173                 }
174 #endif
175             }
176 
177             if (out == NULL)
178             {
179                 return ARM_CMSIS_NN_NO_IMPL_ERROR;
180             }
181         }
182 
183         /* Handle left over columns */
184         if (lhs_rows != 0)
185         {
186 #if defined(ARM_MATH_MVEI)
187             arm_nn_mat_mult_nt_t_s16(buffer_a,
188                                      filter_data,
189                                      bias_data,
190                                      out,
191                                      output_mult,
192                                      output_shift,
193                                      lhs_rows,
194                                      rhs_rows,
195                                      rhs_cols,
196                                      out_activation_min,
197                                      out_activation_max);
198             out += lhs_rows * rhs_rows;
199             lhs_rows = 0;
200             im2col = buffer_a;
201 #else // #if defined(ARM_MATH_MVEI)
202 
203             const int64_t *bias_s64 = (const int64_t *)bias_data->data;
204             const int32_t *bias_s32 = (const int32_t *)bias_data->data;
205             const bool is_int32_bias = bias_data->is_int32_bias;
206             const int8_t *ker_a = filter_data;
207             int i;
208 
209             for (i = 0; i < output_ch; i++)
210             {
211                 /* Init the accumulator*/
212                 int32_t sum = 0;
213 
214                 /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
215                 const int16_t *ip_as_col = buffer_a;
216 
217     #if defined(ARM_MATH_DSP)
218                 /* 4 multiply and accumulates are done in one loop. */
219                 int32_t col_count = rhs_cols >> 2;
220 
221                 while (col_count)
222                 {
223                     int32_t ker_a1, ker_a2;
224                     int32_t ip_b1, ip_b2;
225 
226                     ker_a = read_and_pad(ker_a, &ker_a1, &ker_a2);
227 
228                     ip_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
229                     sum = SMLAD(ker_a1, ip_b1, sum);
230                     ip_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
231                     sum = SMLAD(ker_a2, ip_b2, sum);
232 
233                     col_count--;
234                 }
235                 /* Handle left over mac */
236                 col_count = rhs_cols & 0x3;
237     #else
238                 uint16_t col_count = rhs_cols;
239 
240     #endif
241 
242                 while (col_count)
243                 {
244                     int8_t ker_a1 = *ker_a++;
245                     int16_t ip_b1 = *ip_as_col++;
246                     sum += ker_a1 * ip_b1;
247                     col_count--;
248                 }
249 
250                 if (is_int32_bias)
251                 {
252                     if (bias_s32)
253                     {
254                         sum += bias_s32[i];
255                     }
256 
257                     sum = arm_nn_requantize(sum, output_mult[i], output_shift[i]);
258                 }
259                 else
260                 {
261                     int64_t acc_64 = sum;
262 
263                     if (bias_s64)
264                     {
265                         acc_64 += bias_s64[i];
266                     }
267 
268                     int32_t reduced_multiplier = REDUCE_MULTIPLIER(output_mult[i]);
269                     sum = arm_nn_requantize_s64(acc_64, reduced_multiplier, output_shift[i]);
270                 }
271 
272                 sum = MAX(sum, out_activation_min);
273                 sum = MIN(sum, out_activation_max);
274                 *out++ = (int16_t)sum;
275             }
276             lhs_rows = 0;
277 
278 #endif // #if defined(ARM_MATH_MVEI)
279         }
280 
281         /* Advance to the next batch */
282         input_data += (input_x * input_y * input_ch);
283         output_data += (output_x * output_y * output_ch);
284     }
285 
286     /* Return to application */
287     return ARM_CMSIS_NN_SUCCESS;
288 }
289 
290 /**
291  * @} end of NNConv group
292  */
293