1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_1_x_n_s8.c
22  * Description:  s8 version of 1xN convolution using symmetric quantization.
23  *
24  * $Date:        04 November 2024
25  * $Revision:    V.3.6.1
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34  *  @ingroup Public
35  */
36 
37 /**
38  * @addtogroup NNConv
39  * @{
40  */
41 
42 /*
43  * 1xN s8 convolution function.
44  *
45  * Refer header file for details.
46  *
47  */
arm_convolve_1_x_n_s8(const cmsis_nn_context * ctx,const cmsis_nn_conv_params * conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)48 arm_cmsis_nn_status arm_convolve_1_x_n_s8(const cmsis_nn_context *ctx,
49                                           const cmsis_nn_conv_params *conv_params,
50                                           const cmsis_nn_per_channel_quant_params *quant_params,
51                                           const cmsis_nn_dims *input_dims,
52                                           const int8_t *input_data,
53                                           const cmsis_nn_dims *filter_dims,
54                                           const int8_t *filter_data,
55                                           const cmsis_nn_dims *bias_dims,
56                                           const int32_t *bias_data,
57                                           const cmsis_nn_dims *output_dims,
58                                           int8_t *output_data)
59 {
60     arm_cmsis_nn_status status = ARM_CMSIS_NN_SUCCESS;
61 
62     /* The wrapper API is the ultimate reference for argument check */
63     if ((input_dims->h != 1) || conv_params->dilation.w != 1 || ctx->buf == NULL || conv_params->stride.w == 0 ||
64         (conv_params->stride.w * input_dims->c % 4 != 0))
65     {
66         return ARM_CMSIS_NN_ARG_ERROR;
67     }
68 
69 #if defined(ARM_MATH_MVEI)
70     (void)bias_dims;
71 
72     const int32_t input_x = input_dims->w;
73     const int32_t kernel_x = filter_dims->w;
74     const int32_t output_x = output_dims->w;
75     const int32_t input_ch = input_dims->c;
76     const int32_t pad_x = conv_params->padding.w;
77     const int32_t stride_x = conv_params->stride.w;
78 
79     // Total pad for dilation of 1
80     const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
81     const int32_t asym_pad = total_pad % 2;
82 
83     if (pad_x * 2 + asym_pad != total_pad)
84     {
85         return ARM_CMSIS_NN_FAILURE;
86     }
87 
88     const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
89     const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
90     const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
91 
92     const int32_t pad_size_left = pad_x * input_ch;
93     const int32_t pad_size_right = asym_pad ? right_pad_num * input_ch : pad_size_left;
94 
95     const int32_t rhs_cols = kernel_x * input_ch;
96     const int32_t rhs_rows = output_dims->c;
97     const int32_t lhs_offset = input_ch * stride_x;
98 
99     if (right_pad_num + no_pad_num + left_pad_num != output_x)
100     {
101         return arm_convolve_s8(ctx,
102                                conv_params,
103                                quant_params,
104                                input_dims,
105                                input_data,
106                                filter_dims,
107                                filter_data,
108                                bias_dims,
109                                bias_data,
110                                NULL,
111                                output_dims,
112                                output_data);
113     }
114 
115     const uint32_t num_elem_left = kernel_x * input_ch;
116     const uint32_t num_elem_right = num_elem_left - input_ch;
117 
118     for (int i_batch = 0; i_batch < input_dims->n; i_batch++)
119     {
120         /* Handle left padded sections */
121         int32_t lhs_rows = left_pad_num;
122         int8_t *im2col = ctx->buf;
123 
124         arm_memset_s8(im2col, (int8_t)-conv_params->input_offset, sizeof(int8_t) * (uint32_t)pad_size_left);
125         im2col += pad_size_left;
126         arm_memcpy_s8(im2col, input_data, sizeof(int8_t) * num_elem_left);
127 
128         arm_nn_mat_mult_nt_t_s8((int8_t *)ctx->buf,
129                                 filter_data,
130                                 bias_data,
131                                 output_data,
132                                 quant_params->multiplier,
133                                 quant_params->shift,
134                                 lhs_rows,
135                                 rhs_rows,
136                                 rhs_cols,
137                                 conv_params->input_offset,
138                                 conv_params->output_offset,
139                                 conv_params->activation.min,
140                                 conv_params->activation.max,
141                                 rhs_rows,
142                                 lhs_offset);
143 
144         output_data += lhs_rows * rhs_rows;
145 
146         /* Non padded elements */
147         int32_t out_idx = lhs_rows;
148         int32_t input_start = stride_x * lhs_rows - pad_x;
149 
150         if (input_start < 0)
151         {
152             return ARM_CMSIS_NN_FAILURE;
153         }
154 
155         input_start *= input_ch;
156         lhs_rows = no_pad_num;
157 
158         arm_nn_mat_mult_nt_t_s8(input_data + input_start,
159                                 filter_data,
160                                 bias_data,
161                                 output_data,
162                                 quant_params->multiplier,
163                                 quant_params->shift,
164                                 lhs_rows,
165                                 rhs_rows,
166                                 rhs_cols,
167                                 conv_params->input_offset,
168                                 conv_params->output_offset,
169                                 conv_params->activation.min,
170                                 conv_params->activation.max,
171                                 rhs_rows,
172                                 lhs_offset);
173 
174         output_data += lhs_rows * rhs_rows;
175         out_idx += lhs_rows;
176 
177         /* Right padded elements */
178         lhs_rows = output_x - out_idx;
179 
180         if (lhs_rows < 0)
181         {
182             return ARM_CMSIS_NN_FAILURE;
183         }
184 
185         im2col = ctx->buf;
186         input_start = (stride_x * (left_pad_num + no_pad_num) - pad_x) * input_ch;
187 
188         arm_memcpy_s8(im2col, input_data + input_start, sizeof(int8_t) * num_elem_right);
189         im2col += num_elem_right;
190         arm_memset_s8(im2col, (int8_t)-conv_params->input_offset, sizeof(int8_t) * (uint32_t)pad_size_right);
191 
192         arm_nn_mat_mult_nt_t_s8((int8_t *)ctx->buf,
193                                 filter_data,
194                                 bias_data,
195                                 output_data,
196                                 quant_params->multiplier,
197                                 quant_params->shift,
198                                 lhs_rows,
199                                 rhs_rows,
200                                 rhs_cols,
201                                 conv_params->input_offset,
202                                 conv_params->output_offset,
203                                 conv_params->activation.min,
204                                 conv_params->activation.max,
205                                 rhs_rows,
206                                 lhs_offset);
207 
208         output_data += lhs_rows * rhs_rows;
209 
210         /* Advance to the next batch */
211         input_data += (input_x * input_ch);
212     }
213 #else
214     status = arm_convolve_s8(ctx,
215                              conv_params,
216                              quant_params,
217                              input_dims,
218                              input_data,
219                              filter_dims,
220                              filter_data,
221                              bias_dims,
222                              bias_data,
223                              NULL,
224                              output_dims,
225                              output_data);
226 
227 #endif
228 
229     /* Return to application */
230     return status;
231 }
232 
233 /**
234  * @} end of NNConv group
235  */
236