1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_wrapper_s8.c
22 * Description: Wrapper API to select appropriate depthwise conv API based
23 * on dimensions.
24 *
25 * $Date: 04 November 2024
26 * $Revision: V.2.2.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 *
30 * -------------------------------------------------------------------- */
31
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup Public
37 */
38
39 /**
40 * @addtogroup NNConv
41 * @{
42 */
43
44 #if defined(ARM_MATH_MVEI)
arm_depthwise_conv_to_conv_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * filter,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)45 static arm_cmsis_nn_status arm_depthwise_conv_to_conv_s8(const cmsis_nn_context *ctx,
46 const cmsis_nn_dw_conv_params *dw_conv_params,
47 const cmsis_nn_per_channel_quant_params *quant_params,
48 const cmsis_nn_dims *input_dims,
49 const int8_t *input,
50 const cmsis_nn_dims *filter_dims,
51 const int8_t *filter,
52 const cmsis_nn_dims *bias_dims,
53 const int32_t *bias,
54 const cmsis_nn_dims *output_dims,
55 int8_t *output)
56 {
57 const cmsis_nn_conv_params conv_params = {dw_conv_params->input_offset,
58 dw_conv_params->output_offset,
59 dw_conv_params->stride,
60 dw_conv_params->padding,
61 dw_conv_params->dilation,
62 dw_conv_params->activation};
63 const cmsis_nn_dims filter_output_dims = {filter_dims->c, filter_dims->h, filter_dims->w, filter_dims->n};
64 int8_t *w_buf =
65 ctx->buf + arm_convolve_wrapper_s8_get_buffer_size(&conv_params, input_dims, &filter_output_dims, output_dims);
66 const uint32_t perm[4] = {3, 1, 2, 0};
67 const cmsis_nn_transpose_params transpose_params = {4, perm};
68
69 arm_cmsis_nn_status status = arm_transpose_s8(filter, w_buf, filter_dims, &filter_output_dims, &transpose_params);
70
71 if (status == ARM_CMSIS_NN_SUCCESS)
72 {
73 status = arm_convolve_wrapper_s8(ctx,
74 &conv_params,
75 quant_params,
76 input_dims,
77 input,
78 &filter_output_dims,
79 (const int8_t *)w_buf,
80 bias_dims,
81 bias,
82 output_dims,
83 output);
84 }
85 return status;
86 }
87 #endif
88
89 /*
90 * s8 Depthwise conv wrapper function
91 *
92 * Refer header file for details.
93 *
94 */
arm_depthwise_conv_wrapper_s8(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * filter,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)95 arm_cmsis_nn_status arm_depthwise_conv_wrapper_s8(const cmsis_nn_context *ctx,
96 const cmsis_nn_dw_conv_params *dw_conv_params,
97 const cmsis_nn_per_channel_quant_params *quant_params,
98 const cmsis_nn_dims *input_dims,
99 const int8_t *input,
100 const cmsis_nn_dims *filter_dims,
101 const int8_t *filter,
102 const cmsis_nn_dims *bias_dims,
103 const int32_t *bias,
104 const cmsis_nn_dims *output_dims,
105 int8_t *output)
106 {
107 arm_cmsis_nn_status status = ARM_CMSIS_NN_SUCCESS;
108
109 #if defined(ARM_MATH_MVEI)
110 if (input_dims->c == 1 && output_dims->c > CONVERT_DW_CONV_WITH_ONE_INPUT_CH_AND_OUTPUT_CH_ABOVE_THRESHOLD)
111 {
112 return arm_depthwise_conv_to_conv_s8(ctx,
113 dw_conv_params,
114 quant_params,
115 input_dims,
116 input,
117 filter_dims,
118 filter,
119 bias_dims,
120 bias,
121 output_dims,
122 output);
123 }
124 #endif
125
126 if (1 == dw_conv_params->ch_mult && input_dims->n == 1 && dw_conv_params->dilation.w == 1 &&
127 dw_conv_params->dilation.h == 1)
128 {
129 #if !defined(ARM_MATH_MVEI)
130 if (filter_dims->w == 3 && filter_dims->h == 3 && dw_conv_params->padding.h <= 1 &&
131 dw_conv_params->padding.w <= 1)
132 {
133 status = arm_depthwise_conv_3x3_s8(ctx,
134 dw_conv_params,
135 quant_params,
136 input_dims,
137 input,
138 filter_dims,
139 filter,
140 bias_dims,
141 bias,
142 output_dims,
143 output);
144 }
145 else
146 #endif
147 {
148 status = arm_depthwise_conv_s8_opt(ctx,
149 dw_conv_params,
150 quant_params,
151 input_dims,
152 input,
153 filter_dims,
154 filter,
155 bias_dims,
156 bias,
157 output_dims,
158 output);
159 }
160 }
161 else
162 {
163 status = arm_depthwise_conv_s8(ctx,
164 dw_conv_params,
165 quant_params,
166 input_dims,
167 input,
168 filter_dims,
169 filter,
170 bias_dims,
171 bias,
172 output_dims,
173 output);
174 }
175
176 /* Return to application */
177 return status;
178 }
179
180 /**
181 * @} end of NNConv group
182 */
183