1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_convolve_get_buffer_sizes_s8.c
22 * Description: Collection of get buffer size functions for the various s8 convolution layer functions.
23 *
24 * $Date: 31 October 2024
25 * $Revision: V.2.2.1
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup NNConv
37 */
38
39 /**
40 * @addtogroup GetBufferSizeNNConv
41 * @{
42 */
arm_convolve_1x1_s8_fast_get_buffer_size_dsp(const cmsis_nn_dims * input_dims)43 __STATIC_INLINE int32_t arm_convolve_1x1_s8_fast_get_buffer_size_dsp(const cmsis_nn_dims *input_dims)
44 {
45 #if defined(__ARMCC_VERSION) && (__ARMCC_VERSION >= 6010050)
46 return (2 * input_dims->c) * (int32_t)sizeof(int16_t);
47 #else
48 (void)input_dims;
49 return 0;
50 #endif
51 }
52
arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)53 __STATIC_INLINE int32_t arm_convolve_s8_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
54 const cmsis_nn_dims *filter_dims)
55 {
56 int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
57 // Get number of complete lanes with int8 elements (multiple of 16) for given col_length. This is dependent on
58 // implementation of arm_nn_mat_mult_nt_t_s8
59 col_length = (col_length + 15) / 16;
60 // 4 -> number of im2col buffers, 16 -> 16 elements per Q register
61 return 4 * col_length * 16 * (int32_t)sizeof(int8_t);
62 }
63
arm_convolve_1_x_n_s8_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)64 __STATIC_INLINE int32_t arm_convolve_1_x_n_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
65 const cmsis_nn_dims *input_dims,
66 const cmsis_nn_dims *filter_dims,
67 const cmsis_nn_dims *output_dims)
68 {
69 const int32_t input_x = input_dims->w;
70 const int32_t pad_x = conv_params->padding.w;
71 const int32_t kernel_x = filter_dims->w;
72 const int32_t output_x = output_dims->w;
73 const int32_t stride_x = conv_params->stride.w;
74 const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
75 const int32_t asym_pad = total_pad % 2;
76
77 const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
78 const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
79 const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
80
81 if (right_pad_num + no_pad_num + left_pad_num != output_x)
82 {
83 return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
84 }
85
86 const int32_t pad_size_left = pad_x * input_dims->c;
87 const int32_t pad_size_right = asym_pad ? right_pad_num * input_dims->c : pad_size_left;
88 const int32_t num_elem_left = kernel_x * input_dims->c;
89 const int32_t num_elem_right = num_elem_left - input_dims->c;
90 const int32_t size_1_x_n = MAX(num_elem_left + pad_size_left, num_elem_right + pad_size_right);
91
92 return size_1_x_n;
93 }
94
arm_convolve_s8_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)95 int32_t arm_convolve_s8_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
96 {
97 #if defined(ARM_MATH_MVEI)
98 return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
99 #else
100 const int32_t rhs_cols = filter_dims->w * filter_dims->h * input_dims->c;
101 const int32_t remainder = rhs_cols % 4;
102 const int32_t aligned_rhs_cols = remainder != 0 ? rhs_cols + 4 - remainder : rhs_cols;
103 return (2 * aligned_rhs_cols) * (int32_t)sizeof(int16_t);
104 #endif
105 }
106
arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)107 int32_t arm_convolve_1_x_n_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
108 const cmsis_nn_dims *input_dims,
109 const cmsis_nn_dims *filter_dims,
110 const cmsis_nn_dims *output_dims)
111 {
112 #if !defined(ARM_MATH_MVEI)
113 (void)conv_params;
114 (void)output_dims;
115
116 return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
117 #else
118 return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
119 #endif
120 }
121
arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims * input_dims)122 int32_t arm_convolve_1x1_s8_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
123 {
124 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
125 return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
126 #else
127 (void)input_dims;
128 #endif
129 return 0;
130 }
131
132 /*
133 * Get the required buffer size for arm_convolve_wrapper_s8. This is the recommended function convolve wrapper s8
134 * function.
135 *
136 * Refer to header file for details.
137 *
138 */
arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)139 int32_t arm_convolve_wrapper_s8_get_buffer_size(const cmsis_nn_conv_params *conv_params,
140 const cmsis_nn_dims *input_dims,
141 const cmsis_nn_dims *filter_dims,
142 const cmsis_nn_dims *output_dims)
143 {
144 #if defined(ARM_MATH_MVEI)
145 return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
146 #elif defined(ARM_MATH_DSP)
147 return arm_convolve_wrapper_s8_get_buffer_size_dsp(conv_params, input_dims, filter_dims, output_dims);
148 #else
149 (void)output_dims;
150 if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
151 (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
152 {
153 if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
154 {
155 return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
156 }
157 else
158 {
159 return 0;
160 }
161 }
162 else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
163 (conv_params->stride.w * input_dims->c % 4 == 0))
164 {
165 return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
166 }
167 else
168 {
169 return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
170 }
171 #endif
172 }
173
arm_convolve_wrapper_s8_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)174 int32_t arm_convolve_wrapper_s8_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
175 const cmsis_nn_dims *input_dims,
176 const cmsis_nn_dims *filter_dims,
177 const cmsis_nn_dims *output_dims)
178 {
179 (void)output_dims;
180 if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
181 (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
182 {
183 if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
184 {
185 return arm_convolve_1x1_s8_fast_get_buffer_size(input_dims);
186 }
187 else
188 {
189 return 0;
190 }
191 }
192 else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
193 (conv_params->stride.w * input_dims->c % 4 == 0))
194 {
195 return arm_convolve_1_x_n_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
196 }
197 else
198 {
199 return arm_convolve_s8_get_buffer_size_mve(input_dims, filter_dims);
200 }
201 }
202
arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)203 int32_t arm_convolve_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
204 const cmsis_nn_dims *input_dims,
205 const cmsis_nn_dims *filter_dims,
206 const cmsis_nn_dims *output_dims)
207 {
208 (void)output_dims;
209 if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
210 (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
211 {
212 if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
213 {
214 return arm_convolve_1x1_s8_fast_get_buffer_size_dsp(input_dims);
215 }
216 else
217 {
218 return 0;
219 }
220 }
221 else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
222 (conv_params->stride.w * input_dims->c % 4 == 0))
223 {
224 return arm_convolve_1_x_n_s8_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
225 }
226 else
227 {
228 return arm_convolve_s8_get_buffer_size(input_dims, filter_dims);
229 }
230 }
231
232 /**
233 * @} end of GetBufferSizeNNConv group
234 */
235