1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_get_buffer_sizes_s8.c
22  * Description:  Collection of get buffer size functions for the various s8 convolution layer functions.
23  *
24  * $Date:        31 October 2024
25  * $Revision:    V.2.2.1
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup NNConv
37  */
38 
39 /**
40  * @addtogroup GetBufferSizeNNConv
41  * @{
42  */
arm_convolve_1x1_s8_fast_get_buffer_size_dsp(const cmsis_nn_dims * input_dims)43 __STATIC_INLINE int32_t arm_convolve_1x1_s8_fast_get_buffer_size_dsp(const cmsis_nn_dims *input_dims)
44 {
45 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
46     return (2 * input_dims->c) * (int32_t)sizeof(int16_t);
47 #else
48     (void)input_dims;
49     return 0;
50 #endif
51 }
52 
arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)53 __STATIC_INLINE int32_t arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
54                                                             const cmsis_nn_dims *filter_dims)
55 {
56     int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
57     // Get number of complete lanes with int8 elements (multiple of 16) for given col_length. This is dependent on
58     // implementation of arm_nn_mat_mult_nt_t_s8
59     col_length = (col_length + 15) / 16;
60     // 4 -> number of im2col buffers, 16 -> 16 elements per Q register
61     return 4 * col_length * 16 * (int32_t)sizeof(int8_t);
62 }
63 
arm_convolve_1_x_n_s8_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)64 __STATIC_INLINE int32_t arm_convolve_1_x_n_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
65                                                                   const cmsis_nn_dims *input_dims,
66                                                                   const cmsis_nn_dims *filter_dims,
67                                                                   const cmsis_nn_dims *output_dims)
68 {
69     const int32_t input_x = input_dims->w;
70     const int32_t pad_x = conv_params->padding.w;
71     const int32_t kernel_x = filter_dims->w;
72     const int32_t output_x = output_dims->w;
73     const int32_t stride_x = conv_params->stride.w;
74     const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
75     const int32_t asym_pad = total_pad % 2;
76 
77     const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
78     const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
79     const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
80 
81     if (right_pad_num + no_pad_num + left_pad_num != output_x)
82     {
83         return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
84     }
85 
86     const int32_t pad_size_left = pad_x * input_dims->c;
87     const int32_t pad_size_right = asym_pad ? right_pad_num * input_dims->c : pad_size_left;
88     const int32_t num_elem_left = kernel_x * input_dims->c;
89     const int32_t num_elem_right = num_elem_left - input_dims->c;
90     const int32_t size_1_x_n = MAX(num_elem_left + pad_size_left, num_elem_right + pad_size_right);
91 
92     return size_1_x_n;
93 }
94 
arm_convolve_s8_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)95 int32_t arm_convolve_s8_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
96 {
97 #if defined(ARM_MATH_MVEI)
98     return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
99 #else
100     const int32_t rhs_cols = filter_dims->w * filter_dims->h * input_dims->c;
101     const int32_t remainder = rhs_cols % 4;
102     const int32_t aligned_rhs_cols = remainder != 0 ? rhs_cols + 4 - remainder : rhs_cols;
103     return (2 * aligned_rhs_cols) * (int32_t)sizeof(int16_t);
104 #endif
105 }
106 
arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)107 int32_t arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
108                                               const cmsis_nn_dims *input_dims,
109                                               const cmsis_nn_dims *filter_dims,
110                                               const cmsis_nn_dims *output_dims)
111 {
112 #if !defined(ARM_MATH_MVEI)
113     (void)conv_params;
114     (void)output_dims;
115 
116     return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
117 #else
118     return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
119 #endif
120 }
121 
arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims * input_dims)122 int32_t arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
123 {
124 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
125     return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
126 #else
127     (void)input_dims;
128 #endif
129     return 0;
130 }
131 
132 /*
133  * Get the required buffer size for arm_convolve_wrapper_s8. This is the recommended function convolve wrapper s8
134  * function.
135  *
136  * Refer to header file for details.
137  *
138  */
arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)139 int32_t arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
140                                                 const cmsis_nn_dims *input_dims,
141                                                 const cmsis_nn_dims *filter_dims,
142                                                 const cmsis_nn_dims *output_dims)
143 {
144 #if defined(ARM_MATH_MVEI)
145     return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
146 #elif defined(ARM_MATH_DSP)
147     return arm_convolve_wrapper_s8_get_buffer_size_dsp(conv_params, input_dims, filter_dims, output_dims);
148 #else
149     (void)output_dims;
150     if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
151         (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
152     {
153         if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
154         {
155             return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
156         }
157         else
158         {
159             return 0;
160         }
161     }
162     else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
163              (conv_params->stride.w * input_dims->c % 4 == 0))
164     {
165         return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
166     }
167     else
168     {
169         return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
170     }
171 #endif
172 }
173 
arm_convolve_wrapper_s8_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)174 int32_t arm_convolve_wrapper_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
175                                                     const cmsis_nn_dims *input_dims,
176                                                     const cmsis_nn_dims *filter_dims,
177                                                     const cmsis_nn_dims *output_dims)
178 {
179     (void)output_dims;
180     if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
181         (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
182     {
183         if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
184         {
185             return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
186         }
187         else
188         {
189             return 0;
190         }
191     }
192     else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
193              (conv_params->stride.w * input_dims->c % 4 == 0))
194     {
195         return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
196     }
197     else
198     {
199         return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
200     }
201 }
202 
arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)203 int32_t arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
204                                                     const cmsis_nn_dims *input_dims,
205                                                     const cmsis_nn_dims *filter_dims,
206                                                     const cmsis_nn_dims *output_dims)
207 {
208     (void)output_dims;
209     if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
210         (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
211     {
212         if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
213         {
214             return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
215         }
216         else
217         {
218             return 0;
219         }
220     }
221     else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
222              (conv_params->stride.w * input_dims->c % 4 == 0))
223     {
224         return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
225     }
226     else
227     {
228         return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
229     }
230 }
231 
232 /**
233  * @} end of GetBufferSizeNNConv group
234  */
235