1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_convolve_get_buffer_sizes_s4.c
22 * Description: Collection of get buffer size functions for the various s4 convolution layer functions.
23 *
24 * $Date: 10 April 2024
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup NNConv
37 */
38
39 /**
40 * @addtogroup GetBufferSizeNNConv
41 * @{
42 */
43
arm_convolve_s4_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)44 __STATIC_INLINE int32_t arm_convolve_s4_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
45 const cmsis_nn_dims *filter_dims)
46 {
47 int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
48 // Get number of complete lanes with int8 elements (multiple of 16) for given col_length. This is dependent on
49 // implementation of arm_nn_mat_mult_nt_t_s4
50 col_length = (col_length + 15) / 16;
51 // 4 -> number of im2col buffers, 16 -> 16 elements per Q register
52 return 4 * col_length * 16 * (int32_t)sizeof(int8_t);
53 }
54
arm_convolve_1_x_n_s4_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)55 __STATIC_INLINE int32_t arm_convolve_1_x_n_s4_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
56 const cmsis_nn_dims *input_dims,
57 const cmsis_nn_dims *filter_dims,
58 const cmsis_nn_dims *output_dims)
59 {
60 const int32_t input_x = input_dims->w;
61 const int32_t pad_x = conv_params->padding.w;
62 const int32_t kernel_x = filter_dims->w;
63 const int32_t output_x = output_dims->w;
64 const int32_t stride_x = conv_params->stride.w;
65 const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
66 const int32_t asym_pad = total_pad % 2;
67
68 const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
69 const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
70 const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
71
72 if (right_pad_num + no_pad_num + left_pad_num != output_x)
73 {
74 return arm_convolve_s4_get_buffer_size_mve(input_dims, filter_dims);
75 }
76
77 return 0;
78 }
79
arm_convolve_s4_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)80 int32_t arm_convolve_s4_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
81 {
82 const int32_t rhs_cols = filter_dims->w * filter_dims->h * input_dims->c;
83 return (2 * rhs_cols) * (int32_t)sizeof(int16_t);
84 }
85
arm_convolve_1_x_n_s4_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)86 int32_t arm_convolve_1_x_n_s4_get_buffer_size(const cmsis_nn_conv_params *conv_params,
87 const cmsis_nn_dims *input_dims,
88 const cmsis_nn_dims *filter_dims,
89 const cmsis_nn_dims *output_dims)
90 {
91 #if !defined(ARM_MATH_MVEI)
92 (void)conv_params;
93 (void)output_dims;
94
95 return arm_convolve_s4_get_buffer_size(input_dims, filter_dims);
96 #else
97 return arm_convolve_1_x_n_s4_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
98 #endif
99 }
100
arm_convolve_1x1_s4_fast_get_buffer_size(const cmsis_nn_dims * input_dims)101 int32_t arm_convolve_1x1_s4_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
102 {
103 (void)input_dims;
104 return 0;
105 }
106
107 /*
108 * Get the required buffer size for arm_convolve_wrapper_s4. This is the
109 * recommended convolve wrapper s4 function.
110 *
111 * Refer to header file for details.
112 *
113 */
arm_convolve_wrapper_s4_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)114 int32_t arm_convolve_wrapper_s4_get_buffer_size(const cmsis_nn_conv_params *conv_params,
115 const cmsis_nn_dims *input_dims,
116 const cmsis_nn_dims *filter_dims,
117 const cmsis_nn_dims *output_dims)
118 {
119 #if defined(ARM_MATH_MVEI)
120 return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
121 #else
122 (void)output_dims;
123 if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
124 (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
125 {
126 if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
127 {
128 return arm_convolve_1x1_s4_fast_get_buffer_size(input_dims);
129 }
130 else
131 {
132 return 0;
133 }
134 }
135 else
136 {
137 return arm_convolve_s4_get_buffer_size(input_dims, filter_dims);
138 }
139 #endif
140 }
141
arm_convolve_wrapper_s4_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)142 int32_t arm_convolve_wrapper_s4_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
143 const cmsis_nn_dims *input_dims,
144 const cmsis_nn_dims *filter_dims,
145 const cmsis_nn_dims *output_dims)
146
147 {
148 (void)output_dims;
149 if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
150 (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
151 {
152 if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
153 {
154 return arm_convolve_1x1_s4_fast_get_buffer_size(input_dims);
155 }
156 else
157 {
158 return 0;
159 }
160 }
161 else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
162 (conv_params->stride.w * input_dims->c % 4 == 0))
163 {
164 return arm_convolve_1_x_n_s4_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
165 }
166 else
167 {
168 return arm_convolve_s4_get_buffer_size_mve(input_dims, filter_dims);
169 }
170 }
171
arm_convolve_wrapper_s4_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)172 int32_t arm_convolve_wrapper_s4_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
173 const cmsis_nn_dims *input_dims,
174 const cmsis_nn_dims *filter_dims,
175 const cmsis_nn_dims *output_dims)
176 {
177 return arm_convolve_wrapper_s4_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
178 }
179 /**
180 * @} end of GetBufferSizeNNConv group
181 */
182