1 /*
2 * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_fast_s16.c
22 * Description: Optimized s16 depthwise separable convolution function for
23 * channel multiplier of 1.
24 *
25 * $Date: 19 March 2024
26 * $Revision: V.1.4.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 *
30 * -------------------------------------------------------------------- */
31
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup Public
37 */
38
39 /**
40 * @addtogroup NNConv
41 * @{
42 */
43
44 /*
45 * Optimized s16 depthwise convolution function with constraint that in_channel equals out_channel
46 *
47 * Refer prototype header file for details.
48 *
49 */
50
arm_depthwise_conv_fast_s16(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int16_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int64_t * bias,const cmsis_nn_dims * output_dims,int16_t * output)51 arm_cmsis_nn_status arm_depthwise_conv_fast_s16(const cmsis_nn_context *ctx,
52 const cmsis_nn_dw_conv_params *dw_conv_params,
53 const cmsis_nn_per_channel_quant_params *quant_params,
54 const cmsis_nn_dims *input_dims,
55 const int16_t *input,
56 const cmsis_nn_dims *filter_dims,
57 const int8_t *kernel,
58 const cmsis_nn_dims *bias_dims,
59 const int64_t *bias,
60 const cmsis_nn_dims *output_dims,
61 int16_t *output)
62 {
63 const int32_t input_ch = input_dims->c;
64 const int32_t output_ch = output_dims->c;
65
66 /* Check input constraints input_ch == output_ch */
67 if (input_ch != output_ch)
68 {
69 return ARM_CMSIS_NN_ARG_ERROR;
70 }
71
72 if (filter_dims->w * filter_dims->h >= MAX_COL_COUNT)
73 {
74 return ARM_CMSIS_NN_ARG_ERROR;
75 }
76
77 if (ctx->buf == NULL && arm_depthwise_conv_fast_s16_get_buffer_size(input_dims, filter_dims) > 0)
78 {
79 return ARM_CMSIS_NN_ARG_ERROR;
80 }
81
82 #if defined(ARM_MATH_DSP)
83 (void)bias_dims;
84 const int32_t input_x = input_dims->w;
85 const int32_t input_y = input_dims->h;
86 const int32_t input_batches = input_dims->n;
87 const int32_t kernel_x = filter_dims->w;
88 const int32_t kernel_y = filter_dims->h;
89 const int32_t pad_x = dw_conv_params->padding.w;
90 const int32_t pad_y = dw_conv_params->padding.h;
91 const int32_t stride_x = dw_conv_params->stride.w;
92 const int32_t stride_y = dw_conv_params->stride.h;
93 const int32_t *output_shift = quant_params->shift;
94 const int32_t *output_mult = quant_params->multiplier;
95 const int32_t output_x = output_dims->w;
96 const int32_t output_y = output_dims->h;
97 const int32_t output_activation_min = dw_conv_params->activation.min;
98 const int32_t output_activation_max = dw_conv_params->activation.max;
99 int16_t *buffer_a = (int16_t *)ctx->buf;
100
101 #if defined(ARM_MATH_MVEI)
102 int16_t *lhs_buffer = buffer_a;
103 int16_t *out = output;
104 int buffer_count = 0;
105 const int32_t kernel_size = kernel_x * kernel_y;
106
107 for (int i_batch = 0; i_batch < input_batches; i_batch++)
108 {
109 /* This part implements the im2col function */
110 for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
111 {
112 for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
113 {
114 for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
115 {
116 for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
117 {
118 if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
119 {
120 memset(lhs_buffer, (int16_t)0, (uint32_t)(input_ch * sizeof(int16_t)));
121 }
122 else
123 {
124 arm_memcpy_q15(lhs_buffer,
125 (int16_t *)(input + (i_ker_y * input_x + i_ker_x) * input_ch),
126 (uint32_t)(input_ch * sizeof(int16_t)));
127 }
128 lhs_buffer += input_ch;
129 }
130 }
131 buffer_count++;
132 if (buffer_count == 4)
133 {
134 lhs_buffer = buffer_a;
135
136 out = arm_nn_depthwise_conv_nt_t_s16(lhs_buffer,
137 kernel,
138 input_ch,
139 output_shift,
140 output_mult,
141 output_activation_min,
142 output_activation_max,
143 kernel_size,
144 bias,
145 out);
146 buffer_count = 0;
147 }
148 }
149 }
150 input += input_x * input_y * input_ch;
151 }
152
153 /* Handle left over buffers */
154 lhs_buffer = buffer_a;
155 for (int i_buf = 0; i_buf < buffer_count; i_buf++)
156 {
157 int32_t loop_count = (input_ch + 3) / 4;
158 int32_t num_ch_to_process = input_ch;
159
160 for (int i_loop_cnt = 0, offset = 0; i_loop_cnt < loop_count; num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
161 {
162 const int8_t *row_0 = kernel + offset;
163 const int16_t *col_0 = lhs_buffer + (kernel_size * input_ch * i_buf) + offset;
164
165 int32x4_t out_0 = vdupq_n_s32(0);
166
167 for (int i_ker = 0; i_ker < kernel_size; i_ker++)
168 {
169 const int32x4_t ker_0 = vldrbq_s32(row_0);
170
171 int32x4_t ip_0 = vldrhq_s32(col_0);
172 out_0 += vmulq_s32(ip_0, ker_0);
173
174 col_0 += input_ch;
175 row_0 += input_ch;
176 }
177
178 int64_t in_requantize_0 = (int64_t)out_0[0];
179 int64_t in_requantize_1 = (int64_t)out_0[1];
180 int64_t in_requantize_2 = (int64_t)out_0[2];
181 int64_t in_requantize_3 = (int64_t)out_0[3];
182
183 if (bias)
184 {
185 in_requantize_0 += bias[offset];
186 in_requantize_1 += bias[offset + 1];
187 in_requantize_2 += bias[offset + 2];
188 in_requantize_3 += bias[offset + 3];
189 }
190
191 int32_t reduced_multiplier_0 = REDUCE_MULTIPLIER(output_mult[offset]);
192 int32_t reduced_multiplier_1 = REDUCE_MULTIPLIER(output_mult[offset + 1]);
193 int32_t reduced_multiplier_2 = REDUCE_MULTIPLIER(output_mult[offset + 2]);
194 int32_t reduced_multiplier_3 = REDUCE_MULTIPLIER(output_mult[offset + 3]);
195
196 out_0[0] = arm_nn_requantize_s64(in_requantize_0, reduced_multiplier_0, output_shift[offset]);
197 out_0[1] = arm_nn_requantize_s64(in_requantize_1, reduced_multiplier_1, output_shift[offset + 1]);
198 out_0[2] = arm_nn_requantize_s64(in_requantize_2, reduced_multiplier_2, output_shift[offset + 2]);
199 out_0[3] = arm_nn_requantize_s64(in_requantize_3, reduced_multiplier_3, output_shift[offset + 3]);
200
201 out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
202 out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
203
204 mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
205 vstrhq_p_s32(out, out_0, p);
206
207 out += 4;
208 }
209
210 const int tail_ch = input_ch & 0x3;
211 if (tail_ch != 0)
212 {
213 out -= (4 - tail_ch);
214 }
215 }
216
217 #else // ARM_MATH_DSP
218
219 /* Run the following code in cores using DSP extension */
220 int16_t *const col_buffer_start = buffer_a;
221 int16_t *col_buffer = col_buffer_start;
222 const int64_t *const bias_start_pos = bias;
223 const int32_t *const out_mult_start_pos = output_mult;
224 const int32_t *const out_shift_start_pos = output_shift;
225 uint16_t row_count;
226 uint16_t row_shift;
227 int32_t result;
228
229 for (int i_batch = 0; i_batch < input_batches; i_batch++)
230 {
231 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
232 {
233 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
234 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
235 {
236 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
237
238 /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
239 along the x axis */
240 const int ker_y_start = MAX(0, -base_idx_y);
241 /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
242 const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
243
244 int32_t index = 0;
245 if (ker_y_start != 0)
246 {
247 memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
248 index += (kernel_x * input_ch) * ker_y_start;
249 }
250
251 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
252 {
253 const int32_t idx_y = base_idx_y + i_ker_y;
254
255 for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
256 {
257 const int32_t idx_x = base_idx_x + i_ker_x;
258
259 if (idx_x < 0 || idx_x >= input_x)
260 {
261 memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
262 }
263 else
264 {
265 arm_memcpy_q15(&col_buffer[index],
266 input + (idx_y * input_x + idx_x) * input_ch,
267 input_ch * sizeof(int16_t));
268 }
269 index += input_ch;
270 }
271 }
272
273 const int diff = kernel_y - ker_y_end;
274 if (diff != 0)
275 {
276 memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
277 }
278
279 row_count = output_ch / 4;
280 row_shift = 0;
281 bias = bias_start_pos;
282 output_mult = out_mult_start_pos;
283 output_shift = out_shift_start_pos;
284
285 while (row_count)
286 {
287 int32_t sum_1 = 0;
288 int32_t sum_2 = 0;
289 int32_t sum_3 = 0;
290 int32_t sum_4 = 0;
291
292 int32_t output_mult_1 = REDUCE_MULTIPLIER(output_mult[0]);
293 int32_t output_mult_2 = REDUCE_MULTIPLIER(output_mult[1]);
294 int32_t output_mult_3 = REDUCE_MULTIPLIER(output_mult[2]);
295 int32_t output_mult_4 = REDUCE_MULTIPLIER(output_mult[3]);
296 output_mult += 4;
297
298 uint16_t col_count = (kernel_x * kernel_y) / 2;
299 int16_t *col_pos = col_buffer_start + row_shift;
300 const int8_t *row_pos = kernel + row_shift;
301 row_shift += 4;
302
303 while (col_count)
304 {
305 /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
306 use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
307 int32_t row_a1, row_a2, row_b1, row_b2, col_a, row_c, col_b, col_c;
308
309 /* Read 4 weights */
310 row_b1 = arm_nn_read_s8x4(row_pos);
311 row_a1 = arm_nn_read_s8x4(row_pos + input_ch);
312 col_a = arm_nn_read_s16x2(col_pos);
313 col_b = arm_nn_read_s16x2(col_pos + input_ch);
314
315 row_a2 = SXTB16(row_b1);
316 row_b1 = SXTB16(ROR(row_b1, 8));
317
318 row_b2 = SXTB16(row_a1);
319 row_a1 = SXTB16(ROR(row_a1, 8));
320
321 col_c = PKHBT(col_b, col_a, 16);
322 col_a = PKHTB(col_b, col_a, 16);
323 row_c = PKHBT(row_b2, row_a2, 16);
324 sum_1 = SMLAD(col_c, row_c, sum_1);
325
326 row_c = PKHBT(row_b1, row_a1, 16);
327 sum_2 = SMLAD(col_a, row_c, sum_2);
328
329 col_a = arm_nn_read_s16x2(col_pos + 2);
330 col_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
331
332 col_c = PKHBT(col_b, col_a, 16);
333 col_a = PKHTB(col_b, col_a, 16);
334 row_c = PKHTB(row_a2, row_b2, 16);
335 sum_3 = SMLAD(col_c, row_c, sum_3);
336
337 row_c = PKHTB(row_a1, row_b1, 16);
338 sum_4 = SMLAD(col_a, row_c, sum_4);
339
340 row_pos += input_ch << 1;
341 col_pos += input_ch << 1;
342 col_count--;
343 }
344
345 col_count = (kernel_x * kernel_y) & 0x1;
346 while (col_count)
347 {
348 sum_1 += row_pos[0] * col_pos[0];
349 sum_2 += row_pos[1] * col_pos[1];
350 sum_3 += row_pos[2] * col_pos[2];
351 sum_4 += row_pos[3] * col_pos[3];
352
353 row_pos += input_ch;
354 col_pos += input_ch;
355
356 col_count--;
357 }
358
359 int64_t acc_1 = sum_1;
360 int64_t acc_2 = sum_2;
361 int64_t acc_3 = sum_3;
362 int64_t acc_4 = sum_4;
363
364 if (bias)
365 {
366 acc_1 += *bias++;
367 acc_2 += *bias++;
368 acc_3 += *bias++;
369 acc_4 += *bias++;
370 }
371
372 result = arm_nn_requantize_s64(acc_1, output_mult_1, *output_shift++);
373 result = MAX(result, output_activation_min);
374 result = MIN(result, output_activation_max);
375 *output++ = (int16_t)result;
376
377 result = arm_nn_requantize_s64(acc_2, output_mult_2, *output_shift++);
378 result = MAX(result, output_activation_min);
379 result = MIN(result, output_activation_max);
380 *output++ = (int16_t)result;
381
382 result = arm_nn_requantize_s64(acc_3, output_mult_3, *output_shift++);
383 result = MAX(result, output_activation_min);
384 result = MIN(result, output_activation_max);
385 *output++ = (int16_t)result;
386
387 result = arm_nn_requantize_s64(acc_4, output_mult_4, *output_shift++);
388 result = MAX(result, output_activation_min);
389 result = MIN(result, output_activation_max);
390 *output++ = (int16_t)result;
391
392 row_count--;
393 }
394
395 row_count = output_ch & 0x3;
396 while (row_count)
397 {
398 int16_t *col_pos = col_buffer_start + row_shift;
399 const int8_t *row_pos = kernel + row_shift;
400 int32_t sum = 0;
401 const uint16_t col_count = (kernel_x * kernel_y);
402 row_shift += 1;
403
404 for (int i = 0; i < col_count; i++)
405 {
406 sum += row_pos[i * input_ch] * col_pos[i * input_ch];
407 }
408 int64_t acc = sum;
409 if (bias)
410 {
411 acc += *bias++;
412 }
413 result = arm_nn_requantize_s64(acc, REDUCE_MULTIPLIER(*output_mult), *output_shift++);
414 output_mult++;
415 result = MAX(result, output_activation_min);
416 result = MIN(result, output_activation_max);
417 *output++ = (int16_t)result;
418
419 row_count--;
420 }
421 // clear counter and pointers
422 col_buffer = col_buffer_start;
423 }
424 }
425
426 /* Advance to the next batch */
427 input += (input_x * input_y * input_ch);
428 }
429 #endif
430 #else
431 /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
432 return arm_depthwise_conv_s16(ctx,
433 dw_conv_params,
434 quant_params,
435 input_dims,
436 input,
437 filter_dims,
438 kernel,
439 bias_dims,
440 bias,
441 output_dims,
442 output);
443 #endif /* ARM_MATH_MVEI | ARM_MATH_DSP */
444
445 /* Return to application */
446 return ARM_CMSIS_NN_SUCCESS;
447 }
448
449 /**
450 * @} end of NNConv group
451 */
452