1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_1_x_n_s8.c
22  * Description:  s8 version of 1xN convolution using symmetric quantization.
23  *
24  * $Date:        19 March 2024
25  * $Revision:    V.3.6.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34  *  @ingroup Public
35  */
36 
37 /**
38  * @addtogroup NNConv
39  * @{
40  */
41 
42 /*
43  * 1xN s8 convolution function.
44  *
45  * Refer header file for details.
46  *
47  */
arm_convolve_1_x_n_s8(const cmsis_nn_context * ctx,const cmsis_nn_conv_params * conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)48 arm_cmsis_nn_status arm_convolve_1_x_n_s8(const cmsis_nn_context *ctx,
49                                           const cmsis_nn_conv_params *conv_params,
50                                           const cmsis_nn_per_channel_quant_params *quant_params,
51                                           const cmsis_nn_dims *input_dims,
52                                           const int8_t *input_data,
53                                           const cmsis_nn_dims *filter_dims,
54                                           const int8_t *filter_data,
55                                           const cmsis_nn_dims *bias_dims,
56                                           const int32_t *bias_data,
57                                           const cmsis_nn_dims *output_dims,
58                                           int8_t *output_data)
59 {
60     arm_cmsis_nn_status status = ARM_CMSIS_NN_SUCCESS;
61 
62     /* The wrapper API is the ultimate reference for argument check */
63     if ((input_dims->h != 1) || conv_params->dilation.w != 1 || ctx->buf == NULL || conv_params->stride.w == 0 ||
64         (conv_params->stride.w * input_dims->c % 4 != 0))
65     {
66         return ARM_CMSIS_NN_ARG_ERROR;
67     }
68 
69 #if defined(ARM_MATH_MVEI)
70     (void)bias_dims;
71 
72     const int32_t input_x = input_dims->w;
73     const int32_t kernel_x = filter_dims->w;
74     const int32_t output_x = output_dims->w;
75     const int32_t input_ch = input_dims->c;
76     const int32_t pad_x = conv_params->padding.w;
77     const int32_t stride_x = conv_params->stride.w;
78 
79     // Total pad for dilation of 1
80     const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
81     const int32_t asym_pad = total_pad % 2;
82 
83     if (pad_x * 2 + asym_pad != total_pad)
84     {
85         return ARM_CMSIS_NN_FAILURE;
86     }
87 
88     const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
89     const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
90     const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
91 
92     const int32_t pad_size_left = pad_x * input_ch;
93     const int32_t pad_size_right = asym_pad ? right_pad_num * input_ch : pad_size_left;
94 
95     const int32_t rhs_cols = kernel_x * input_ch;
96     const int32_t rhs_rows = output_dims->c;
97     const int32_t lhs_offset = input_ch * stride_x;
98 
99     if (right_pad_num + no_pad_num + left_pad_num != output_x)
100     {
101         return arm_convolve_s8(ctx,
102                                conv_params,
103                                quant_params,
104                                input_dims,
105                                input_data,
106                                filter_dims,
107                                filter_data,
108                                bias_dims,
109                                bias_data,
110                                output_dims,
111                                output_data);
112     }
113 
114     const uint32_t num_elem_left = kernel_x * input_ch;
115     const uint32_t num_elem_right = num_elem_left - input_ch;
116 
117     for (int i_batch = 0; i_batch < input_dims->n; i_batch++)
118     {
119         /* Handle left padded sections */
120         int32_t lhs_rows = left_pad_num;
121         int8_t *im2col = ctx->buf;
122 
123         arm_memset_s8(im2col, (int8_t)-conv_params->input_offset, sizeof(int8_t) * (uint32_t)pad_size_left);
124         im2col += pad_size_left;
125         arm_memcpy_s8(im2col, input_data, sizeof(int8_t) * num_elem_left);
126 
127         arm_nn_mat_mult_nt_t_s8((int8_t *)ctx->buf,
128                                 filter_data,
129                                 bias_data,
130                                 output_data,
131                                 quant_params->multiplier,
132                                 quant_params->shift,
133                                 lhs_rows,
134                                 rhs_rows,
135                                 rhs_cols,
136                                 conv_params->input_offset,
137                                 conv_params->output_offset,
138                                 conv_params->activation.min,
139                                 conv_params->activation.max,
140                                 rhs_rows,
141                                 lhs_offset);
142 
143         output_data += lhs_rows * rhs_rows;
144 
145         /* Non padded elements */
146         int32_t out_idx = lhs_rows;
147         int32_t input_start = stride_x * lhs_rows - pad_x;
148 
149         if (input_start < 0)
150         {
151             return ARM_CMSIS_NN_FAILURE;
152         }
153 
154         input_start *= input_ch;
155         lhs_rows = no_pad_num;
156 
157         arm_nn_mat_mult_nt_t_s8(input_data + input_start,
158                                 filter_data,
159                                 bias_data,
160                                 output_data,
161                                 quant_params->multiplier,
162                                 quant_params->shift,
163                                 lhs_rows,
164                                 rhs_rows,
165                                 rhs_cols,
166                                 conv_params->input_offset,
167                                 conv_params->output_offset,
168                                 conv_params->activation.min,
169                                 conv_params->activation.max,
170                                 rhs_rows,
171                                 lhs_offset);
172 
173         output_data += lhs_rows * rhs_rows;
174         out_idx += lhs_rows;
175 
176         /* Right padded elements */
177         lhs_rows = output_x - out_idx;
178 
179         if (lhs_rows < 0)
180         {
181             return ARM_CMSIS_NN_FAILURE;
182         }
183 
184         im2col = ctx->buf;
185         input_start = (stride_x * (left_pad_num + no_pad_num) - pad_x) * input_ch;
186 
187         arm_memcpy_s8(im2col, input_data + input_start, sizeof(int8_t) * num_elem_right);
188         im2col += num_elem_right;
189         arm_memset_s8(im2col, (int8_t)-conv_params->input_offset, sizeof(int8_t) * (uint32_t)pad_size_right);
190 
191         arm_nn_mat_mult_nt_t_s8((int8_t *)ctx->buf,
192                                 filter_data,
193                                 bias_data,
194                                 output_data,
195                                 quant_params->multiplier,
196                                 quant_params->shift,
197                                 lhs_rows,
198                                 rhs_rows,
199                                 rhs_cols,
200                                 conv_params->input_offset,
201                                 conv_params->output_offset,
202                                 conv_params->activation.min,
203                                 conv_params->activation.max,
204                                 rhs_rows,
205                                 lhs_offset);
206 
207         output_data += lhs_rows * rhs_rows;
208 
209         /* Advance to the next batch */
210         input_data += (input_x * input_ch);
211     }
212 #else
213     status = arm_convolve_s8(ctx,
214                              conv_params,
215                              quant_params,
216                              input_dims,
217                              input_data,
218                              filter_dims,
219                              filter_data,
220                              bias_dims,
221                              bias_data,
222                              output_dims,
223                              output_data);
224 
225 #endif
226 
227     /* Return to application */
228     return status;
229 }
230 
231 /**
232  * @} end of NNConv group
233  */
234