1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_s4.c
22 * Description: s4 version of depthwise convolution.
23 *
24 * $Date: 13 February 2024
25 * $Revision: V.1.1.0
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup NNConv
40 * @{
41 */
42
depthwise_conv_s4_generic(const int8_t * input,const int32_t input_batches,const int32_t input_x,const int32_t input_y,const int32_t input_ch,const int8_t * kernel,const int32_t output_ch,const int32_t ch_mult,const int32_t kernel_x,const int32_t kernel_y,const int32_t pad_x,const int32_t pad_y,const int32_t stride_x,const int32_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const int32_t output_x,const int32_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max,const int32_t dilation_x,const int32_t dilation_y)43 static void depthwise_conv_s4_generic(const int8_t *input,
44 const int32_t input_batches,
45 const int32_t input_x,
46 const int32_t input_y,
47 const int32_t input_ch,
48 const int8_t *kernel,
49 const int32_t output_ch,
50 const int32_t ch_mult,
51 const int32_t kernel_x,
52 const int32_t kernel_y,
53 const int32_t pad_x,
54 const int32_t pad_y,
55 const int32_t stride_x,
56 const int32_t stride_y,
57 const int32_t *bias,
58 int8_t *output,
59 const int32_t *output_shift,
60 const int32_t *output_mult,
61 const int32_t output_x,
62 const int32_t output_y,
63 const int32_t output_offset,
64 const int32_t input_offset,
65 const int32_t output_activation_min,
66 const int32_t output_activation_max,
67 const int32_t dilation_x,
68 const int32_t dilation_y)
69
70 {
71 (void)output_ch;
72 int i_out = 0;
73 int i_batch;
74
75 const int32_t kernel_index_offset = input_ch >> 1;
76 if (!(input_ch % 2))
77 {
78 for (i_batch = 0; i_batch < input_batches; i_batch++)
79 {
80 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
81 {
82 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
83 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
84 {
85 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
86 int idx_out_ch_s4 = 0;
87 int get_low_nibble = 1;
88
89 // If ch_mult is 1 we can process 2 outputs at a time by doing 2 input_ch iterations
90 if (ch_mult == 1)
91 {
92 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch += 2, idx_out_ch_s4++)
93 {
94 int32_t acc_0 = 0;
95 int32_t acc_1 = 0;
96
97 int ker_y_start;
98 int ker_x_start;
99 int ker_y_end;
100 int ker_x_end;
101
102 if (dilation_x > 1)
103 {
104 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
105 ker_x_start = MAX(0, start_x_max);
106 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
107 ker_x_end = MIN(kernel_x, end_min_x);
108 }
109 else
110 {
111 ker_x_start = MAX(0, -base_idx_x);
112 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
113 }
114
115 if (dilation_y > 1)
116 {
117 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
118 ker_y_start = MAX(0, start_y_max);
119 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
120 ker_y_end = MIN(kernel_y, end_min_y);
121 }
122 else
123 {
124 ker_y_start = MAX(0, -base_idx_y);
125 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
126 }
127
128 if (bias)
129 {
130 acc_0 = bias[i_input_ch];
131 acc_1 = bias[i_input_ch + 1];
132 }
133
134 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
135 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
136 {
137 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
138 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
139
140 int32_t ker_idx_0 =
141 (i_ker_y * kernel_x + ker_x_start) * kernel_index_offset + idx_out_ch_s4;
142
143 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
144 {
145 int8_t ker_val0, ker_val1;
146
147 ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
148 ker_val1 = (kernel[ker_idx_0] >> 4);
149
150 acc_0 += (input[idx_0] + input_offset) * ker_val0;
151 acc_1 += (input[idx_0 + 1] + input_offset) * ker_val1;
152
153 idx_0 += dilation_x * input_ch;
154 idx_x += dilation_x;
155 ker_idx_0 += kernel_index_offset;
156 }
157 idx_y += dilation_y;
158 }
159
160 /* Requantize and clamp output to provided range */
161 acc_0 = arm_nn_requantize(acc_0, output_mult[i_input_ch], output_shift[i_input_ch]);
162 acc_0 += output_offset;
163 acc_0 = MAX(acc_0, output_activation_min);
164 acc_0 = MIN(acc_0, output_activation_max);
165 output[i_out++] = acc_0;
166
167 acc_1 = arm_nn_requantize(acc_1, output_mult[i_input_ch + 1], output_shift[i_input_ch + 1]);
168 acc_1 += output_offset;
169 acc_1 = MAX(acc_1, output_activation_min);
170 acc_1 = MIN(acc_1, output_activation_max);
171 output[i_out++] = acc_1;
172 }
173 }
174 // if ch_mult is odd and greater than 1, we need to continue to process 1 output at a time
175 else if (ch_mult % 2)
176 {
177 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
178 {
179 for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
180 {
181 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
182 if (idx_out_ch && (idx_out_ch % 2 == 0))
183 {
184 idx_out_ch_s4++;
185 }
186
187 int32_t acc_0 = 0;
188
189 int ker_y_start;
190 int ker_x_start;
191 int ker_y_end;
192 int ker_x_end;
193
194 if (dilation_x > 1)
195 {
196 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
197 ker_x_start = MAX(0, start_x_max);
198 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
199 ker_x_end = MIN(kernel_x, end_min_x);
200 }
201 else
202 {
203 ker_x_start = MAX(0, -base_idx_x);
204 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
205 }
206
207 if (dilation_y > 1)
208 {
209 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
210 ker_y_start = MAX(0, start_y_max);
211 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
212 ker_y_end = MIN(kernel_y, end_min_y);
213 }
214 else
215 {
216 ker_y_start = MAX(0, -base_idx_y);
217 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
218 }
219
220 if (bias)
221 {
222 acc_0 = bias[idx_out_ch];
223 }
224
225 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
226 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
227 {
228 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
229 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
230
231 int32_t ker_idx_0 =
232 (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
233 idx_out_ch_s4;
234
235 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
236 {
237 int8_t ker_val0;
238
239 if (get_low_nibble)
240 {
241 ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
242 }
243 else
244 {
245 ker_val0 = (kernel[ker_idx_0] >> 4);
246 }
247
248 acc_0 += (input[idx_0] + input_offset) * ker_val0;
249
250 idx_0 += dilation_x * input_ch;
251 idx_x += dilation_x;
252 ker_idx_0 += (kernel_index_offset * ch_mult);
253 }
254 idx_y += dilation_y;
255 }
256 get_low_nibble = !get_low_nibble;
257
258 /* Requantize and clamp output to provided range */
259 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
260 acc_0 += output_offset;
261 acc_0 = MAX(acc_0, output_activation_min);
262 acc_0 = MIN(acc_0, output_activation_max);
263 output[i_out++] = acc_0;
264 }
265 }
266 }
267 // if ch_mult is even then we can do 2 outputs at a time by processing 2 ch_mult iterations
268 else
269 {
270 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
271 {
272 // ch_mult is limited to being a multiple of input_ch.
273 // This means that we can assume ch_mult is a multiple of 2 given that input_ch is even
274 for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult += 2, idx_out_ch_s4++)
275 {
276 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
277
278 int32_t acc_0 = 0;
279 int32_t acc_1 = 0;
280
281 int ker_y_start;
282 int ker_x_start;
283 int ker_y_end;
284 int ker_x_end;
285
286 if (dilation_x > 1)
287 {
288 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
289 ker_x_start = MAX(0, start_x_max);
290 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
291 ker_x_end = MIN(kernel_x, end_min_x);
292 }
293 else
294 {
295 ker_x_start = MAX(0, -base_idx_x);
296 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
297 }
298
299 if (dilation_y > 1)
300 {
301 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
302 ker_y_start = MAX(0, start_y_max);
303 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
304 ker_y_end = MIN(kernel_y, end_min_y);
305 }
306 else
307 {
308 ker_y_start = MAX(0, -base_idx_y);
309 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
310 }
311
312 if (bias)
313 {
314 acc_0 = bias[idx_out_ch];
315 acc_1 = bias[idx_out_ch + 1];
316 }
317
318 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
319 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
320 {
321 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
322 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
323
324 int32_t ker_idx_0 =
325 (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
326 idx_out_ch_s4;
327
328 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
329 {
330 int8_t ker_val0, ker_val1;
331
332 ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
333 ker_val1 = (kernel[ker_idx_0] >> 4);
334
335 acc_0 += (input[idx_0] + input_offset) * ker_val0;
336 acc_1 += (input[idx_0] + input_offset) * ker_val1;
337
338 idx_0 += dilation_x * input_ch;
339 idx_x += dilation_x;
340 ker_idx_0 += (kernel_index_offset * ch_mult);
341 }
342 idx_y += dilation_y;
343 }
344
345 /* Requantize and clamp output to provided range */
346 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
347 acc_0 += output_offset;
348 acc_0 = MAX(acc_0, output_activation_min);
349 acc_0 = MIN(acc_0, output_activation_max);
350 output[i_out++] = acc_0;
351
352 acc_1 =
353 arm_nn_requantize(acc_1, output_mult[idx_out_ch + 1], output_shift[idx_out_ch + 1]);
354 acc_1 += output_offset;
355 acc_1 = MAX(acc_1, output_activation_min);
356 acc_1 = MIN(acc_1, output_activation_max);
357 output[i_out++] = acc_1;
358 }
359 }
360 }
361 }
362 }
363 /* Advance to the next batch */
364 input += (input_x * input_y * input_ch);
365 }
366 }
367 else
368 {
369 for (i_batch = 0; i_batch < input_batches; i_batch++)
370 {
371 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
372 {
373 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
374 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
375 {
376 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
377 int idx_out_ch_s4 = 0;
378 int get_low_nibble = 1;
379
380 for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
381 {
382 for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
383 {
384 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
385 if (idx_out_ch && (idx_out_ch % 2 == 0))
386 {
387 idx_out_ch_s4++;
388 }
389
390 int16_t kernel_index_offset_uneven = 0;
391 int32_t acc_0 = 0;
392
393 int ker_y_start;
394 int ker_x_start;
395 int ker_y_end;
396 int ker_x_end;
397
398 if (dilation_x > 1)
399 {
400 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
401 ker_x_start = MAX(0, start_x_max);
402 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
403 ker_x_end = MIN(kernel_x, end_min_x);
404 }
405 else
406 {
407 ker_x_start = MAX(0, -base_idx_x);
408 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
409 }
410
411 if (dilation_y > 1)
412 {
413 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
414 ker_y_start = MAX(0, start_y_max);
415 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
416 ker_y_end = MIN(kernel_y, end_min_y);
417 }
418 else
419 {
420 ker_y_start = MAX(0, -base_idx_y);
421 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
422 }
423
424 if (bias)
425 {
426 acc_0 = bias[idx_out_ch];
427 }
428 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
429 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
430 {
431 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
432 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
433
434 int32_t ker_idx_0 =
435 (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
436 idx_out_ch_s4 + kernel_index_offset_uneven;
437
438 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
439 {
440 int8_t ker_val;
441
442 if (get_low_nibble)
443 {
444 get_low_nibble = 0;
445 ker_val = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
446 }
447 else
448 {
449 ker_val = (kernel[ker_idx_0] >> 4);
450 get_low_nibble = 1;
451 kernel_index_offset_uneven++;
452 }
453
454 acc_0 += (input[idx_0] + input_offset) * ker_val;
455 idx_0 += dilation_x * input_ch;
456 idx_x += dilation_x;
457 ker_idx_0 += (kernel_index_offset * ch_mult) + get_low_nibble;
458 }
459 idx_y += dilation_y;
460 }
461 if ((kernel_x * kernel_y) % 2)
462 {
463 get_low_nibble = !get_low_nibble;
464 }
465 get_low_nibble = !get_low_nibble;
466
467 /* Requantize and clamp output to provided range */
468 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
469 acc_0 += output_offset;
470 acc_0 = MAX(acc_0, output_activation_min);
471 acc_0 = MIN(acc_0, output_activation_max);
472
473 output[i_out++] = acc_0;
474 }
475 }
476 }
477 }
478
479 /* Advance to the next batch */
480 input += (input_x * input_y * input_ch);
481 }
482 }
483 }
484
485 /*
486 * Basic s4 depthwise convolution function.
487 *
488 * Refer header file for details.
489 * Optimization using DSP extension is not available for the generic case where channel multiplier is > 1.
490 *
491 */
arm_depthwise_conv_s4(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)492 arm_cmsis_nn_status arm_depthwise_conv_s4(const cmsis_nn_context *ctx,
493 const cmsis_nn_dw_conv_params *dw_conv_params,
494 const cmsis_nn_per_channel_quant_params *quant_params,
495 const cmsis_nn_dims *input_dims,
496 const int8_t *input,
497 const cmsis_nn_dims *filter_dims,
498 const int8_t *kernel,
499 const cmsis_nn_dims *bias_dims,
500 const int32_t *bias,
501 const cmsis_nn_dims *output_dims,
502 int8_t *output)
503 {
504 (void)bias_dims;
505 (void)ctx;
506
507 const int32_t dilation_x = dw_conv_params->dilation.w;
508 const int32_t dilation_y = dw_conv_params->dilation.h;
509 depthwise_conv_s4_generic(input,
510 input_dims->n,
511 input_dims->w,
512 input_dims->h,
513 input_dims->c,
514 kernel,
515 output_dims->c,
516 dw_conv_params->ch_mult,
517 filter_dims->w,
518 filter_dims->h,
519 dw_conv_params->padding.w,
520 dw_conv_params->padding.h,
521 dw_conv_params->stride.w,
522 dw_conv_params->stride.h,
523 bias,
524 output,
525 quant_params->shift,
526 quant_params->multiplier,
527 output_dims->w,
528 output_dims->h,
529 dw_conv_params->output_offset,
530 dw_conv_params->input_offset,
531 dw_conv_params->activation.min,
532 dw_conv_params->activation.max,
533 dilation_x,
534 dilation_y);
535 /* Return to application */
536 return ARM_CMSIS_NN_SUCCESS;
537 }
538
539 /**
540 * @} end of NNConv group
541 */
542