1 /*
2  * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_get_buffer_sizes_s16.c
22  * Description:  Collection of get buffer size functions for the various s16 convolution layer functions.
23  *
24  * $Date:        13 January 2023
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup NNconv
36  */
37 
38 /**
39  * @addtogroup GetBufferSizeNNConv
40  * @{
41  */
arm_depthwise_conv_fast_s16_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)42 __STATIC_INLINE int32_t arm_depthwise_conv_fast_s16_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
43                                                                         const cmsis_nn_dims *filter_dims)
44 {
45     /* The + 8 accounts for a worst case out of bounds read of the lhs buffers in the *_nt_t_* function.  */
46     return 4 * input_dims->c * filter_dims->w * filter_dims->h * sizeof(int16_t) + 8;
47 }
48 
arm_depthwise_conv_fast_s16_get_buffer_size_dsp(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)49 __STATIC_INLINE int32_t arm_depthwise_conv_fast_s16_get_buffer_size_dsp(const cmsis_nn_dims *input_dims,
50                                                                         const cmsis_nn_dims *filter_dims)
51 {
52     return input_dims->c * filter_dims->w * filter_dims->h * sizeof(int16_t);
53 }
54 
arm_depthwise_conv_fast_s16_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)55 int32_t arm_depthwise_conv_fast_s16_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
56 {
57 #if defined(ARM_MATH_DSP)
58     #if defined(ARM_MATH_MVEI)
59     return arm_depthwise_conv_fast_s16_get_buffer_size_mve(input_dims, filter_dims);
60     #else // ARM_MATH_DSP
61     return arm_depthwise_conv_fast_s16_get_buffer_size_dsp(input_dims, filter_dims);
62     #endif
63 #else
64     (void)input_dims;
65     (void)filter_dims;
66     return 0;
67 #endif
68 }
69 
arm_depthwise_conv_wrapper_s16_get_buffer_size(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)70 int32_t arm_depthwise_conv_wrapper_s16_get_buffer_size(const cmsis_nn_dw_conv_params *dw_conv_params,
71                                                        const cmsis_nn_dims *input_dims,
72                                                        const cmsis_nn_dims *filter_dims,
73                                                        const cmsis_nn_dims *output_dims)
74 {
75     (void)output_dims;
76 
77     int32_t size = 0;
78 
79     if (USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims))
80     {
81         size = arm_depthwise_conv_fast_s16_get_buffer_size(input_dims, filter_dims);
82     }
83 
84     return size;
85 }
86 
arm_depthwise_conv_wrapper_s16_get_buffer_size_mve(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)87 int32_t arm_depthwise_conv_wrapper_s16_get_buffer_size_mve(const cmsis_nn_dw_conv_params *dw_conv_params,
88                                                            const cmsis_nn_dims *input_dims,
89                                                            const cmsis_nn_dims *filter_dims,
90                                                            const cmsis_nn_dims *output_dims)
91 {
92     (void)output_dims;
93 
94     int32_t size = 0;
95 
96     if (USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims))
97     {
98         size = arm_depthwise_conv_fast_s16_get_buffer_size_mve(input_dims, filter_dims);
99     }
100 
101     return size;
102 }
103 
arm_depthwise_conv_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)104 int32_t arm_depthwise_conv_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_dw_conv_params *dw_conv_params,
105                                                            const cmsis_nn_dims *input_dims,
106                                                            const cmsis_nn_dims *filter_dims,
107                                                            const cmsis_nn_dims *output_dims)
108 {
109     (void)output_dims;
110 
111     int32_t size = 0;
112 
113     if (USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims))
114     {
115         size = arm_depthwise_conv_fast_s16_get_buffer_size_dsp(input_dims, filter_dims);
116     }
117 
118     return size;
119 }
120 
121 /**
122  * @} end of GetBufferSizeNNConv group
123  */
124