1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_s8.c
22  * Description:  s8 version of depthwise convolution.
23  *
24  * $Date:        26 October 2022
25  * $Revision:    V.3.0.4
26  *
27  * Target Processor:  Cortex-M CPUs
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup Public
36  */
37 
38 /**
39  * @addtogroup NNConv
40  * @{
41  */
42 
43 #if !defined(__ARMCC_VERSION)
44 __attribute__((optimize("no-unroll-loops")))
45 #endif
46 static void
depthwise_conv_s8_mult_4(const int8_t * input,const int32_t input_x,const int32_t input_y,const int32_t input_ch,const int8_t * kernel,const int32_t output_ch,const int32_t ch_mult,const int32_t kernel_x,const int32_t kernel_y,const int32_t pad_x,const int32_t pad_y,const int32_t stride_x,const int32_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const int32_t output_x,const int32_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max)47 depthwise_conv_s8_mult_4(const int8_t *input,
48                          const int32_t input_x,
49                          const int32_t input_y,
50                          const int32_t input_ch,
51                          const int8_t *kernel,
52                          const int32_t output_ch,
53                          const int32_t ch_mult,
54                          const int32_t kernel_x,
55                          const int32_t kernel_y,
56                          const int32_t pad_x,
57                          const int32_t pad_y,
58                          const int32_t stride_x,
59                          const int32_t stride_y,
60                          const int32_t *bias,
61                          int8_t *output,
62                          const int32_t *output_shift,
63                          const int32_t *output_mult,
64                          const int32_t output_x,
65                          const int32_t output_y,
66                          const int32_t output_offset,
67                          const int32_t input_offset,
68                          const int32_t output_activation_min,
69                          const int32_t output_activation_max)
70 {
71     const int32_t *bias_base = bias;
72     const int32_t *mult_base = output_mult;
73     const int32_t *shift_base = output_shift;
74     const int8_t *kernel_base = kernel;
75 
76     for (int32_t in_h = -pad_y, out_h = 0; out_h < output_y; in_h += stride_y, ++out_h)
77     {
78         for (int32_t in_w = -pad_x, out_w = 0, ker_h_start = MAX(0, -in_h); out_w < output_x; in_w += stride_x, ++out_w)
79         {
80             bias = bias_base;
81             output_mult = mult_base;
82             output_shift = shift_base;
83             for (int32_t in_ch = 0, out_ch = 0, ker_w_start = MAX(0, -in_w); out_ch < output_ch;
84                  ++in_ch, out_ch += ch_mult)
85             {
86                 for (int mult_tile = 0; mult_tile < ch_mult; mult_tile += 4)
87                 {
88                     int32_t out_buff[4] = {0, 0, 0, 0};
89                     if (bias)
90                     {
91                         out_buff[0] = *bias++;
92                         out_buff[1] = *bias++;
93                         out_buff[2] = *bias++;
94                         out_buff[3] = *bias++;
95                     }
96 
97                     for (int32_t ker_h = ker_h_start; ker_h < MIN(kernel_y, input_y - in_h); ++ker_h)
98                     {
99                         int32_t ker_idx = ker_h * (output_ch * kernel_x) + ker_w_start * output_ch + out_ch;
100                         kernel = kernel_base + mult_tile + ker_idx;
101                         int32_t in_idx = (in_h + ker_h) * (input_ch * input_x) + in_w * input_ch + in_ch;
102 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
103 #pragma clang loop unroll(disable)
104 #endif
105                         for (int32_t ker_w = ker_w_start; ker_w < MIN(kernel_x, input_x - in_w);
106                              ++ker_w, kernel += output_ch)
107                         {
108                             int32_t in_val = input[in_idx + ker_w * input_ch] + input_offset;
109                             out_buff[0] += in_val * kernel[0];
110                             out_buff[1] += in_val * kernel[1];
111                             out_buff[2] += in_val * kernel[2];
112                             out_buff[3] += in_val * kernel[3];
113                         }
114                     }
115 #if defined(ARM_MATH_MVEI)
116                     int32x4_t res = vldrwq_s32(out_buff);
117                     res = arm_requantize_mve_32x4(res, vldrwq_s32(output_mult), vldrwq_s32(output_shift));
118                     output_mult += 4;
119                     output_shift += 4;
120                     res = vaddq_n_s32(res, output_offset);
121 
122                     res = vmaxq_s32(res, vdupq_n_s32(output_activation_min));
123                     res = vminq_s32(res, vdupq_n_s32(output_activation_max));
124                     vstrbq_s32(output, res);
125                     output += 4;
126 #else
127                     out_buff[0] = arm_nn_requantize(out_buff[0], *output_mult++, *output_shift++);
128                     out_buff[1] = arm_nn_requantize(out_buff[1], *output_mult++, *output_shift++);
129                     out_buff[2] = arm_nn_requantize(out_buff[2], *output_mult++, *output_shift++);
130                     out_buff[3] = arm_nn_requantize(out_buff[3], *output_mult++, *output_shift++);
131 
132                     out_buff[0] += output_offset;
133                     out_buff[1] += output_offset;
134                     out_buff[2] += output_offset;
135                     out_buff[3] += output_offset;
136 
137                     out_buff[0] = MIN(MAX(out_buff[0], output_activation_min), output_activation_max);
138                     out_buff[1] = MIN(MAX(out_buff[1], output_activation_min), output_activation_max);
139                     out_buff[2] = MIN(MAX(out_buff[2], output_activation_min), output_activation_max);
140                     out_buff[3] = MIN(MAX(out_buff[3], output_activation_min), output_activation_max);
141 
142                     *output++ = (int8_t)out_buff[0];
143                     *output++ = (int8_t)out_buff[1];
144                     *output++ = (int8_t)out_buff[2];
145                     *output++ = (int8_t)out_buff[3];
146 
147 #endif
148                 }
149             }
150         }
151     }
152 }
153 
depthwise_conv_s8_generic(const int8_t * input,const uint16_t input_batches,const uint16_t input_x,const uint16_t input_y,const uint16_t input_ch,const int8_t * kernel,const uint16_t output_ch,const uint16_t ch_mult,const uint16_t kernel_x,const uint16_t kernel_y,const uint16_t pad_x,const uint16_t pad_y,const uint16_t stride_x,const uint16_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const uint16_t output_x,const uint16_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max,const uint16_t dilation_x,const uint16_t dilation_y)154 static void depthwise_conv_s8_generic(const int8_t *input,
155                                       const uint16_t input_batches,
156                                       const uint16_t input_x,
157                                       const uint16_t input_y,
158                                       const uint16_t input_ch,
159                                       const int8_t *kernel,
160                                       const uint16_t output_ch,
161                                       const uint16_t ch_mult,
162                                       const uint16_t kernel_x,
163                                       const uint16_t kernel_y,
164                                       const uint16_t pad_x,
165                                       const uint16_t pad_y,
166                                       const uint16_t stride_x,
167                                       const uint16_t stride_y,
168                                       const int32_t *bias,
169                                       int8_t *output,
170                                       const int32_t *output_shift,
171                                       const int32_t *output_mult,
172                                       const uint16_t output_x,
173                                       const uint16_t output_y,
174                                       const int32_t output_offset,
175                                       const int32_t input_offset,
176                                       const int32_t output_activation_min,
177                                       const int32_t output_activation_max,
178                                       const uint16_t dilation_x,
179                                       const uint16_t dilation_y)
180 
181 {
182     (void)output_ch;
183     int i_out = 0;
184     int i_batch;
185 
186     for (i_batch = 0; i_batch < input_batches; i_batch++)
187     {
188         for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
189         {
190             const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
191             for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
192             {
193                 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
194                 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
195                 {
196                     for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
197                     {
198                         const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
199                         int32_t acc_0 = 0;
200 
201                         int ker_y_start;
202                         int ker_x_start;
203                         int ker_y_end;
204                         int ker_x_end;
205 
206                         if (dilation_x > 1)
207                         {
208                             const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
209                             ker_x_start = MAX(0, start_x_max);
210                             const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
211                             ker_x_end = MIN(kernel_x, end_min_x);
212                         }
213                         else
214                         {
215                             ker_x_start = MAX(0, -base_idx_x);
216                             ker_x_end = MIN(kernel_x, input_x - base_idx_x);
217                         }
218 
219                         if (dilation_y > 1)
220                         {
221                             const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
222                             ker_y_start = MAX(0, start_y_max);
223                             const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
224                             ker_y_end = MIN(kernel_y, end_min_y);
225                         }
226                         else
227                         {
228                             ker_y_start = MAX(0, -base_idx_y);
229                             ker_y_end = MIN(kernel_y, input_y - base_idx_y);
230                         }
231 
232                         if (bias)
233                         {
234                             acc_0 = bias[idx_out_ch];
235                         }
236 
237                         for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
238                         {
239                             const int32_t idx_y = base_idx_y + dilation_y * i_ker_y;
240                             for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
241                             {
242                                 const int32_t idx_x = base_idx_x + dilation_x * i_ker_x;
243                                 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
244                                 int32_t ker_idx_0 = (i_ker_y * kernel_x + i_ker_x) * (input_ch * ch_mult) + idx_out_ch;
245 
246                                 acc_0 += (input[idx_0] + input_offset) * kernel[ker_idx_0];
247                             }
248                         }
249 
250                         /* Requantize and clamp output to provided range */
251                         acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
252                         acc_0 += output_offset;
253                         acc_0 = MAX(acc_0, output_activation_min);
254                         acc_0 = MIN(acc_0, output_activation_max);
255 
256                         output[i_out++] = acc_0;
257                     }
258                 }
259             }
260         }
261         /* Advance to the next batch */
262         input += (input_x * input_y * input_ch);
263     }
264 }
265 
266 /*
267  *  Basic s8 depthwise convolution function.
268  *
269  *  Refer header file for details.
270  *  Optimization using DSP extension is not available for the generic case where channel multiplier is > 1.
271  *
272  */
arm_depthwise_conv_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)273 arm_cmsis_nn_status arm_depthwise_conv_s8(const cmsis_nn_context *ctx,
274                                           const cmsis_nn_dw_conv_params *dw_conv_params,
275                                           const cmsis_nn_per_channel_quant_params *quant_params,
276                                           const cmsis_nn_dims *input_dims,
277                                           const int8_t *input,
278                                           const cmsis_nn_dims *filter_dims,
279                                           const int8_t *kernel,
280                                           const cmsis_nn_dims *bias_dims,
281                                           const int32_t *bias,
282                                           const cmsis_nn_dims *output_dims,
283                                           int8_t *output)
284 {
285     const uint16_t dilation_x = dw_conv_params->dilation.w;
286     const uint16_t dilation_y = dw_conv_params->dilation.h;
287 
288     (void)bias_dims;
289     (void)ctx;
290 
291     if (dw_conv_params->ch_mult % 4 == 0 && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
292         dw_conv_params->dilation.h == 1)
293     {
294         depthwise_conv_s8_mult_4(input,
295                                  input_dims->w,
296                                  input_dims->h,
297                                  input_dims->c,
298                                  kernel,
299                                  output_dims->c,
300                                  dw_conv_params->ch_mult,
301                                  filter_dims->w,
302                                  filter_dims->h,
303                                  dw_conv_params->padding.w,
304                                  dw_conv_params->padding.h,
305                                  dw_conv_params->stride.w,
306                                  dw_conv_params->stride.h,
307                                  bias,
308                                  output,
309                                  quant_params->shift,
310                                  quant_params->multiplier,
311                                  output_dims->w,
312                                  output_dims->h,
313                                  dw_conv_params->output_offset,
314                                  dw_conv_params->input_offset,
315                                  dw_conv_params->activation.min,
316                                  dw_conv_params->activation.max);
317     }
318     else
319     {
320         depthwise_conv_s8_generic(input,
321                                   input_dims->n,
322                                   input_dims->w,
323                                   input_dims->h,
324                                   input_dims->c,
325                                   kernel,
326                                   output_dims->c,
327                                   dw_conv_params->ch_mult,
328                                   filter_dims->w,
329                                   filter_dims->h,
330                                   dw_conv_params->padding.w,
331                                   dw_conv_params->padding.h,
332                                   dw_conv_params->stride.w,
333                                   dw_conv_params->stride.h,
334                                   bias,
335                                   output,
336                                   quant_params->shift,
337                                   quant_params->multiplier,
338                                   output_dims->w,
339                                   output_dims->h,
340                                   dw_conv_params->output_offset,
341                                   dw_conv_params->input_offset,
342                                   dw_conv_params->activation.min,
343                                   dw_conv_params->activation.max,
344                                   dilation_x,
345                                   dilation_y);
346     }
347 
348     /* Return to application */
349     return ARM_CMSIS_NN_SUCCESS;
350 }
351 
352 /**
353  * @} end of NNConv group
354  */
355