1 /*
2  * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_3x3_s8.c
22  * Description:  Optimized s8 depthwise convolution function for channel
23  *               multiplier of 1 and 3x3 kernel size.
24  *
25  * $Date:        09. October 2020
26  * $Revision:    V.2.0.1
27  *
28  * Target Processor:  Cortex-M CPUs
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup groupNN
37  */
38 
39 /**
40  * @addtogroup NNConv
41  * @{
42  */
43 
44 /*
45  * Optimized s8 depthwise convolution function with constraint that
46  * in_channel == out_channel and kernel_x == kernel_y == 3 with pads at most 1
47  *
48  *  Refer prototype header file for details.
49  *
50  */
51 
arm_depthwise_conv_3x3_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const q7_t * input,const cmsis_nn_dims * filter_dims,const q7_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,q7_t * output)52 arm_status arm_depthwise_conv_3x3_s8(const cmsis_nn_context *ctx,
53                                      const cmsis_nn_dw_conv_params *dw_conv_params,
54                                      const cmsis_nn_per_channel_quant_params *quant_params,
55                                      const cmsis_nn_dims *input_dims,
56                                      const q7_t *input,
57                                      const cmsis_nn_dims *filter_dims,
58                                      const q7_t *kernel,
59                                      const cmsis_nn_dims *bias_dims,
60                                      const int32_t *bias,
61                                      const cmsis_nn_dims *output_dims,
62                                      q7_t *output)
63 {
64     (void)ctx;
65     (void)bias_dims;
66 
67     const int32_t input_x = input_dims->w;
68     const int32_t input_y = input_dims->h;
69     const int32_t input_ch = input_dims->c;
70     const int32_t output_ch = output_dims->c;
71     const int32_t pad_x = dw_conv_params->padding.w;
72     const int32_t pad_y = dw_conv_params->padding.h;
73     const int32_t stride_x = dw_conv_params->stride.w;
74     const int32_t stride_y = dw_conv_params->stride.h;
75     const int32_t *output_shift = quant_params->shift;
76     const int32_t *output_mult = quant_params->multiplier;
77     const int32_t output_x = output_dims->w;
78     const int32_t output_y = output_dims->h;
79     const int32_t output_offset = dw_conv_params->output_offset;
80     const int32_t input_offset = dw_conv_params->input_offset;
81     const int32_t output_activation_min = dw_conv_params->activation.min;
82     const int32_t output_activation_max = dw_conv_params->activation.max;
83 
84     /* Check input constraints input_ch == output_ch */
85     if (input_ch != output_ch)
86     {
87         return ARM_MATH_SIZE_MISMATCH;
88     }
89     /* Check input constraints pad_x <= 1 */
90     if (pad_x > 1 || filter_dims->w != 3 || filter_dims->h != 3)
91     {
92         return ARM_MATH_ARGUMENT_ERROR;
93     }
94 
95     for (int32_t in_h = -pad_y, out_h = 0, out_idx = 0; out_h < output_y; in_h += stride_y, ++out_h)
96     {
97         for (int32_t in_w = -pad_x, out_w = 0, ker_h_start = MAX(0, -in_h); out_w < output_x; in_w += stride_x, ++out_w)
98         {
99             int32_t in_ch = 0;
100             int32_t ker_w_start = MAX(0, -in_w);
101 
102             for (; in_ch <= (input_ch - 4); in_ch += 4)
103             {
104                 int32_t out_buff0 = bias[in_ch + 0];
105                 int32_t out_buff1 = bias[in_ch + 1];
106                 int32_t out_buff2 = bias[in_ch + 2];
107                 int32_t out_buff3 = bias[in_ch + 3];
108 
109                 const int8_t *input_ptr = input + (in_h + ker_h_start) * (input_ch * input_x) + in_w * input_ch + in_ch;
110                 const int8_t *kernel_ptr = kernel + ker_h_start * (input_ch * 3) + in_ch;
111 
112                 for (int32_t ker_h = ker_h_start; ker_h < MIN(3, input_y - in_h); ++ker_h)
113                 {
114                     int32_t in_val = 0;
115                     int32_t ker_val = 0;
116 
117                     if (ker_w_start == 0)
118                     {
119                         in_val = arm_nn_read_q7x4(input_ptr);
120                         ker_val = arm_nn_read_q7x4(kernel_ptr);
121 
122                         out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
123                         out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
124                         out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
125                         out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
126                     }
127 
128                     in_val = arm_nn_read_q7x4(input_ptr + input_ch);
129                     ker_val = arm_nn_read_q7x4(kernel_ptr + input_ch);
130 
131                     out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
132                     out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
133                     out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
134                     out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
135 
136                     if ((input_x - in_w) >= 3)
137                     {
138                         in_val = arm_nn_read_q7x4(input_ptr + (input_ch << 1));
139                         ker_val = arm_nn_read_q7x4(kernel_ptr + (input_ch << 1));
140 
141                         out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
142                         out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
143                         out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
144                         out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
145                     }
146 
147                     input_ptr += (input_ch * input_x);
148                     kernel_ptr += (input_ch * 3);
149                 }
150 
151                 out_buff0 = arm_nn_requantize(out_buff0, output_mult[in_ch + 0], output_shift[in_ch + 0]);
152                 out_buff1 = arm_nn_requantize(out_buff1, output_mult[in_ch + 1], output_shift[in_ch + 1]);
153                 out_buff2 = arm_nn_requantize(out_buff2, output_mult[in_ch + 2], output_shift[in_ch + 2]);
154                 out_buff3 = arm_nn_requantize(out_buff3, output_mult[in_ch + 3], output_shift[in_ch + 3]);
155 
156                 out_buff0 += output_offset;
157                 out_buff1 += output_offset;
158                 out_buff2 += output_offset;
159                 out_buff3 += output_offset;
160 
161                 out_buff0 = MIN(MAX(out_buff0, output_activation_min), output_activation_max);
162                 out_buff1 = MIN(MAX(out_buff1, output_activation_min), output_activation_max);
163                 out_buff2 = MIN(MAX(out_buff2, output_activation_min), output_activation_max);
164                 out_buff3 = MIN(MAX(out_buff3, output_activation_min), output_activation_max);
165 
166                 output[out_idx++] = (int8_t)out_buff0;
167                 output[out_idx++] = (int8_t)out_buff1;
168                 output[out_idx++] = (int8_t)out_buff2;
169                 output[out_idx++] = (int8_t)out_buff3;
170             }
171 
172             // Leftover
173             for (; in_ch < input_ch; ++in_ch)
174             {
175                 int32_t out_buff = bias[in_ch];
176 
177                 const int8_t *input_ptr = input + (in_h + ker_h_start) * (input_ch * input_x) + in_w * input_ch + in_ch;
178                 const int8_t *kernel_ptr = kernel + ker_h_start * (input_ch * 3) + in_ch;
179 
180                 for (int32_t ker_h = ker_h_start; ker_h < MIN(3, input_y - in_h); ++ker_h)
181                 {
182                     if (ker_w_start == 0)
183                     {
184                         out_buff += (*(input_ptr) + input_offset) * *(kernel_ptr);
185                     }
186 
187                     out_buff += (*(input_ptr + input_ch) + input_offset) * *(kernel_ptr + input_ch);
188 
189                     if ((input_x - in_w) >= 3)
190                     {
191                         out_buff += (*(input_ptr + (input_ch << 1)) + input_offset) * *(kernel_ptr + (input_ch << 1));
192                     }
193 
194                     input_ptr += (input_ch * input_x);
195                     kernel_ptr += (input_ch * 3);
196                 }
197 
198                 out_buff = arm_nn_requantize(out_buff, output_mult[in_ch], output_shift[in_ch]);
199                 out_buff += output_offset;
200                 out_buff = MIN(MAX(out_buff, output_activation_min), output_activation_max);
201                 output[out_idx++] = (int8_t)out_buff;
202             }
203         }
204     }
205 
206     /* Return to application */
207     return ARM_MATH_SUCCESS;
208 }
209 
210 /**
211  * @} end of NNConv group
212  */
213