1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_transpose_conv_get_buffer_sizes_s8.c
22  * Description:  Collection of get buffer size functions for the transpose convolution layer functions.
23  *
24  * $Date:        29 October 2024
25  * $Revision:    V.2.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup NNConv
37  */
38 
39 /**
40  * @addtogroup GetBufferSizeNNConv
41  * @{
42  */
43 
44 /*
45  * Get the required buffer size for arm_transpose_conv_s8. This is the recommended transpose conv s8 get buffer size
46  * function.
47  *
48  * Refer to header file for details.
49  *
50  */
arm_transpose_conv_s8_get_buffer_size(const cmsis_nn_transpose_conv_params * transpose_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * out_dims)51 int32_t arm_transpose_conv_s8_get_buffer_size(const cmsis_nn_transpose_conv_params *transpose_conv_params,
52                                               const cmsis_nn_dims *input_dims,
53                                               const cmsis_nn_dims *filter_dims,
54                                               const cmsis_nn_dims *out_dims)
55 {
56 
57     const bool reverse_conv_possible =
58         ((transpose_conv_params->stride.w <= 2) && (transpose_conv_params->stride.h <= 2));
59     const bool reverse_conv_efficient = (input_dims->c > REVERSE_TCOL_EFFICIENT_THRESHOLD);
60 
61     if (reverse_conv_possible && reverse_conv_efficient)
62     {
63         const cmsis_nn_dims reverse_conv_input_dims = {input_dims->n,
64                                                        input_dims->h * transpose_conv_params->stride.h,
65                                                        input_dims->w * transpose_conv_params->stride.w,
66                                                        input_dims->c};
67         return arm_convolve_s8_get_buffer_size(&reverse_conv_input_dims, filter_dims);
68     }
69     else
70     {
71         const int32_t buf_x = ((input_dims->w - 1) * transpose_conv_params->stride.w +
72                                MAX(filter_dims->w, transpose_conv_params->stride.h)) *
73             out_dims->c;
74         const int32_t buf_y = MAX(filter_dims->h, transpose_conv_params->stride.h);
75         return buf_x * buf_y * sizeof(int32_t);
76     }
77 }
78 
arm_transpose_conv_s8_get_reverse_conv_buffer_size(const cmsis_nn_transpose_conv_params * transpose_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)79 int32_t arm_transpose_conv_s8_get_reverse_conv_buffer_size(const cmsis_nn_transpose_conv_params *transpose_conv_params,
80                                                            const cmsis_nn_dims *input_dims,
81                                                            const cmsis_nn_dims *filter_dims)
82 {
83     const bool reverse_conv_possible =
84         ((transpose_conv_params->stride.w <= 2) && (transpose_conv_params->stride.h <= 2));
85     const bool reverse_conv_efficient = (input_dims->c > REVERSE_TCOL_EFFICIENT_THRESHOLD);
86 
87     if (reverse_conv_possible && reverse_conv_efficient)
88     {
89         return input_dims->c * filter_dims->w * filter_dims->h * filter_dims->n;
90     }
91     else
92     {
93         return 0;
94     }
95 }
96 
97 /**
98  * @} end of GetBufferSizeNNConv group
99  */
100