1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_get_buffer_sizes_s8.c
22  * Description:  Collection of get buffer size functions for the various s8 convolution layer functions.
23  *
24  * $Date:        17 April 2024
25  * $Revision:    V.1.2.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup NNConv
36  */
37 
38 /**
39  * @addtogroup GetBufferSizeNNConv
40  * @{
41  */
42 
arm_depthwise_conv_s8_opt_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)43 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_mve(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
44 {
45     (void)input_dims;
46     return (4 * CH_IN_BLOCK_MVE * filter_dims->w * filter_dims->h) * (int32_t)sizeof(int8_t);
47 }
48 
arm_depthwise_conv_s8_opt_get_buffer_size_dsp(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)49 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_dsp(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
50 {
51     return (input_dims->c * filter_dims->w * filter_dims->h) * sizeof(int16_t);
52 }
53 
arm_depthwise_conv_s8_opt_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)54 int32_t arm_depthwise_conv_s8_opt_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
55 {
56 #if defined(ARM_MATH_MVEI)
57     return arm_depthwise_conv_s8_opt_get_buffer_size_mve(input_dims, filter_dims);
58 #elif defined(ARM_MATH_DSP)
59     return arm_depthwise_conv_s8_opt_get_buffer_size_dsp(input_dims, filter_dims);
60 #else
61     (void)input_dims;
62     (void)filter_dims;
63     return 0;
64 #endif
65 }
66 
arm_depthwise_conv_wrapper_s8_get_buffer_size(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)67 int32_t arm_depthwise_conv_wrapper_s8_get_buffer_size(const cmsis_nn_dw_conv_params *dw_conv_params,
68                                                       const cmsis_nn_dims *input_dims,
69                                                       const cmsis_nn_dims *filter_dims,
70                                                       const cmsis_nn_dims *output_dims)
71 {
72     int32_t size = 0;
73 
74     if (input_dims->c == output_dims->c && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
75         dw_conv_params->dilation.h == 1)
76     {
77 #if !defined(ARM_MATH_MVEI)
78         if (filter_dims->w == 3 && filter_dims->h == 3 && dw_conv_params->padding.h <= 1 &&
79             dw_conv_params->padding.w <= 1)
80         {
81             return size;
82         }
83 #endif
84         size = arm_depthwise_conv_s8_opt_get_buffer_size(input_dims, filter_dims);
85     }
86 
87     return size;
88 }
89 
arm_depthwise_conv_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)90 int32_t arm_depthwise_conv_wrapper_s8_get_buffer_size_dsp(const cmsis_nn_dw_conv_params *dw_conv_params,
91                                                           const cmsis_nn_dims *input_dims,
92                                                           const cmsis_nn_dims *filter_dims,
93                                                           const cmsis_nn_dims *output_dims)
94 {
95     int32_t size = 0;
96 
97     if (input_dims->c == output_dims->c && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
98         dw_conv_params->dilation.h == 1)
99     {
100         if (filter_dims->w == 3 && filter_dims->h == 3 && dw_conv_params->padding.h <= 1 &&
101             dw_conv_params->padding.w <= 1)
102         {
103             return size;
104         }
105         size = arm_depthwise_conv_s8_opt_get_buffer_size_dsp(input_dims, filter_dims);
106     }
107 
108     return size;
109 }
110 
arm_depthwise_conv_wrapper_s8_get_buffer_size_mve(const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)111 int32_t arm_depthwise_conv_wrapper_s8_get_buffer_size_mve(const cmsis_nn_dw_conv_params *dw_conv_params,
112                                                           const cmsis_nn_dims *input_dims,
113                                                           const cmsis_nn_dims *filter_dims,
114                                                           const cmsis_nn_dims *output_dims)
115 {
116     int32_t size = 0;
117 
118     if (input_dims->c == output_dims->c && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
119         dw_conv_params->dilation.h == 1)
120     {
121         size = arm_depthwise_conv_s8_opt_get_buffer_size_mve(input_dims, filter_dims);
122     }
123 
124     return size;
125 }
126 
127 /**
128  * @} end of GetBufferSizeNNConv group
129  */
130