1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_3x3_s8.c
22  * Description:  Optimized s8 depthwise convolution function for channel
23  *               multiplier of 1 and 3x3 kernel size.
24  *
25  * $Date:        5 January 2023
26  * $Revision:    V.3.2.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup Public
37  */
38 
39 /**
40  * @addtogroup NNConv
41  * @{
42  */
43 
44 /*
45  * Optimized s8 depthwise convolution function with constraint that
46  * in_channel == out_channel and kernel_x == kernel_y == 3 with pads at most 1
47  *
48  *  Refer prototype header file for details.
49  *
50  */
51 
arm_depthwise_conv_3x3_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)52 arm_cmsis_nn_status arm_depthwise_conv_3x3_s8(const cmsis_nn_context *ctx,
53                                               const cmsis_nn_dw_conv_params *dw_conv_params,
54                                               const cmsis_nn_per_channel_quant_params *quant_params,
55                                               const cmsis_nn_dims *input_dims,
56                                               const int8_t *input,
57                                               const cmsis_nn_dims *filter_dims,
58                                               const int8_t *kernel,
59                                               const cmsis_nn_dims *bias_dims,
60                                               const int32_t *bias,
61                                               const cmsis_nn_dims *output_dims,
62                                               int8_t *output)
63 {
64     (void)ctx;
65     (void)bias_dims;
66 
67     const int32_t input_x = input_dims->w;
68     const int32_t input_y = input_dims->h;
69     const int32_t input_ch = input_dims->c;
70     const int32_t output_ch = output_dims->c;
71     const int32_t pad_x = dw_conv_params->padding.w;
72     const int32_t pad_y = dw_conv_params->padding.h;
73     const int32_t stride_x = dw_conv_params->stride.w;
74     const int32_t stride_y = dw_conv_params->stride.h;
75     const int32_t *output_shift = quant_params->shift;
76     const int32_t *output_mult = quant_params->multiplier;
77     const int32_t output_x = output_dims->w;
78     const int32_t output_y = output_dims->h;
79     const int32_t output_offset = dw_conv_params->output_offset;
80     const int32_t input_offset = dw_conv_params->input_offset;
81     const int32_t output_activation_min = dw_conv_params->activation.min;
82     const int32_t output_activation_max = dw_conv_params->activation.max;
83 
84     /* Check input constraints input_ch == output_ch */
85     if (input_ch != output_ch)
86     {
87         return ARM_CMSIS_NN_ARG_ERROR;
88     }
89     /* Check input constraints pad_x <= 1 */
90     if (pad_x > 1 || filter_dims->w != 3 || filter_dims->h != 3)
91     {
92         return ARM_CMSIS_NN_ARG_ERROR;
93     }
94     const int32_t *bias_base = bias;
95     for (int32_t in_h = -pad_y, out_h = 0, out_idx = 0; out_h < output_y; in_h += stride_y, ++out_h)
96     {
97         for (int32_t in_w = -pad_x, out_w = 0, ker_h_start = MAX(0, -in_h); out_w < output_x; in_w += stride_x, ++out_w)
98         {
99             int32_t in_ch = 0;
100             int32_t ker_w_start = MAX(0, -in_w);
101 
102             bias = bias_base;
103             for (; in_ch <= (input_ch - 4); in_ch += 4)
104             {
105                 int32_t out_buff0 = 0;
106                 int32_t out_buff1 = 0;
107                 int32_t out_buff2 = 0;
108                 int32_t out_buff3 = 0;
109                 if (bias)
110                 {
111                     out_buff0 = *bias++;
112                     out_buff1 = *bias++;
113                     out_buff2 = *bias++;
114                     out_buff3 = *bias++;
115                 }
116 
117                 const int8_t *input_ptr = input + (in_h + ker_h_start) * (input_ch * input_x) + in_w * input_ch + in_ch;
118                 const int8_t *kernel_ptr = kernel + ker_h_start * (input_ch * 3) + in_ch;
119 #if defined(ARM_MATH_DSP)
120                 const uint32_t lhs_offset_s16x2 = PKHBT(input_offset, input_offset, 16);
121 
122                 for (int32_t ker_h = ker_h_start; ker_h < MIN(3, input_y - in_h); ++ker_h)
123                 {
124                     int32_t in_val = 0;
125                     int32_t ker_val = 0;
126                     int32_t in_val_1 = 0;
127                     int32_t ker_val_1 = 0;
128 
129                     if (ker_w_start == 0)
130                     {
131                         in_val = arm_nn_read_s8x4(input_ptr);
132                         ker_val = arm_nn_read_s8x4(kernel_ptr);
133 
134                         in_val_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)in_val, 8);
135                         ker_val_1 = SXTB16_RORn((uint32_t)ker_val, 8);
136 
137                         out_buff1 = SMLABB(in_val_1, ker_val_1, out_buff1);
138                         in_val = SXTAB16(lhs_offset_s16x2, (uint32_t)in_val);
139                         out_buff3 = SMLATT(in_val_1, ker_val_1, out_buff3);
140                         ker_val = SXTB16((uint32_t)ker_val);
141                         out_buff0 = SMLABB(in_val, ker_val, out_buff0);
142                         out_buff2 = SMLATT(in_val, ker_val, out_buff2);
143                     }
144 
145                     in_val = arm_nn_read_s8x4(input_ptr + input_ch);
146                     ker_val = arm_nn_read_s8x4(kernel_ptr + input_ch);
147                     in_val_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)in_val, 8);
148                     ker_val_1 = SXTB16_RORn((uint32_t)ker_val, 8);
149 
150                     out_buff1 = SMLABB(in_val_1, ker_val_1, out_buff1);
151                     in_val = SXTAB16(lhs_offset_s16x2, (uint32_t)in_val);
152                     out_buff3 = SMLATT(in_val_1, ker_val_1, out_buff3);
153                     ker_val = SXTB16((uint32_t)ker_val);
154                     out_buff0 = SMLABB(in_val, ker_val, out_buff0);
155                     out_buff2 = SMLATT(in_val, ker_val, out_buff2);
156 
157                     if ((input_x - in_w) >= 3)
158                     {
159                         in_val = arm_nn_read_s8x4(input_ptr + (input_ch << 1));
160                         ker_val = arm_nn_read_s8x4(kernel_ptr + (input_ch << 1));
161                         in_val_1 = SXTAB16_RORn(lhs_offset_s16x2, (uint32_t)in_val, 8);
162                         ker_val_1 = SXTB16_RORn((uint32_t)ker_val, 8);
163 
164                         out_buff1 = SMLABB(in_val_1, ker_val_1, out_buff1);
165                         in_val = SXTAB16(lhs_offset_s16x2, (uint32_t)in_val);
166                         out_buff3 = SMLATT(in_val_1, ker_val_1, out_buff3);
167                         ker_val = SXTB16((uint32_t)ker_val);
168                         out_buff0 = SMLABB(in_val, ker_val, out_buff0);
169                         out_buff2 = SMLATT(in_val, ker_val, out_buff2);
170                     }
171 
172                     input_ptr += (input_ch * input_x);
173                     kernel_ptr += (input_ch * 3);
174                 }
175 
176 #else
177 
178                 for (int32_t ker_h = ker_h_start; ker_h < MIN(3, input_y - in_h); ++ker_h)
179                 {
180                     int32_t in_val = 0;
181                     int32_t ker_val = 0;
182 
183                     if (ker_w_start == 0)
184                     {
185                         in_val = arm_nn_read_s8x4(input_ptr);
186                         ker_val = arm_nn_read_s8x4(kernel_ptr);
187                         out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
188                         out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
189                         out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
190                         out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
191                     }
192 
193                     in_val = arm_nn_read_s8x4(input_ptr + input_ch);
194                     ker_val = arm_nn_read_s8x4(kernel_ptr + input_ch);
195 
196                     out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
197                     out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
198                     out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
199                     out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
200 
201                     if ((input_x - in_w) >= 3)
202                     {
203                         in_val = arm_nn_read_s8x4(input_ptr + (input_ch << 1));
204                         ker_val = arm_nn_read_s8x4(kernel_ptr + (input_ch << 1));
205 
206                         out_buff0 += ((int8_t)in_val + input_offset) * (int8_t)ker_val;
207                         out_buff1 += ((int8_t)(in_val >> 8) + input_offset) * (int8_t)(ker_val >> 8);
208                         out_buff2 += ((int8_t)(in_val >> 16) + input_offset) * (int8_t)(ker_val >> 16);
209                         out_buff3 += ((int8_t)(in_val >> 24) + input_offset) * (int8_t)(ker_val >> 24);
210                     }
211 
212                     input_ptr += (input_ch * input_x);
213                     kernel_ptr += (input_ch * 3);
214                 }
215 #endif
216 
217                 out_buff0 = arm_nn_requantize(out_buff0, output_mult[in_ch + 0], output_shift[in_ch + 0]);
218                 out_buff1 = arm_nn_requantize(out_buff1, output_mult[in_ch + 1], output_shift[in_ch + 1]);
219                 out_buff2 = arm_nn_requantize(out_buff2, output_mult[in_ch + 2], output_shift[in_ch + 2]);
220                 out_buff3 = arm_nn_requantize(out_buff3, output_mult[in_ch + 3], output_shift[in_ch + 3]);
221 
222                 out_buff0 += output_offset;
223                 out_buff1 += output_offset;
224                 out_buff2 += output_offset;
225                 out_buff3 += output_offset;
226 
227                 out_buff0 = MIN(MAX(out_buff0, output_activation_min), output_activation_max);
228                 out_buff1 = MIN(MAX(out_buff1, output_activation_min), output_activation_max);
229                 out_buff2 = MIN(MAX(out_buff2, output_activation_min), output_activation_max);
230                 out_buff3 = MIN(MAX(out_buff3, output_activation_min), output_activation_max);
231 
232                 output[out_idx++] = (int8_t)out_buff0;
233                 output[out_idx++] = (int8_t)out_buff1;
234                 output[out_idx++] = (int8_t)out_buff2;
235                 output[out_idx++] = (int8_t)out_buff3;
236             }
237 
238             // Leftover
239             for (; in_ch < input_ch; ++in_ch)
240             {
241                 int32_t out_buff = 0;
242                 if (bias)
243                 {
244                     out_buff = *bias++;
245                 }
246 
247                 const int8_t *input_ptr = input + (in_h + ker_h_start) * (input_ch * input_x) + in_w * input_ch + in_ch;
248                 const int8_t *kernel_ptr = kernel + ker_h_start * (input_ch * 3) + in_ch;
249 
250                 for (int32_t ker_h = ker_h_start; ker_h < MIN(3, input_y - in_h); ++ker_h)
251                 {
252                     if (ker_w_start == 0)
253                     {
254                         out_buff += (*(input_ptr) + input_offset) * *(kernel_ptr);
255                     }
256 
257                     out_buff += (*(input_ptr + input_ch) + input_offset) * *(kernel_ptr + input_ch);
258 
259                     if ((input_x - in_w) >= 3)
260                     {
261                         out_buff += (*(input_ptr + (input_ch << 1)) + input_offset) * *(kernel_ptr + (input_ch << 1));
262                     }
263 
264                     input_ptr += (input_ch * input_x);
265                     kernel_ptr += (input_ch * 3);
266                 }
267 
268                 out_buff = arm_nn_requantize(out_buff, output_mult[in_ch], output_shift[in_ch]);
269                 out_buff += output_offset;
270                 out_buff = MIN(MAX(out_buff, output_activation_min), output_activation_max);
271                 output[out_idx++] = (int8_t)out_buff;
272             }
273         }
274     }
275 
276     /* Return to application */
277     return ARM_CMSIS_NN_SUCCESS;
278 }
279 
280 /**
281  * @} end of NNConv group
282  */
283