1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_convolve_get_buffer_sizes_s16.c
22 * Description: Collection of get buffer size functions for the various s16 convolution layer functions.
23 *
24 * $Date: 20 March 2024
25 * $Revision: V.2.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "Internal/arm_nn_compiler.h"
32 #include "arm_nnfunctions.h"
33
34 /**
35 * @ingroup NNConv
36 */
37
38 /**
39 * @addtogroup GetBufferSizeNNConv
40 * @{
41 */
42
arm_convolve_s16_get_buffer_size_mve(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)43 __STATIC_INLINE int32_t arm_convolve_s16_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
44 const cmsis_nn_dims *filter_dims)
45 {
46 int32_t col_length = input_dims->c * filter_dims->w * filter_dims->h;
47 // Get number of complete lanes with int16 elements (multiple of 8) for given col_length. This is dependent on
48 // implementation of arm_nn_mat_mult_nt_t_s16
49 col_length = (col_length + 7) / 8;
50 // 4 -> number of im2col buffers, 8 -> 8 elements per Q register
51 return 4 * col_length * 8 * (int32_t)sizeof(int16_t);
52 }
53
arm_convolve_s16_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)54 int32_t arm_convolve_s16_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
55 {
56 #if defined(ARM_MATH_MVEI)
57 return arm_convolve_s16_get_buffer_size_mve(input_dims, filter_dims);
58 #else
59 return (2 * input_dims->c * filter_dims->w * filter_dims->h) * (int32_t)sizeof(int16_t);
60 #endif
61 }
62
63 /*
64 * Get the required buffer size for arm_convolve_wrapper_s16. This is the recommended function convolve wrapper s16
65 * function.
66 *
67 * Refer to header file for details.
68 *
69 */
arm_convolve_wrapper_s16_get_buffer_size(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)70 int32_t arm_convolve_wrapper_s16_get_buffer_size(const cmsis_nn_conv_params *conv_params,
71 const cmsis_nn_dims *input_dims,
72 const cmsis_nn_dims *filter_dims,
73 const cmsis_nn_dims *output_dims)
74 {
75 (void)conv_params;
76 (void)output_dims;
77
78 return arm_convolve_s16_get_buffer_size(input_dims, filter_dims);
79 }
80
arm_convolve_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)81 int32_t arm_convolve_wrapper_s16_get_buffer_size_dsp(const cmsis_nn_conv_params *conv_params,
82 const cmsis_nn_dims *input_dims,
83 const cmsis_nn_dims *filter_dims,
84 const cmsis_nn_dims *output_dims)
85 {
86 return arm_convolve_wrapper_s16_get_buffer_size(conv_params, input_dims, filter_dims, output_dims);
87 }
88
arm_convolve_wrapper_s16_get_buffer_size_mve(const cmsis_nn_conv_params * conv_params,const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims)89 int32_t arm_convolve_wrapper_s16_get_buffer_size_mve(const cmsis_nn_conv_params *conv_params,
90 const cmsis_nn_dims *input_dims,
91 const cmsis_nn_dims *filter_dims,
92 const cmsis_nn_dims *output_dims)
93 {
94 (void)conv_params;
95 (void)output_dims;
96
97 return arm_convolve_s16_get_buffer_size_mve(input_dims, filter_dims);
98 }
99
100 /**
101 * @} end of GetBufferSizeNNConv group
102 */
103