1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_s8_opt.c
22 * Description: Optimized s8 depthwise separable convolution function for
23 * channel multiplier of 1.
24 *
25 * $Date: 22 March 2023
26 * $Revision: V.3.5.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 *
30 * -------------------------------------------------------------------- */
31
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup Public
37 */
38
39 /**
40 * @addtogroup NNConv
41 * @{
42 */
43
44 /*
45 * Optimized s8 depthwise convolution function with constraint that in_channel equals out_channel
46 *
47 * Refer prototype header file for details.
48 *
49 */
50
arm_depthwise_conv_s8_opt(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)51 arm_cmsis_nn_status arm_depthwise_conv_s8_opt(const cmsis_nn_context *ctx,
52 const cmsis_nn_dw_conv_params *dw_conv_params,
53 const cmsis_nn_per_channel_quant_params *quant_params,
54 const cmsis_nn_dims *input_dims,
55 const int8_t *input,
56 const cmsis_nn_dims *filter_dims,
57 const int8_t *kernel,
58 const cmsis_nn_dims *bias_dims,
59 const int32_t *bias,
60 const cmsis_nn_dims *output_dims,
61 int8_t *output)
62 {
63 const int32_t input_ch = input_dims->c;
64 const int32_t output_ch = output_dims->c;
65
66 /* Check depth multiplier is 1 */
67 if (input_ch != output_ch)
68 {
69 return ARM_CMSIS_NN_ARG_ERROR;
70 }
71
72 if (ctx->buf == NULL && arm_depthwise_conv_s8_opt_get_buffer_size(input_dims, filter_dims) > 0)
73 {
74 return ARM_CMSIS_NN_ARG_ERROR;
75 }
76 #ifdef ARM_MATH_DSP
77 (void)bias_dims;
78 const int32_t input_x = input_dims->w;
79 const int32_t input_y = input_dims->h;
80 const int32_t kernel_x = filter_dims->w;
81 const int32_t kernel_y = filter_dims->h;
82 const int32_t pad_x = dw_conv_params->padding.w;
83 const int32_t pad_y = dw_conv_params->padding.h;
84 const int32_t stride_x = dw_conv_params->stride.w;
85 const int32_t stride_y = dw_conv_params->stride.h;
86 const int32_t *output_shift = quant_params->shift;
87 const int32_t *output_mult = quant_params->multiplier;
88 const int32_t output_x = output_dims->w;
89 const int32_t output_y = output_dims->h;
90 const int32_t output_offset = dw_conv_params->output_offset;
91 const int32_t input_offset = dw_conv_params->input_offset;
92 const int32_t output_activation_min = dw_conv_params->activation.min;
93 const int32_t output_activation_max = dw_conv_params->activation.max;
94 int16_t *buffer_a = (int16_t *)ctx->buf;
95
96 #ifdef ARM_MATH_MVEI
97 /* Generate two columns from the input tensor */
98 int8_t *lhs_buffer = (int8_t *)buffer_a;
99 int8_t *out = output;
100 int buffer_count = 0;
101 const int32_t kernel_size = kernel_x * kernel_y;
102
103 const int32_t ch_loop = (input_ch + (CH_IN_BLOCK_MVE - 1)) / CH_IN_BLOCK_MVE;
104 int32_t remaining_ch = output_ch;
105 int32_t active_ch = MIN(CH_IN_BLOCK_MVE, remaining_ch);
106 remaining_ch -= CH_IN_BLOCK_MVE;
107
108 for (int i_ch = 0; i_ch < ch_loop; i_ch++)
109 {
110 out = output + i_ch * CH_IN_BLOCK_MVE;
111 const int8_t *input_slice = input + (i_ch * CH_IN_BLOCK_MVE);
112
113 for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
114 {
115 for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
116 {
117 for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
118 {
119 for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
120 {
121 if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
122 {
123 arm_memset_s8(lhs_buffer, (int8_t)-input_offset, (uint32_t)active_ch);
124 }
125 else
126 {
127 arm_memcpy_s8(lhs_buffer,
128 input_slice + (i_ker_y * input_x + i_ker_x) * input_ch,
129 (uint32_t)active_ch);
130 }
131 lhs_buffer += CH_IN_BLOCK_MVE;
132 }
133 }
134 buffer_count++;
135
136 if (buffer_count == 4)
137 {
138 const int32_t block_offset = i_ch * CH_IN_BLOCK_MVE;
139 lhs_buffer = (int8_t *)buffer_a;
140
141 arm_nn_depthwise_conv_nt_t_s8(lhs_buffer,
142 kernel + block_offset,
143 input_offset,
144 active_ch,
145 input_ch,
146 output_shift + block_offset,
147 output_mult + block_offset,
148 output_offset,
149 output_activation_min,
150 output_activation_max,
151 kernel_size,
152 bias + block_offset,
153 out);
154
155 out += (4 * input_ch);
156 buffer_count = 0;
157 }
158 }
159 }
160 /* Handle left over buffers */
161 lhs_buffer = (int8_t *)buffer_a;
162
163 int8_t *out_base = out;
164 for (int i_buf = 0; i_buf < buffer_count; i_buf++)
165 {
166 int32_t loop_count = (active_ch + 3) / 4;
167 int32_t num_ch_to_process = active_ch;
168 out = out_base + (i_buf * input_ch);
169 for (int i_loop_cnt = 0, offset = i_ch * CH_IN_BLOCK_MVE; i_loop_cnt < loop_count;
170 num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
171 {
172 const int8_t *col_0 = lhs_buffer + (kernel_size * CH_IN_BLOCK_MVE * i_buf) + (i_loop_cnt * 4);
173 const int8_t *row_0 = kernel + offset;
174 int32x4_t out_0 = vdupq_n_s32(0);
175 if (bias)
176 {
177 out_0 = vldrwq_s32(&bias[offset]);
178 }
179
180 for (int i_ker = 0; i_ker < kernel_size; i_ker++)
181 {
182 const int32x4_t ker_0 = vldrbq_s32(row_0);
183 int32x4_t ip_0 = vldrbq_s32(col_0);
184 ip_0 = vaddq_n_s32(ip_0, input_offset);
185 out_0 += vmulq_s32(ip_0, ker_0);
186
187 col_0 += CH_IN_BLOCK_MVE;
188 row_0 += input_ch;
189 }
190
191 const int32x4_t mult = vldrwq_s32(&output_mult[offset]);
192 const int32x4_t shift = vldrwq_s32(&output_shift[offset]);
193
194 out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
195 out_0 = vaddq_n_s32(out_0, output_offset);
196 out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
197 out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
198 mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
199 vstrbq_p_s32(out, out_0, p);
200
201 out += 4;
202 }
203 }
204 buffer_count = 0;
205
206 active_ch = MIN(CH_IN_BLOCK_MVE, remaining_ch);
207 remaining_ch -= CH_IN_BLOCK_MVE;
208 }
209
210 #else // ARM_MATH_DSP
211 /* Run the following code in cores using DSP extension */
212 int16_t *const col_buffer_start = buffer_a;
213 int16_t *col_buffer = col_buffer_start;
214 const int32_t *const bias_start_pos = bias;
215 const int32_t *const out_mult_start_pos = output_mult;
216 const int32_t *const out_shift_start_pos = output_shift;
217 uint16_t row_count;
218 uint16_t row_shift;
219
220 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
221 {
222 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
223 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
224 {
225 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
226
227 /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
228 along the x axis */
229 const int ker_y_start = MAX(0, -base_idx_y);
230 /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
231 const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
232
233 int32_t index = 0;
234 if (ker_y_start != 0)
235 {
236 memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
237 index += (kernel_x * input_ch) * ker_y_start;
238 }
239
240 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
241 {
242 const int32_t idx_y = base_idx_y + i_ker_y;
243
244 for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
245 {
246 const int32_t idx_x = base_idx_x + i_ker_x;
247 if (idx_x < 0 || idx_x >= input_x)
248 {
249 memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
250 }
251 else
252 {
253 arm_q7_to_q15_with_offset((int8_t *)input + (idx_y * input_x + idx_x) * input_ch,
254 &col_buffer[index],
255 input_ch,
256 (int16_t)input_offset);
257 }
258 index += input_ch;
259 }
260 }
261
262 const int diff = kernel_y - ker_y_end;
263 if (diff != 0)
264 {
265 memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
266 }
267
268 row_count = output_ch / 4;
269 row_shift = 0;
270 bias = bias_start_pos;
271 output_mult = out_mult_start_pos;
272 output_shift = out_shift_start_pos;
273
274 while (row_count)
275 {
276 int32_t sum = 0;
277 int32_t sum_2 = 0;
278 int32_t sum_3 = 0;
279 int32_t sum_4 = 0;
280 if (bias)
281 {
282 sum = *bias++;
283 sum_2 = *bias++;
284 sum_3 = *bias++;
285 sum_4 = *bias++;
286 }
287
288 uint16_t col_count = (kernel_x * kernel_y) / 2;
289 int16_t *col_pos = col_buffer_start + row_shift;
290 const int8_t *row_pos = kernel + row_shift;
291 row_shift += 4;
292
293 while (col_count)
294 {
295 /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
296 use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
297 /* Note: variable names can be improved here to align with rows and columns. */
298 int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
299 /* Read 4 weights */
300 ip_b1 = arm_nn_read_s8x4(row_pos);
301 ip_a1 = arm_nn_read_s8x4(row_pos + input_ch);
302 op_a = arm_nn_read_s16x2(col_pos);
303 op_b = arm_nn_read_s16x2(col_pos + input_ch);
304
305 ip_a2 = SXTB16(ip_b1);
306 ip_b1 = SXTB16(ROR(ip_b1, 8));
307
308 ip_b2 = SXTB16(ip_a1);
309 ip_a1 = SXTB16(ROR(ip_a1, 8));
310
311 op_c = PKHBT(op_b, op_a, 16);
312 op_a = PKHTB(op_b, op_a, 16);
313 op_b = PKHBT(ip_b2, ip_a2, 16);
314 sum = SMLAD(op_c, op_b, sum);
315
316 op_b = PKHBT(ip_b1, ip_a1, 16);
317 sum_2 = SMLAD(op_a, op_b, sum_2);
318
319 op_a = arm_nn_read_s16x2(col_pos + 2);
320 op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
321
322 op_c = PKHBT(op_b, op_a, 16);
323 op_a = PKHTB(op_b, op_a, 16);
324 op_b = PKHTB(ip_a2, ip_b2, 16);
325 sum_3 = SMLAD(op_c, op_b, sum_3);
326
327 op_b = PKHTB(ip_a1, ip_b1, 16);
328 sum_4 = SMLAD(op_a, op_b, sum_4);
329
330 row_pos += input_ch << 1;
331 col_pos += input_ch << 1;
332 col_count--;
333 }
334
335 col_count = (kernel_x * kernel_y) & 0x1;
336 while (col_count)
337 {
338 sum += row_pos[0] * col_pos[0];
339 sum_2 += row_pos[1] * col_pos[1];
340 sum_3 += row_pos[2] * col_pos[2];
341 sum_4 += row_pos[3] * col_pos[3];
342
343 row_pos += input_ch;
344 col_pos += input_ch;
345
346 col_count--;
347 }
348 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
349 sum += output_offset;
350 sum = MAX(sum, output_activation_min);
351 sum = MIN(sum, output_activation_max);
352 *output++ = (int8_t)sum;
353
354 sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
355 sum_2 += output_offset;
356 sum_2 = MAX(sum_2, output_activation_min);
357 sum_2 = MIN(sum_2, output_activation_max);
358 *output++ = (int8_t)sum_2;
359 sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
360 sum_3 += output_offset;
361 sum_3 = MAX(sum_3, output_activation_min);
362 sum_3 = MIN(sum_3, output_activation_max);
363 *output++ = (int8_t)sum_3;
364
365 sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
366 sum_4 += output_offset;
367 sum_4 = MAX(sum_4, output_activation_min);
368 sum_4 = MIN(sum_4, output_activation_max);
369 *output++ = (int8_t)sum_4;
370
371 row_count--;
372 }
373
374 row_count = output_ch & 0x3;
375 while (row_count)
376 {
377 int16_t *col_pos = col_buffer_start + row_shift;
378 const int8_t *row_pos = kernel + row_shift;
379 int32_t sum = 0;
380 if (bias)
381 {
382 sum = *bias++;
383 }
384 const uint16_t col_count = (kernel_x * kernel_y);
385 row_shift += 1;
386
387 for (int i = 0; i < col_count; i++)
388 {
389 sum += row_pos[i * input_ch] * col_pos[i * input_ch];
390 }
391 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
392 sum += output_offset;
393 sum = MAX(sum, output_activation_min);
394 sum = MIN(sum, output_activation_max);
395 *output++ = (int8_t)sum;
396
397 row_count--;
398 }
399
400 // clear counter and pointers
401 col_buffer = col_buffer_start;
402 }
403 }
404 #endif
405 #else
406 /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
407 return arm_depthwise_conv_s8(ctx,
408 dw_conv_params,
409 quant_params,
410 input_dims,
411 input,
412 filter_dims,
413 kernel,
414 bias_dims,
415 bias,
416 output_dims,
417 output);
418 #endif /* ARM_MATH_MVEI | ARM_MATH_DSP */
419
420 /* Return to application */
421 return ARM_CMSIS_NN_SUCCESS;
422 }
423
424 /**
425 * @} end of NNConv group
426 */
427