1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_s8.c
22 * Description: s8 version of depthwise convolution.
23 *
24 * $Date: 26 October 2022
25 * $Revision: V.3.0.4
26 *
27 * Target Processor: Cortex-M CPUs
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup NNConv
40 * @{
41 */
42
43 #if !defined(__ARMCC_VERSION)
44 __attribute__((optimize("no-unroll-loops")))
45 #endif
46 static void
depthwise_conv_s8_mult_4(const int8_t * input,const int32_t input_x,const int32_t input_y,const int32_t input_ch,const int8_t * kernel,const int32_t output_ch,const int32_t ch_mult,const int32_t kernel_x,const int32_t kernel_y,const int32_t pad_x,const int32_t pad_y,const int32_t stride_x,const int32_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const int32_t output_x,const int32_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max)47 depthwise_conv_s8_mult_4(const int8_t *input,
48 const int32_t input_x,
49 const int32_t input_y,
50 const int32_t input_ch,
51 const int8_t *kernel,
52 const int32_t output_ch,
53 const int32_t ch_mult,
54 const int32_t kernel_x,
55 const int32_t kernel_y,
56 const int32_t pad_x,
57 const int32_t pad_y,
58 const int32_t stride_x,
59 const int32_t stride_y,
60 const int32_t *bias,
61 int8_t *output,
62 const int32_t *output_shift,
63 const int32_t *output_mult,
64 const int32_t output_x,
65 const int32_t output_y,
66 const int32_t output_offset,
67 const int32_t input_offset,
68 const int32_t output_activation_min,
69 const int32_t output_activation_max)
70 {
71 const int32_t *bias_base = bias;
72 const int32_t *mult_base = output_mult;
73 const int32_t *shift_base = output_shift;
74 const int8_t *kernel_base = kernel;
75
76 for (int32_t in_h = -pad_y, out_h = 0; out_h < output_y; in_h += stride_y, ++out_h)
77 {
78 for (int32_t in_w = -pad_x, out_w = 0, ker_h_start = MAX(0, -in_h); out_w < output_x; in_w += stride_x, ++out_w)
79 {
80 bias = bias_base;
81 output_mult = mult_base;
82 output_shift = shift_base;
83 for (int32_t in_ch = 0, out_ch = 0, ker_w_start = MAX(0, -in_w); out_ch < output_ch;
84 ++in_ch, out_ch += ch_mult)
85 {
86 for (int mult_tile = 0; mult_tile < ch_mult; mult_tile += 4)
87 {
88 int32_t out_buff[4] = {0, 0, 0, 0};
89 if (bias)
90 {
91 out_buff[0] = *bias++;
92 out_buff[1] = *bias++;
93 out_buff[2] = *bias++;
94 out_buff[3] = *bias++;
95 }
96
97 for (int32_t ker_h = ker_h_start; ker_h < MIN(kernel_y, input_y - in_h); ++ker_h)
98 {
99 int32_t ker_idx = ker_h * (output_ch * kernel_x) + ker_w_start * output_ch + out_ch;
100 kernel = kernel_base + mult_tile + ker_idx;
101 int32_t in_idx = (in_h + ker_h) * (input_ch * input_x) + in_w * input_ch + in_ch;
102 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
103 #pragma clang loop unroll(disable)
104 #endif
105 for (int32_t ker_w = ker_w_start; ker_w < MIN(kernel_x, input_x - in_w);
106 ++ker_w, kernel += output_ch)
107 {
108 int32_t in_val = input[in_idx + ker_w * input_ch] + input_offset;
109 out_buff[0] += in_val * kernel[0];
110 out_buff[1] += in_val * kernel[1];
111 out_buff[2] += in_val * kernel[2];
112 out_buff[3] += in_val * kernel[3];
113 }
114 }
115 #if defined(ARM_MATH_MVEI)
116 int32x4_t res = vldrwq_s32(out_buff);
117 res = arm_requantize_mve_32x4(res, vldrwq_s32(output_mult), vldrwq_s32(output_shift));
118 output_mult += 4;
119 output_shift += 4;
120 res = vaddq_n_s32(res, output_offset);
121
122 res = vmaxq_s32(res, vdupq_n_s32(output_activation_min));
123 res = vminq_s32(res, vdupq_n_s32(output_activation_max));
124 vstrbq_s32(output, res);
125 output += 4;
126 #else
127 out_buff[0] = arm_nn_requantize(out_buff[0], *output_mult++, *output_shift++);
128 out_buff[1] = arm_nn_requantize(out_buff[1], *output_mult++, *output_shift++);
129 out_buff[2] = arm_nn_requantize(out_buff[2], *output_mult++, *output_shift++);
130 out_buff[3] = arm_nn_requantize(out_buff[3], *output_mult++, *output_shift++);
131
132 out_buff[0] += output_offset;
133 out_buff[1] += output_offset;
134 out_buff[2] += output_offset;
135 out_buff[3] += output_offset;
136
137 out_buff[0] = MIN(MAX(out_buff[0], output_activation_min), output_activation_max);
138 out_buff[1] = MIN(MAX(out_buff[1], output_activation_min), output_activation_max);
139 out_buff[2] = MIN(MAX(out_buff[2], output_activation_min), output_activation_max);
140 out_buff[3] = MIN(MAX(out_buff[3], output_activation_min), output_activation_max);
141
142 *output++ = (int8_t)out_buff[0];
143 *output++ = (int8_t)out_buff[1];
144 *output++ = (int8_t)out_buff[2];
145 *output++ = (int8_t)out_buff[3];
146
147 #endif
148 }
149 }
150 }
151 }
152 }
153
depthwise_conv_s8_generic(const int8_t * input,const uint16_t input_batches,const uint16_t input_x,const uint16_t input_y,const uint16_t input_ch,const int8_t * kernel,const uint16_t output_ch,const uint16_t ch_mult,const uint16_t kernel_x,const uint16_t kernel_y,const uint16_t pad_x,const uint16_t pad_y,const uint16_t stride_x,const uint16_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const uint16_t output_x,const uint16_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max,const uint16_t dilation_x,const uint16_t dilation_y)154 static void depthwise_conv_s8_generic(const int8_t *input,
155 const uint16_t input_batches,
156 const uint16_t input_x,
157 const uint16_t input_y,
158 const uint16_t input_ch,
159 const int8_t *kernel,
160 const uint16_t output_ch,
161 const uint16_t ch_mult,
162 const uint16_t kernel_x,
163 const uint16_t kernel_y,
164 const uint16_t pad_x,
165 const uint16_t pad_y,
166 const uint16_t stride_x,
167 const uint16_t stride_y,
168 const int32_t *bias,
169 int8_t *output,
170 const int32_t *output_shift,
171 const int32_t *output_mult,
172 const uint16_t output_x,
173 const uint16_t output_y,
174 const int32_t output_offset,
175 const int32_t input_offset,
176 const int32_t output_activation_min,
177 const int32_t output_activation_max,
178 const uint16_t dilation_x,
179 const uint16_t dilation_y)
180
181 {
182 (void)output_ch;
183 int i_out = 0;
184 int i_batch;
185
186 for (i_batch = 0; i_batch < input_batches; i_batch++)
187 {
188 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
189 {
190 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
191 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
192 {
193 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
194 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
195 {
196 for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
197 {
198 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
199 int32_t acc_0 = 0;
200
201 int ker_y_start;
202 int ker_x_start;
203 int ker_y_end;
204 int ker_x_end;
205
206 if (dilation_x > 1)
207 {
208 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
209 ker_x_start = MAX(0, start_x_max);
210 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
211 ker_x_end = MIN(kernel_x, end_min_x);
212 }
213 else
214 {
215 ker_x_start = MAX(0, -base_idx_x);
216 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
217 }
218
219 if (dilation_y > 1)
220 {
221 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
222 ker_y_start = MAX(0, start_y_max);
223 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
224 ker_y_end = MIN(kernel_y, end_min_y);
225 }
226 else
227 {
228 ker_y_start = MAX(0, -base_idx_y);
229 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
230 }
231
232 if (bias)
233 {
234 acc_0 = bias[idx_out_ch];
235 }
236
237 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
238 {
239 const int32_t idx_y = base_idx_y + dilation_y * i_ker_y;
240 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
241 {
242 const int32_t idx_x = base_idx_x + dilation_x * i_ker_x;
243 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
244 int32_t ker_idx_0 = (i_ker_y * kernel_x + i_ker_x) * (input_ch * ch_mult) + idx_out_ch;
245
246 acc_0 += (input[idx_0] + input_offset) * kernel[ker_idx_0];
247 }
248 }
249
250 /* Requantize and clamp output to provided range */
251 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
252 acc_0 += output_offset;
253 acc_0 = MAX(acc_0, output_activation_min);
254 acc_0 = MIN(acc_0, output_activation_max);
255
256 output[i_out++] = acc_0;
257 }
258 }
259 }
260 }
261 /* Advance to the next batch */
262 input += (input_x * input_y * input_ch);
263 }
264 }
265
266 /*
267 * Basic s8 depthwise convolution function.
268 *
269 * Refer header file for details.
270 * Optimization using DSP extension is not available for the generic case where channel multiplier is > 1.
271 *
272 */
arm_depthwise_conv_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)273 arm_cmsis_nn_status arm_depthwise_conv_s8(const cmsis_nn_context *ctx,
274 const cmsis_nn_dw_conv_params *dw_conv_params,
275 const cmsis_nn_per_channel_quant_params *quant_params,
276 const cmsis_nn_dims *input_dims,
277 const int8_t *input,
278 const cmsis_nn_dims *filter_dims,
279 const int8_t *kernel,
280 const cmsis_nn_dims *bias_dims,
281 const int32_t *bias,
282 const cmsis_nn_dims *output_dims,
283 int8_t *output)
284 {
285 const uint16_t dilation_x = dw_conv_params->dilation.w;
286 const uint16_t dilation_y = dw_conv_params->dilation.h;
287
288 (void)bias_dims;
289 (void)ctx;
290
291 if (dw_conv_params->ch_mult % 4 == 0 && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
292 dw_conv_params->dilation.h == 1)
293 {
294 depthwise_conv_s8_mult_4(input,
295 input_dims->w,
296 input_dims->h,
297 input_dims->c,
298 kernel,
299 output_dims->c,
300 dw_conv_params->ch_mult,
301 filter_dims->w,
302 filter_dims->h,
303 dw_conv_params->padding.w,
304 dw_conv_params->padding.h,
305 dw_conv_params->stride.w,
306 dw_conv_params->stride.h,
307 bias,
308 output,
309 quant_params->shift,
310 quant_params->multiplier,
311 output_dims->w,
312 output_dims->h,
313 dw_conv_params->output_offset,
314 dw_conv_params->input_offset,
315 dw_conv_params->activation.min,
316 dw_conv_params->activation.max);
317 }
318 else
319 {
320 depthwise_conv_s8_generic(input,
321 input_dims->n,
322 input_dims->w,
323 input_dims->h,
324 input_dims->c,
325 kernel,
326 output_dims->c,
327 dw_conv_params->ch_mult,
328 filter_dims->w,
329 filter_dims->h,
330 dw_conv_params->padding.w,
331 dw_conv_params->padding.h,
332 dw_conv_params->stride.w,
333 dw_conv_params->stride.h,
334 bias,
335 output,
336 quant_params->shift,
337 quant_params->multiplier,
338 output_dims->w,
339 output_dims->h,
340 dw_conv_params->output_offset,
341 dw_conv_params->input_offset,
342 dw_conv_params->activation.min,
343 dw_conv_params->activation.max,
344 dilation_x,
345 dilation_y);
346 }
347
348 /* Return to application */
349 return ARM_CMSIS_NN_SUCCESS;
350 }
351
352 /**
353 * @} end of NNConv group
354 */
355