1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_convolve_get_buffer_sizes_s16.c
22  * Description:  Collection of get buffer size functions for the various s16 convolution layer functions.
23  *
24  * $Date:        20 March 2024
25  * $Revision:    V.2.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33 
34 /**
35  *  @ingroup NNConv
36  */
37 
38 /**
39  * @addtogroup GetBufferSizeNNConv
40  * @{
41  */
42 
arm_convolve_s16_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)43 __STATIC_INLINE int32_t arm_convolve_s16_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
44                                                              const cmsis_nn_dims *filter_dims)
45 {
46     int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
47     // Get number of complete lanes with int16 elements (multiple of 8) for given col_length. This is dependent on
48     // implementation of arm_nn_mat_mult_nt_t_s16
49     col_length = (col_length + 7) / 8;
50     // 4 -> number of im2col buffers, 8 -> 8 elements per Q register
51     return 4 * col_length * 8 * (int32_t)sizeof(int16_t);
52 }
53 
arm_convolve_s16_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)54 int32_t arm_convolve_s16_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
55 {
56 #if defined(ARM_MATH_MVEI)
57     return arm_convolve_s16_get_buffer_size_mve(input_dims, filter_dims);
58 #else
59     return (2 * input_dims->c * filter_dims->w * filter_dims->h) * (int32_t)sizeof(int16_t);
60 #endif
61 }
62 
63 /*
64  * Get the required buffer size for arm_convolve_wrapper_s16. This is the recommended function convolve wrapper s16
65  * function.
66  *
67  * Refer to header file for details.
68  *
69  */
arm_convolve_wrapper_s16_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)70 int32_t arm_convolve_wrapper_s16_get_buffer_size(const cmsis_nn_conv_params *conv_params,
71                                                  const cmsis_nn_dims *input_dims,
72                                                  const cmsis_nn_dims *filter_dims,
73                                                  const cmsis_nn_dims *output_dims)
74 {
75     (void)conv_params;
76     (void)output_dims;
77 
78     return arm_convolve_s16_get_buffer_size(input_dims, filter_dims);
79 }
80 
arm_convolve_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)81 int32_t arm_convolve_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
82                                                      const cmsis_nn_dims *input_dims,
83                                                      const cmsis_nn_dims *filter_dims,
84                                                      const cmsis_nn_dims *output_dims)
85 {
86     return arm_convolve_wrapper_s16_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
87 }
88 
arm_convolve_wrapper_s16_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)89 int32_t arm_convolve_wrapper_s16_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
90                                                      const cmsis_nn_dims *input_dims,
91                                                      const cmsis_nn_dims *filter_dims,
92                                                      const cmsis_nn_dims *output_dims)
93 {
94     (void)conv_params;
95     (void)output_dims;
96 
97     return arm_convolve_s16_get_buffer_size_mve(input_dims, filter_dims);
98 }
99 
100 /**
101  * @} end of GetBufferSizeNNConv group
102  */
103