1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_convolve_s16.c
22 * Description: s16 version of convolution.
23 *
24 * $Date: 22 April 2024
25 * $Revision: V.4.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup NNConv
40 * @{
41 */
42
43 /*
44 * Basic s16 convolution function.
45 *
46 * Refer header file for details. Optimal use case for the DSP/MVE implementation is when input and output channels
47 * are multiples of 4 or atleast greater than 4.
48 *
49 */
arm_convolve_s16(const cmsis_nn_context * ctx,const cmsis_nn_conv_params * conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int16_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const cmsis_nn_bias_data * bias_data,const cmsis_nn_dims * output_dims,int16_t * output_data)50 arm_cmsis_nn_status arm_convolve_s16(const cmsis_nn_context *ctx,
51 const cmsis_nn_conv_params *conv_params,
52 const cmsis_nn_per_channel_quant_params *quant_params,
53 const cmsis_nn_dims *input_dims,
54 const int16_t *input_data,
55 const cmsis_nn_dims *filter_dims,
56 const int8_t *filter_data,
57 const cmsis_nn_dims *bias_dims,
58 const cmsis_nn_bias_data *bias_data,
59 const cmsis_nn_dims *output_dims,
60 int16_t *output_data)
61 {
62 (void)bias_dims;
63
64 if (ctx->buf == NULL)
65 {
66 return ARM_CMSIS_NN_ARG_ERROR;
67 }
68 int16_t *buffer_a = (int16_t *)ctx->buf;
69
70 const int32_t input_batches = input_dims->n;
71 const int32_t input_x = input_dims->w;
72 const int32_t input_y = input_dims->h;
73 const int32_t input_ch = input_dims->c;
74 const int32_t kernel_x = filter_dims->w;
75 const int32_t kernel_y = filter_dims->h;
76 const int32_t output_x = output_dims->w;
77 const int32_t output_y = output_dims->h;
78 const int32_t output_ch = output_dims->c;
79 const int32_t rhs_cols = input_ch * kernel_y * kernel_x;
80
81 const int32_t dilation_x = conv_params->dilation.w;
82 const int32_t dilation_y = conv_params->dilation.h;
83 const int32_t pad_x = conv_params->padding.w;
84 const int32_t pad_y = conv_params->padding.h;
85 const int32_t stride_x = conv_params->stride.w;
86 const int32_t stride_y = conv_params->stride.h;
87
88 const int32_t out_activation_min = conv_params->activation.min;
89 const int32_t out_activation_max = conv_params->activation.max;
90 int32_t *output_mult = quant_params->multiplier;
91 int32_t *output_shift = quant_params->shift;
92
93 #if defined(ARM_MATH_MVEI)
94 const int32_t rhs_rows = output_dims->c;
95 #endif
96
97 for (int i_batch = 0; i_batch < input_batches; i_batch++)
98 {
99 int16_t *im2col = buffer_a;
100 int16_t *out = output_data;
101
102 int32_t lhs_rows = 0;
103
104 /* This part implements the im2col function */
105 for (int32_t i_out_y = 0; i_out_y < output_y; i_out_y++)
106 {
107 for (int32_t i_out_x = 0; i_out_x < output_x; i_out_x++)
108 {
109 const int32_t base_idx_x = stride_x * i_out_x - pad_x;
110 const int32_t base_idx_y = stride_y * i_out_y - pad_y;
111
112 for (int32_t i_ker_y = 0; i_ker_y < kernel_y; i_ker_y++)
113 {
114 for (int32_t i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
115 {
116 const int32_t k_y = base_idx_y + dilation_y * i_ker_y;
117 const int32_t k_x = base_idx_x + dilation_x * i_ker_x;
118
119 if (k_y < 0 || k_y >= input_y || k_x < 0 || k_x >= input_x)
120 {
121 /* Filling 0 for out-of-bound paddings */
122 arm_memset_s8((int8_t *)im2col, 0, sizeof(int16_t) * (uint32_t)input_ch);
123 }
124 else
125 {
126 arm_memcpy_s8((int8_t *)im2col,
127 (const int8_t *)(input_data + (k_y * input_x + k_x) * input_ch),
128 (uint32_t)input_ch * sizeof(int16_t));
129 }
130 im2col += input_ch;
131 }
132 }
133
134 lhs_rows++;
135 #if defined(ARM_MATH_MVEI)
136 /* Computation is filed for every 4 columns */
137 if (lhs_rows == 4)
138 {
139 arm_nn_mat_mult_nt_t_s16(buffer_a,
140 filter_data,
141 bias_data,
142 out,
143 output_mult,
144 output_shift,
145 lhs_rows,
146 rhs_rows,
147 rhs_cols,
148 out_activation_min,
149 out_activation_max);
150 out += lhs_rows * output_ch;
151
152 lhs_rows = 0;
153 im2col = buffer_a;
154 }
155 #else
156 /* Computation is filed for every 2 columns */
157 if (lhs_rows == 2)
158 {
159 out = arm_nn_mat_mult_kernel_s16(filter_data,
160 buffer_a,
161 output_ch,
162 output_shift,
163 output_mult,
164 out_activation_min,
165 out_activation_max,
166 rhs_cols,
167 bias_data,
168 out);
169
170 /* Counter reset */
171 im2col = buffer_a;
172 lhs_rows = 0;
173 }
174 #endif
175 }
176
177 if (out == NULL)
178 {
179 return ARM_CMSIS_NN_NO_IMPL_ERROR;
180 }
181 }
182
183 /* Handle left over columns */
184 if (lhs_rows != 0)
185 {
186 #if defined(ARM_MATH_MVEI)
187 arm_nn_mat_mult_nt_t_s16(buffer_a,
188 filter_data,
189 bias_data,
190 out,
191 output_mult,
192 output_shift,
193 lhs_rows,
194 rhs_rows,
195 rhs_cols,
196 out_activation_min,
197 out_activation_max);
198 out += lhs_rows * rhs_rows;
199 lhs_rows = 0;
200 im2col = buffer_a;
201 #else // #if defined(ARM_MATH_MVEI)
202
203 const int64_t *bias_s64 = (const int64_t *)bias_data->data;
204 const int32_t *bias_s32 = (const int32_t *)bias_data->data;
205 const bool is_int32_bias = bias_data->is_int32_bias;
206 const int8_t *ker_a = filter_data;
207 int i;
208
209 for (i = 0; i < output_ch; i++)
210 {
211 /* Init the accumulator*/
212 int32_t sum = 0;
213
214 /* Point to the beginning of the im2col buffer where the input is available as a rearranged column */
215 const int16_t *ip_as_col = buffer_a;
216
217 #if defined(ARM_MATH_DSP)
218 /* 4 multiply and accumulates are done in one loop. */
219 int32_t col_count = rhs_cols >> 2;
220
221 while (col_count)
222 {
223 int32_t ker_a1, ker_a2;
224 int32_t ip_b1, ip_b2;
225
226 ker_a = read_and_pad(ker_a, &ker_a1, &ker_a2);
227
228 ip_b1 = arm_nn_read_q15x2_ia(&ip_as_col);
229 sum = SMLAD(ker_a1, ip_b1, sum);
230 ip_b2 = arm_nn_read_q15x2_ia(&ip_as_col);
231 sum = SMLAD(ker_a2, ip_b2, sum);
232
233 col_count--;
234 }
235 /* Handle left over mac */
236 col_count = rhs_cols & 0x3;
237 #else
238 uint16_t col_count = rhs_cols;
239
240 #endif
241
242 while (col_count)
243 {
244 int8_t ker_a1 = *ker_a++;
245 int16_t ip_b1 = *ip_as_col++;
246 sum += ker_a1 * ip_b1;
247 col_count--;
248 }
249
250 if (is_int32_bias)
251 {
252 if (bias_s32)
253 {
254 sum += bias_s32[i];
255 }
256
257 sum = arm_nn_requantize(sum, output_mult[i], output_shift[i]);
258 }
259 else
260 {
261 int64_t acc_64 = sum;
262
263 if (bias_s64)
264 {
265 acc_64 += bias_s64[i];
266 }
267
268 int32_t reduced_multiplier = REDUCE_MULTIPLIER(output_mult[i]);
269 sum = arm_nn_requantize_s64(acc_64, reduced_multiplier, output_shift[i]);
270 }
271
272 sum = MAX(sum, out_activation_min);
273 sum = MIN(sum, out_activation_max);
274 *out++ = (int16_t)sum;
275 }
276 lhs_rows = 0;
277
278 #endif // #if defined(ARM_MATH_MVEI)
279 }
280
281 /* Advance to the next batch */
282 input_data += (input_x * input_y * input_ch);
283 output_data += (output_x * output_y * output_ch);
284 }
285
286 /* Return to application */
287 return ARM_CMSIS_NN_SUCCESS;
288 }
289
290 /**
291 * @} end of NNConv group
292 */
293