1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_get_buffer_sizes_s4.c
22  * Description:  Collection of get buffer size functions for the various s4 convolution layer functions.
23  *
24  * $Date:        10 April 2024
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup NNConv
37  */
38 
39 /**
40  * @addtogroup GetBufferSizeNNConv
41  * @{
42  */
43 
arm_convolve_s4_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)44 __STATIC_INLINE int32_t arm_convolve_s4_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
45                                                             const cmsis_nn_dims *filter_dims)
46 {
47     int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
48     // Get number of complete lanes with int8 elements (multiple of 16) for given col_length. This is dependent on
49     // implementation of arm_nn_mat_mult_nt_t_s4
50     col_length = (col_length + 15) / 16;
51     // 4 -> number of im2col buffers, 16 -> 16 elements per Q register
52     return 4 * col_length * 16 * (int32_t)sizeof(int8_t);
53 }
54 
arm_convolve_1_x_n_s4_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)55 __STATIC_INLINE int32_t arm_convolve_1_x_n_s4_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
56                                                                   const cmsis_nn_dims *input_dims,
57                                                                   const cmsis_nn_dims *filter_dims,
58                                                                   const cmsis_nn_dims *output_dims)
59 {
60     const int32_t input_x = input_dims->w;
61     const int32_t pad_x = conv_params->padding.w;
62     const int32_t kernel_x = filter_dims->w;
63     const int32_t output_x = output_dims->w;
64     const int32_t stride_x = conv_params->stride.w;
65     const int32_t total_pad = ((output_x - 1) * stride_x + kernel_x - input_x);
66     const int32_t asym_pad = total_pad % 2;
67 
68     const int32_t right_pad_num = pad_x + asym_pad != 0 ? MAX(1, (pad_x + asym_pad + stride_x - 1) / stride_x) : 0;
69     const int32_t left_pad_num = pad_x != 0 ? MAX(1, (pad_x + stride_x - 1) / stride_x) : 0;
70     const int32_t no_pad_num = MAX(output_x - (right_pad_num + left_pad_num), 0);
71 
72     if (right_pad_num + no_pad_num + left_pad_num != output_x)
73     {
74         return arm_convolve_s4_get_buffer_size_mve(input_dims, filter_dims);
75     }
76 
77     return 0;
78 }
79 
arm_convolve_s4_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)80 int32_t arm_convolve_s4_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
81 {
82     const int32_t rhs_cols = filter_dims->w * filter_dims->h * input_dims->c;
83     return (2 * rhs_cols) * (int32_t)sizeof(int16_t);
84 }
85 
arm_convolve_1_x_n_s4_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)86 int32_t arm_convolve_1_x_n_s4_get_buffer_size(const cmsis_nn_conv_params *conv_params,
87                                               const cmsis_nn_dims *input_dims,
88                                               const cmsis_nn_dims *filter_dims,
89                                               const cmsis_nn_dims *output_dims)
90 {
91 #if !defined(ARM_MATH_MVEI)
92     (void)conv_params;
93     (void)output_dims;
94 
95     return arm_convolve_s4_get_buffer_size(input_dims, filter_dims);
96 #else
97     return arm_convolve_1_x_n_s4_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
98 #endif
99 }
100 
arm_convolve_1x1_s4_fast_get_buffer_size(const cmsis_nn_dims * input_dims)101 int32_t arm_convolve_1x1_s4_fast_get_buffer_size(const cmsis_nn_dims *input_dims)
102 {
103     (void)input_dims;
104     return 0;
105 }
106 
107 /*
108  * Get the required buffer size for arm_convolve_wrapper_s4. This is the
109  * recommended convolve wrapper s4 function.
110  *
111  * Refer to header file for details.
112  *
113  */
arm_convolve_wrapper_s4_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)114 int32_t arm_convolve_wrapper_s4_get_buffer_size(const cmsis_nn_conv_params *conv_params,
115                                                 const cmsis_nn_dims *input_dims,
116                                                 const cmsis_nn_dims *filter_dims,
117                                                 const cmsis_nn_dims *output_dims)
118 {
119 #if defined(ARM_MATH_MVEI)
120     return arm_convolve_wrapper_s8_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
121 #else
122     (void)output_dims;
123     if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
124         (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
125     {
126         if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
127         {
128             return arm_convolve_1x1_s4_fast_get_buffer_size(input_dims);
129         }
130         else
131         {
132             return 0;
133         }
134     }
135     else
136     {
137         return arm_convolve_s4_get_buffer_size(input_dims, filter_dims);
138     }
139 #endif
140 }
141 
arm_convolve_wrapper_s4_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)142 int32_t arm_convolve_wrapper_s4_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
143                                                     const cmsis_nn_dims *input_dims,
144                                                     const cmsis_nn_dims *filter_dims,
145                                                     const cmsis_nn_dims *output_dims)
146 
147 {
148     (void)output_dims;
149     if ((conv_params->padding.w == 0) && (conv_params->padding.h == 0) && (filter_dims->w == 1) &&
150         (filter_dims->h == 1) && (conv_params->dilation.w == 1 && conv_params->dilation.h == 1))
151     {
152         if ((conv_params->stride.w == 1) && (conv_params->stride.h == 1))
153         {
154             return arm_convolve_1x1_s4_fast_get_buffer_size(input_dims);
155         }
156         else
157         {
158             return 0;
159         }
160     }
161     else if ((input_dims->h == 1) && (conv_params->dilation.w == 1) && (filter_dims->h == 1) &&
162              (conv_params->stride.w * input_dims->c % 4 == 0))
163     {
164         return arm_convolve_1_x_n_s4_get_buffer_size_mve(conv_params, input_dims, filter_dims, output_dims);
165     }
166     else
167     {
168         return arm_convolve_s4_get_buffer_size_mve(input_dims, filter_dims);
169     }
170 }
171 
arm_convolve_wrapper_s4_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)172 int32_t arm_convolve_wrapper_s4_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
173                                                     const cmsis_nn_dims *input_dims,
174                                                     const cmsis_nn_dims *filter_dims,
175                                                     const cmsis_nn_dims *output_dims)
176 {
177     return arm_convolve_wrapper_s4_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
178 }
179 /**
180  * @} end of GetBufferSizeNNConv group
181  */
182