1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_transpose_conv_s8.c
22 * Description: s8 version of transposed convolution using symmetric quantization.
23 *
24 * $Date: 29 October 2024
25 * $Revision: V.2.0.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34 * @ingroup Public
35 */
36
37 /**
38 * @addtogroup NNConv
39 * @{
40 */
41
42 /*
43 * Basic s8 transpose convolution function.
44 *
45 * Refer header file for details.
46 *
47 */
arm_transpose_conv_s8(const cmsis_nn_context * ctx,const cmsis_nn_context * output_ctx,const cmsis_nn_transpose_conv_params * transpose_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)48 arm_cmsis_nn_status arm_transpose_conv_s8(const cmsis_nn_context *ctx,
49 const cmsis_nn_context *output_ctx,
50 const cmsis_nn_transpose_conv_params *transpose_conv_params,
51 const cmsis_nn_per_channel_quant_params *quant_params,
52 const cmsis_nn_dims *input_dims,
53 const int8_t *input_data,
54 const cmsis_nn_dims *filter_dims,
55 const int8_t *filter_data,
56 const cmsis_nn_dims *bias_dims,
57 const int32_t *bias_data,
58 const cmsis_nn_dims *output_dims,
59 int8_t *output_data)
60 {
61 (void)bias_dims;
62 (void)output_ctx;
63
64 const int32_t activation_min = transpose_conv_params->activation.min;
65 const int32_t activation_max = transpose_conv_params->activation.max;
66
67 const int32_t input_ch = input_dims->c;
68 const int32_t input_x = input_dims->w;
69 const int32_t input_y = input_dims->h;
70
71 const int32_t output_x = output_dims->w;
72 const int32_t output_y = output_dims->h;
73
74 const int32_t output_ch = output_dims->c;
75
76 const int32_t filter_x = filter_dims->w;
77 const int32_t filter_y = filter_dims->h;
78
79 const int32_t pad_x = transpose_conv_params->padding.w;
80 const int32_t pad_y = transpose_conv_params->padding.h;
81
82 const int32_t stride_x = transpose_conv_params->stride.w;
83 const int32_t stride_y = transpose_conv_params->stride.h;
84
85 const int32_t *output_multiplier = quant_params->multiplier;
86 const int32_t *output_shift = quant_params->shift;
87
88 const int32_t out_offset = transpose_conv_params->output_offset;
89 const int32_t input_offset = transpose_conv_params->input_offset;
90
91 const int32_t buf_x_elements = ((input_x - 1) * stride_x + MAX(filter_x, stride_x));
92 const int32_t buf_x = buf_x_elements * output_ch;
93 const int32_t buf_y = MAX(filter_y, stride_y);
94 const int32_t buf_size = buf_y * buf_x;
95 int32_t *buf = ctx->buf;
96 int32_t batch_cnt = input_dims->n;
97
98 const int8_t *filter = filter_data;
99 const int8_t *input = input_data;
100 int8_t *output = output_data;
101
102 while (batch_cnt)
103 {
104 // Reset buf
105 if (bias_data)
106 {
107 for (int x = 0; x < buf_x_elements * buf_y; x++)
108 {
109 arm_memcpy_s8((int8_t *)(buf + x * output_ch), (const int8_t *)bias_data, output_ch * sizeof(int32_t));
110 }
111 }
112 else
113 {
114 arm_memset_s8((int8_t *)buf, 0, buf_size * sizeof(int32_t));
115 }
116
117 int32_t buf_row = 0;
118 for (int j = 0; j < input_y; j++)
119 {
120 int skip_rows_top = MAX(0, pad_y - j * stride_y);
121 int skip_rows_bottom = MAX(0, (j * stride_y + filter_y) - (pad_y + output_y) - 1);
122
123 // Compute output for one row of input
124 arm_nn_transpose_conv_row_s8_s32(input,
125 filter,
126 buf,
127 buf_row,
128 buf_size,
129 filter_y,
130 filter_x,
131 input_ch,
132 output_ch,
133 input_offset,
134 buf_x,
135 input_x,
136 stride_x,
137 skip_rows_top,
138 skip_rows_bottom);
139 input += input_ch * input_x;
140
141 if (skip_rows_top == 0)
142 {
143 for (int y = 0; y < stride_y; y++)
144 {
145 int32_t *buf_out = buf + buf_row;
146 buf_out += output_ch * pad_x;
147
148 #if defined(ARM_MATH_MVEI)
149 for (int x = 0; x < output_x; x++)
150 {
151 const int32_t *mult_ptr = output_multiplier;
152 const int32_t *shift_ptr = output_shift;
153
154 int channel_count = output_ch;
155 for (; channel_count > 0; channel_count -= 4)
156 {
157 mve_pred16_t p = vctp32q((uint32_t)channel_count);
158
159 int32x4_t result = vldrwq_z_s32(buf_out, p);
160 buf_out += 4;
161 result =
162 arm_requantize_mve_32x4(result, vldrwq_z_s32(mult_ptr, p), vldrwq_z_s32(shift_ptr, p));
163 mult_ptr += 4;
164 shift_ptr += 4;
165 result = vaddq_n_s32(result, out_offset);
166 result = vmaxq_s32(result, vdupq_n_s32(activation_min));
167 result = vminq_s32(result, vdupq_n_s32(activation_max));
168 vstrbq_p_s32(output, result, p);
169 output += 4;
170 }
171
172 // Correct pointer overshoot due to predication
173 buf_out += channel_count;
174 output += channel_count;
175 }
176 #else
177
178 for (int x = 0; x < output_x; x++)
179 {
180 const int32_t *output_multiplier_ptr = output_multiplier;
181 const int32_t *output_shift_ptr = output_shift;
182 for (int z = 0; z < output_ch; z++)
183 {
184 int32_t result = *buf_out++;
185 result = arm_nn_requantize(result, *output_multiplier_ptr++, *output_shift_ptr++);
186 result += out_offset;
187 result = MAX(result, activation_min);
188 result = MIN(result, activation_max);
189 *output++ = result;
190 }
191 }
192 #endif
193
194 // Reset the buffer which was just written
195 if (bias_data)
196 {
197 for (int x = 0; x < buf_x_elements; x++)
198 {
199 arm_memcpy_s8((int8_t *)(buf + buf_row + x * output_ch),
200 (const int8_t *)bias_data,
201 output_ch * sizeof(int32_t));
202 }
203 }
204 else
205 {
206 arm_memset_s8((int8_t *)(buf + buf_row), 0, buf_x * sizeof(int32_t));
207 }
208
209 // Next row in the rolling buffer
210 buf_row = (buf_row + buf_x) % buf_size;
211 }
212 }
213 }
214
215 // Write leftover rows
216 for (int y = 0; y < filter_y - stride_y; y++)
217 {
218 int32_t *buf_out = buf + buf_row;
219 if ((input_y * stride_y + y >= pad_y) && (input_y * stride_y + y < pad_y + output_y))
220 {
221 buf_out += output_ch * pad_x;
222 #if defined(ARM_MATH_MVEI)
223 for (int x = 0; x < output_x; x++)
224 {
225 const int32_t *mult_ptr = output_multiplier;
226 const int32_t *shift_ptr = output_shift;
227
228 int channel_count = output_ch;
229 for (; channel_count > 0; channel_count -= 4)
230 {
231 mve_pred16_t p = vctp32q((uint32_t)channel_count);
232
233 int32x4_t result = vldrwq_z_s32(buf_out, p);
234 buf_out += 4;
235 result = arm_requantize_mve_32x4(result, vldrwq_z_s32(mult_ptr, p), vldrwq_z_s32(shift_ptr, p));
236 mult_ptr += 4;
237 shift_ptr += 4;
238 result = vaddq_n_s32(result, out_offset);
239 result = vmaxq_s32(result, vdupq_n_s32(activation_min));
240 result = vminq_s32(result, vdupq_n_s32(activation_max));
241 vstrbq_p_s32(output, result, p);
242 output += 4;
243 }
244
245 // Correct pointer overshoot due to predication
246 buf_out += channel_count;
247 output += channel_count;
248 }
249 #else
250 for (int x = 0; x < output_x; x++)
251 {
252 const int32_t *output_multiplier_ptr = output_multiplier;
253 const int32_t *output_shift_ptr = output_shift;
254
255 for (int z = 0; z < output_ch; z++)
256 {
257 int32_t result = *buf_out++;
258
259 result = arm_nn_requantize(result, *output_multiplier_ptr++, *output_shift_ptr++);
260 result += out_offset;
261 result = MAX(result, activation_min);
262 result = MIN(result, activation_max);
263 *output++ = result;
264 }
265 }
266 #endif
267 }
268 buf_row = (buf_row + buf_x) % buf_size;
269 }
270
271 batch_cnt--;
272 }
273
274 /* Return to application */
275 return ARM_CMSIS_NN_SUCCESS;
276 }
277
278 /**
279 * @} end of NNConv group
280 */
281