1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_transpose_conv_s8.c
22  * Description:  s8 version of transposed convolution using symmetric quantization.
23  *
24  * $Date:        29 October 2024
25  * $Revision:    V.2.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 /**
34  *  @ingroup Public
35  */
36 
37 /**
38  * @addtogroup NNConv
39  * @{
40  */
41 
42 /*
43  * Basic s8 transpose convolution function.
44  *
45  * Refer header file for details.
46  *
47  */
arm_transpose_conv_s8(const cmsis_nn_context * ctx,const cmsis_nn_context * output_ctx,const cmsis_nn_transpose_conv_params * transpose_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * filter_dims,const int8_t * filter_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)48 arm_cmsis_nn_status arm_transpose_conv_s8(const cmsis_nn_context *ctx,
49                                           const cmsis_nn_context *output_ctx,
50                                           const cmsis_nn_transpose_conv_params *transpose_conv_params,
51                                           const cmsis_nn_per_channel_quant_params *quant_params,
52                                           const cmsis_nn_dims *input_dims,
53                                           const int8_t *input_data,
54                                           const cmsis_nn_dims *filter_dims,
55                                           const int8_t *filter_data,
56                                           const cmsis_nn_dims *bias_dims,
57                                           const int32_t *bias_data,
58                                           const cmsis_nn_dims *output_dims,
59                                           int8_t *output_data)
60 {
61     (void)bias_dims;
62     (void)output_ctx;
63 
64     const int32_t activation_min = transpose_conv_params->activation.min;
65     const int32_t activation_max = transpose_conv_params->activation.max;
66 
67     const int32_t input_ch = input_dims->c;
68     const int32_t input_x = input_dims->w;
69     const int32_t input_y = input_dims->h;
70 
71     const int32_t output_x = output_dims->w;
72     const int32_t output_y = output_dims->h;
73 
74     const int32_t output_ch = output_dims->c;
75 
76     const int32_t filter_x = filter_dims->w;
77     const int32_t filter_y = filter_dims->h;
78 
79     const int32_t pad_x = transpose_conv_params->padding.w;
80     const int32_t pad_y = transpose_conv_params->padding.h;
81 
82     const int32_t stride_x = transpose_conv_params->stride.w;
83     const int32_t stride_y = transpose_conv_params->stride.h;
84 
85     const int32_t *output_multiplier = quant_params->multiplier;
86     const int32_t *output_shift = quant_params->shift;
87 
88     const int32_t out_offset = transpose_conv_params->output_offset;
89     const int32_t input_offset = transpose_conv_params->input_offset;
90 
91     const int32_t buf_x_elements = ((input_x - 1) * stride_x + MAX(filter_x, stride_x));
92     const int32_t buf_x = buf_x_elements * output_ch;
93     const int32_t buf_y = MAX(filter_y, stride_y);
94     const int32_t buf_size = buf_y * buf_x;
95     int32_t *buf = ctx->buf;
96     int32_t batch_cnt = input_dims->n;
97 
98     const int8_t *filter = filter_data;
99     const int8_t *input = input_data;
100     int8_t *output = output_data;
101 
102     while (batch_cnt)
103     {
104         // Reset buf
105         if (bias_data)
106         {
107             for (int x = 0; x < buf_x_elements * buf_y; x++)
108             {
109                 arm_memcpy_s8((int8_t *)(buf + x * output_ch), (const int8_t *)bias_data, output_ch * sizeof(int32_t));
110             }
111         }
112         else
113         {
114             arm_memset_s8((int8_t *)buf, 0, buf_size * sizeof(int32_t));
115         }
116 
117         int32_t buf_row = 0;
118         for (int j = 0; j < input_y; j++)
119         {
120             int skip_rows_top = MAX(0, pad_y - j * stride_y);
121             int skip_rows_bottom = MAX(0, (j * stride_y + filter_y) - (pad_y + output_y) - 1);
122 
123             // Compute output for one row of input
124             arm_nn_transpose_conv_row_s8_s32(input,
125                                              filter,
126                                              buf,
127                                              buf_row,
128                                              buf_size,
129                                              filter_y,
130                                              filter_x,
131                                              input_ch,
132                                              output_ch,
133                                              input_offset,
134                                              buf_x,
135                                              input_x,
136                                              stride_x,
137                                              skip_rows_top,
138                                              skip_rows_bottom);
139             input += input_ch * input_x;
140 
141             if (skip_rows_top == 0)
142             {
143                 for (int y = 0; y < stride_y; y++)
144                 {
145                     int32_t *buf_out = buf + buf_row;
146                     buf_out += output_ch * pad_x;
147 
148 #if defined(ARM_MATH_MVEI)
149                     for (int x = 0; x < output_x; x++)
150                     {
151                         const int32_t *mult_ptr = output_multiplier;
152                         const int32_t *shift_ptr = output_shift;
153 
154                         int channel_count = output_ch;
155                         for (; channel_count > 0; channel_count -= 4)
156                         {
157                             mve_pred16_t p = vctp32q((uint32_t)channel_count);
158 
159                             int32x4_t result = vldrwq_z_s32(buf_out, p);
160                             buf_out += 4;
161                             result =
162                                 arm_requantize_mve_32x4(result, vldrwq_z_s32(mult_ptr, p), vldrwq_z_s32(shift_ptr, p));
163                             mult_ptr += 4;
164                             shift_ptr += 4;
165                             result = vaddq_n_s32(result, out_offset);
166                             result = vmaxq_s32(result, vdupq_n_s32(activation_min));
167                             result = vminq_s32(result, vdupq_n_s32(activation_max));
168                             vstrbq_p_s32(output, result, p);
169                             output += 4;
170                         }
171 
172                         // Correct pointer overshoot due to predication
173                         buf_out += channel_count;
174                         output += channel_count;
175                     }
176 #else
177 
178                     for (int x = 0; x < output_x; x++)
179                     {
180                         const int32_t *output_multiplier_ptr = output_multiplier;
181                         const int32_t *output_shift_ptr = output_shift;
182                         for (int z = 0; z < output_ch; z++)
183                         {
184                             int32_t result = *buf_out++;
185                             result = arm_nn_requantize(result, *output_multiplier_ptr++, *output_shift_ptr++);
186                             result += out_offset;
187                             result = MAX(result, activation_min);
188                             result = MIN(result, activation_max);
189                             *output++ = result;
190                         }
191                     }
192 #endif
193 
194                     // Reset the buffer which was just written
195                     if (bias_data)
196                     {
197                         for (int x = 0; x < buf_x_elements; x++)
198                         {
199                             arm_memcpy_s8((int8_t *)(buf + buf_row + x * output_ch),
200                                           (const int8_t *)bias_data,
201                                           output_ch * sizeof(int32_t));
202                         }
203                     }
204                     else
205                     {
206                         arm_memset_s8((int8_t *)(buf + buf_row), 0, buf_x * sizeof(int32_t));
207                     }
208 
209                     // Next row in the rolling buffer
210                     buf_row = (buf_row + buf_x) % buf_size;
211                 }
212             }
213         }
214 
215         // Write leftover rows
216         for (int y = 0; y < filter_y - stride_y; y++)
217         {
218             int32_t *buf_out = buf + buf_row;
219             if ((input_y * stride_y + y >= pad_y) && (input_y * stride_y + y < pad_y + output_y))
220             {
221                 buf_out += output_ch * pad_x;
222 #if defined(ARM_MATH_MVEI)
223                 for (int x = 0; x < output_x; x++)
224                 {
225                     const int32_t *mult_ptr = output_multiplier;
226                     const int32_t *shift_ptr = output_shift;
227 
228                     int channel_count = output_ch;
229                     for (; channel_count > 0; channel_count -= 4)
230                     {
231                         mve_pred16_t p = vctp32q((uint32_t)channel_count);
232 
233                         int32x4_t result = vldrwq_z_s32(buf_out, p);
234                         buf_out += 4;
235                         result = arm_requantize_mve_32x4(result, vldrwq_z_s32(mult_ptr, p), vldrwq_z_s32(shift_ptr, p));
236                         mult_ptr += 4;
237                         shift_ptr += 4;
238                         result = vaddq_n_s32(result, out_offset);
239                         result = vmaxq_s32(result, vdupq_n_s32(activation_min));
240                         result = vminq_s32(result, vdupq_n_s32(activation_max));
241                         vstrbq_p_s32(output, result, p);
242                         output += 4;
243                     }
244 
245                     // Correct pointer overshoot due to predication
246                     buf_out += channel_count;
247                     output += channel_count;
248                 }
249 #else
250                 for (int x = 0; x < output_x; x++)
251                 {
252                     const int32_t *output_multiplier_ptr = output_multiplier;
253                     const int32_t *output_shift_ptr = output_shift;
254 
255                     for (int z = 0; z < output_ch; z++)
256                     {
257                         int32_t result = *buf_out++;
258 
259                         result = arm_nn_requantize(result, *output_multiplier_ptr++, *output_shift_ptr++);
260                         result += out_offset;
261                         result = MAX(result, activation_min);
262                         result = MIN(result, activation_max);
263                         *output++ = result;
264                     }
265                 }
266 #endif
267             }
268             buf_row = (buf_row + buf_x) % buf_size;
269         }
270 
271         batch_cnt--;
272     }
273 
274     /* Return to application */
275     return ARM_CMSIS_NN_SUCCESS;
276 }
277 
278 /**
279  * @} end of NNConv group
280  */
281