1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_s4.c
22  * Description:  s4 version of depthwise convolution.
23  *
24  * $Date:        13 February 2024
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup Public
36  */
37 
38 /**
39  * @addtogroup NNConv
40  * @{
41  */
42 
depthwise_conv_s4_generic(const int8_t * input,const int32_t input_batches,const int32_t input_x,const int32_t input_y,const int32_t input_ch,const int8_t * kernel,const int32_t output_ch,const int32_t ch_mult,const int32_t kernel_x,const int32_t kernel_y,const int32_t pad_x,const int32_t pad_y,const int32_t stride_x,const int32_t stride_y,const int32_t * bias,int8_t * output,const int32_t * output_shift,const int32_t * output_mult,const int32_t output_x,const int32_t output_y,const int32_t output_offset,const int32_t input_offset,const int32_t output_activation_min,const int32_t output_activation_max,const int32_t dilation_x,const int32_t dilation_y)43 static void depthwise_conv_s4_generic(const int8_t *input,
44                                       const int32_t input_batches,
45                                       const int32_t input_x,
46                                       const int32_t input_y,
47                                       const int32_t input_ch,
48                                       const int8_t *kernel,
49                                       const int32_t output_ch,
50                                       const int32_t ch_mult,
51                                       const int32_t kernel_x,
52                                       const int32_t kernel_y,
53                                       const int32_t pad_x,
54                                       const int32_t pad_y,
55                                       const int32_t stride_x,
56                                       const int32_t stride_y,
57                                       const int32_t *bias,
58                                       int8_t *output,
59                                       const int32_t *output_shift,
60                                       const int32_t *output_mult,
61                                       const int32_t output_x,
62                                       const int32_t output_y,
63                                       const int32_t output_offset,
64                                       const int32_t input_offset,
65                                       const int32_t output_activation_min,
66                                       const int32_t output_activation_max,
67                                       const int32_t dilation_x,
68                                       const int32_t dilation_y)
69 
70 {
71     (void)output_ch;
72     int i_out = 0;
73     int i_batch;
74 
75     const int32_t kernel_index_offset = input_ch >> 1;
76     if (!(input_ch % 2))
77     {
78         for (i_batch = 0; i_batch < input_batches; i_batch++)
79         {
80             for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
81             {
82                 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
83                 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
84                 {
85                     const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
86                     int idx_out_ch_s4 = 0;
87                     int get_low_nibble = 1;
88 
89                     // If ch_mult is 1 we can process 2 outputs at a time by doing 2 input_ch iterations
90                     if (ch_mult == 1)
91                     {
92                         for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch += 2, idx_out_ch_s4++)
93                         {
94                             int32_t acc_0 = 0;
95                             int32_t acc_1 = 0;
96 
97                             int ker_y_start;
98                             int ker_x_start;
99                             int ker_y_end;
100                             int ker_x_end;
101 
102                             if (dilation_x > 1)
103                             {
104                                 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
105                                 ker_x_start = MAX(0, start_x_max);
106                                 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
107                                 ker_x_end = MIN(kernel_x, end_min_x);
108                             }
109                             else
110                             {
111                                 ker_x_start = MAX(0, -base_idx_x);
112                                 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
113                             }
114 
115                             if (dilation_y > 1)
116                             {
117                                 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
118                                 ker_y_start = MAX(0, start_y_max);
119                                 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
120                                 ker_y_end = MIN(kernel_y, end_min_y);
121                             }
122                             else
123                             {
124                                 ker_y_start = MAX(0, -base_idx_y);
125                                 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
126                             }
127 
128                             if (bias)
129                             {
130                                 acc_0 = bias[i_input_ch];
131                                 acc_1 = bias[i_input_ch + 1];
132                             }
133 
134                             int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
135                             for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
136                             {
137                                 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
138                                 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
139 
140                                 int32_t ker_idx_0 =
141                                     (i_ker_y * kernel_x + ker_x_start) * kernel_index_offset + idx_out_ch_s4;
142 
143                                 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
144                                 {
145                                     int8_t ker_val0, ker_val1;
146 
147                                     ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
148                                     ker_val1 = (kernel[ker_idx_0] >> 4);
149 
150                                     acc_0 += (input[idx_0] + input_offset) * ker_val0;
151                                     acc_1 += (input[idx_0 + 1] + input_offset) * ker_val1;
152 
153                                     idx_0 += dilation_x * input_ch;
154                                     idx_x += dilation_x;
155                                     ker_idx_0 += kernel_index_offset;
156                                 }
157                                 idx_y += dilation_y;
158                             }
159 
160                             /* Requantize and clamp output to provided range */
161                             acc_0 = arm_nn_requantize(acc_0, output_mult[i_input_ch], output_shift[i_input_ch]);
162                             acc_0 += output_offset;
163                             acc_0 = MAX(acc_0, output_activation_min);
164                             acc_0 = MIN(acc_0, output_activation_max);
165                             output[i_out++] = acc_0;
166 
167                             acc_1 = arm_nn_requantize(acc_1, output_mult[i_input_ch + 1], output_shift[i_input_ch + 1]);
168                             acc_1 += output_offset;
169                             acc_1 = MAX(acc_1, output_activation_min);
170                             acc_1 = MIN(acc_1, output_activation_max);
171                             output[i_out++] = acc_1;
172                         }
173                     }
174                     // if ch_mult is odd and greater than 1, we need to continue to process 1 output at a time
175                     else if (ch_mult % 2)
176                     {
177                         for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
178                         {
179                             for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
180                             {
181                                 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
182                                 if (idx_out_ch && (idx_out_ch % 2 == 0))
183                                 {
184                                     idx_out_ch_s4++;
185                                 }
186 
187                                 int32_t acc_0 = 0;
188 
189                                 int ker_y_start;
190                                 int ker_x_start;
191                                 int ker_y_end;
192                                 int ker_x_end;
193 
194                                 if (dilation_x > 1)
195                                 {
196                                     const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
197                                     ker_x_start = MAX(0, start_x_max);
198                                     const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
199                                     ker_x_end = MIN(kernel_x, end_min_x);
200                                 }
201                                 else
202                                 {
203                                     ker_x_start = MAX(0, -base_idx_x);
204                                     ker_x_end = MIN(kernel_x, input_x - base_idx_x);
205                                 }
206 
207                                 if (dilation_y > 1)
208                                 {
209                                     const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
210                                     ker_y_start = MAX(0, start_y_max);
211                                     const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
212                                     ker_y_end = MIN(kernel_y, end_min_y);
213                                 }
214                                 else
215                                 {
216                                     ker_y_start = MAX(0, -base_idx_y);
217                                     ker_y_end = MIN(kernel_y, input_y - base_idx_y);
218                                 }
219 
220                                 if (bias)
221                                 {
222                                     acc_0 = bias[idx_out_ch];
223                                 }
224 
225                                 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
226                                 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
227                                 {
228                                     int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
229                                     int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
230 
231                                     int32_t ker_idx_0 =
232                                         (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
233                                         idx_out_ch_s4;
234 
235                                     for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
236                                     {
237                                         int8_t ker_val0;
238 
239                                         if (get_low_nibble)
240                                         {
241                                             ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
242                                         }
243                                         else
244                                         {
245                                             ker_val0 = (kernel[ker_idx_0] >> 4);
246                                         }
247 
248                                         acc_0 += (input[idx_0] + input_offset) * ker_val0;
249 
250                                         idx_0 += dilation_x * input_ch;
251                                         idx_x += dilation_x;
252                                         ker_idx_0 += (kernel_index_offset * ch_mult);
253                                     }
254                                     idx_y += dilation_y;
255                                 }
256                                 get_low_nibble = !get_low_nibble;
257 
258                                 /* Requantize and clamp output to provided range */
259                                 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
260                                 acc_0 += output_offset;
261                                 acc_0 = MAX(acc_0, output_activation_min);
262                                 acc_0 = MIN(acc_0, output_activation_max);
263                                 output[i_out++] = acc_0;
264                             }
265                         }
266                     }
267                     // if ch_mult is even then we can do 2 outputs at a time by processing 2 ch_mult iterations
268                     else
269                     {
270                         for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
271                         {
272                             // ch_mult is limited to being a multiple of input_ch.
273                             // This means that we can assume ch_mult is a multiple of 2 given that input_ch is even
274                             for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult += 2, idx_out_ch_s4++)
275                             {
276                                 const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
277 
278                                 int32_t acc_0 = 0;
279                                 int32_t acc_1 = 0;
280 
281                                 int ker_y_start;
282                                 int ker_x_start;
283                                 int ker_y_end;
284                                 int ker_x_end;
285 
286                                 if (dilation_x > 1)
287                                 {
288                                     const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
289                                     ker_x_start = MAX(0, start_x_max);
290                                     const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
291                                     ker_x_end = MIN(kernel_x, end_min_x);
292                                 }
293                                 else
294                                 {
295                                     ker_x_start = MAX(0, -base_idx_x);
296                                     ker_x_end = MIN(kernel_x, input_x - base_idx_x);
297                                 }
298 
299                                 if (dilation_y > 1)
300                                 {
301                                     const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
302                                     ker_y_start = MAX(0, start_y_max);
303                                     const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
304                                     ker_y_end = MIN(kernel_y, end_min_y);
305                                 }
306                                 else
307                                 {
308                                     ker_y_start = MAX(0, -base_idx_y);
309                                     ker_y_end = MIN(kernel_y, input_y - base_idx_y);
310                                 }
311 
312                                 if (bias)
313                                 {
314                                     acc_0 = bias[idx_out_ch];
315                                     acc_1 = bias[idx_out_ch + 1];
316                                 }
317 
318                                 int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
319                                 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
320                                 {
321                                     int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
322                                     int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
323 
324                                     int32_t ker_idx_0 =
325                                         (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
326                                         idx_out_ch_s4;
327 
328                                     for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
329                                     {
330                                         int8_t ker_val0, ker_val1;
331 
332                                         ker_val0 = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
333                                         ker_val1 = (kernel[ker_idx_0] >> 4);
334 
335                                         acc_0 += (input[idx_0] + input_offset) * ker_val0;
336                                         acc_1 += (input[idx_0] + input_offset) * ker_val1;
337 
338                                         idx_0 += dilation_x * input_ch;
339                                         idx_x += dilation_x;
340                                         ker_idx_0 += (kernel_index_offset * ch_mult);
341                                     }
342                                     idx_y += dilation_y;
343                                 }
344 
345                                 /* Requantize and clamp output to provided range */
346                                 acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
347                                 acc_0 += output_offset;
348                                 acc_0 = MAX(acc_0, output_activation_min);
349                                 acc_0 = MIN(acc_0, output_activation_max);
350                                 output[i_out++] = acc_0;
351 
352                                 acc_1 =
353                                     arm_nn_requantize(acc_1, output_mult[idx_out_ch + 1], output_shift[idx_out_ch + 1]);
354                                 acc_1 += output_offset;
355                                 acc_1 = MAX(acc_1, output_activation_min);
356                                 acc_1 = MIN(acc_1, output_activation_max);
357                                 output[i_out++] = acc_1;
358                             }
359                         }
360                     }
361                 }
362             }
363             /* Advance to the next batch */
364             input += (input_x * input_y * input_ch);
365         }
366     }
367     else
368     {
369         for (i_batch = 0; i_batch < input_batches; i_batch++)
370         {
371             for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
372             {
373                 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
374                 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
375                 {
376                     const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
377                     int idx_out_ch_s4 = 0;
378                     int get_low_nibble = 1;
379 
380                     for (int i_input_ch = 0; i_input_ch < input_ch; i_input_ch++)
381                     {
382                         for (int i_ch_mult = 0; i_ch_mult < ch_mult; i_ch_mult++)
383                         {
384                             const int idx_out_ch = i_ch_mult + i_input_ch * ch_mult;
385                             if (idx_out_ch && (idx_out_ch % 2 == 0))
386                             {
387                                 idx_out_ch_s4++;
388                             }
389 
390                             int16_t kernel_index_offset_uneven = 0;
391                             int32_t acc_0 = 0;
392 
393                             int ker_y_start;
394                             int ker_x_start;
395                             int ker_y_end;
396                             int ker_x_end;
397 
398                             if (dilation_x > 1)
399                             {
400                                 const int32_t start_x_max = (-base_idx_x + dilation_x - 1) / dilation_x;
401                                 ker_x_start = MAX(0, start_x_max);
402                                 const int32_t end_min_x = (input_x - base_idx_x + dilation_x - 1) / dilation_x;
403                                 ker_x_end = MIN(kernel_x, end_min_x);
404                             }
405                             else
406                             {
407                                 ker_x_start = MAX(0, -base_idx_x);
408                                 ker_x_end = MIN(kernel_x, input_x - base_idx_x);
409                             }
410 
411                             if (dilation_y > 1)
412                             {
413                                 const int32_t start_y_max = (-base_idx_y + dilation_y - 1) / dilation_y;
414                                 ker_y_start = MAX(0, start_y_max);
415                                 const int32_t end_min_y = (input_y - base_idx_y + dilation_y - 1) / dilation_y;
416                                 ker_y_end = MIN(kernel_y, end_min_y);
417                             }
418                             else
419                             {
420                                 ker_y_start = MAX(0, -base_idx_y);
421                                 ker_y_end = MIN(kernel_y, input_y - base_idx_y);
422                             }
423 
424                             if (bias)
425                             {
426                                 acc_0 = bias[idx_out_ch];
427                             }
428                             int32_t idx_y = base_idx_y + dilation_y * ker_y_start;
429                             for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
430                             {
431                                 int32_t idx_x = base_idx_x + dilation_x * ker_x_start;
432                                 int32_t idx_0 = (idx_y * input_x + idx_x) * input_ch + i_input_ch;
433 
434                                 int32_t ker_idx_0 =
435                                     (i_ker_y * kernel_x + ker_x_start) * (kernel_index_offset * ch_mult) +
436                                     idx_out_ch_s4 + kernel_index_offset_uneven;
437 
438                                 for (int i_ker_x = ker_x_start; i_ker_x < ker_x_end; i_ker_x++)
439                                 {
440                                     int8_t ker_val;
441 
442                                     if (get_low_nibble)
443                                     {
444                                         get_low_nibble = 0;
445                                         ker_val = ((int8_t)(kernel[ker_idx_0] << 4) >> 4);
446                                     }
447                                     else
448                                     {
449                                         ker_val = (kernel[ker_idx_0] >> 4);
450                                         get_low_nibble = 1;
451                                         kernel_index_offset_uneven++;
452                                     }
453 
454                                     acc_0 += (input[idx_0] + input_offset) * ker_val;
455                                     idx_0 += dilation_x * input_ch;
456                                     idx_x += dilation_x;
457                                     ker_idx_0 += (kernel_index_offset * ch_mult) + get_low_nibble;
458                                 }
459                                 idx_y += dilation_y;
460                             }
461                             if ((kernel_x * kernel_y) % 2)
462                             {
463                                 get_low_nibble = !get_low_nibble;
464                             }
465                             get_low_nibble = !get_low_nibble;
466 
467                             /* Requantize and clamp output to provided range */
468                             acc_0 = arm_nn_requantize(acc_0, output_mult[idx_out_ch], output_shift[idx_out_ch]);
469                             acc_0 += output_offset;
470                             acc_0 = MAX(acc_0, output_activation_min);
471                             acc_0 = MIN(acc_0, output_activation_max);
472 
473                             output[i_out++] = acc_0;
474                         }
475                     }
476                 }
477             }
478 
479             /* Advance to the next batch */
480             input += (input_x * input_y * input_ch);
481         }
482     }
483 }
484 
485 /*
486  *  Basic s4 depthwise convolution function.
487  *
488  *  Refer header file for details.
489  *  Optimization using DSP extension is not available for the generic case where channel multiplier is > 1.
490  *
491  */
arm_depthwise_conv_s4(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)492 arm_cmsis_nn_status arm_depthwise_conv_s4(const cmsis_nn_context *ctx,
493                                           const cmsis_nn_dw_conv_params *dw_conv_params,
494                                           const cmsis_nn_per_channel_quant_params *quant_params,
495                                           const cmsis_nn_dims *input_dims,
496                                           const int8_t *input,
497                                           const cmsis_nn_dims *filter_dims,
498                                           const int8_t *kernel,
499                                           const cmsis_nn_dims *bias_dims,
500                                           const int32_t *bias,
501                                           const cmsis_nn_dims *output_dims,
502                                           int8_t *output)
503 {
504     (void)bias_dims;
505     (void)ctx;
506 
507     const int32_t dilation_x = dw_conv_params->dilation.w;
508     const int32_t dilation_y = dw_conv_params->dilation.h;
509     depthwise_conv_s4_generic(input,
510                               input_dims->n,
511                               input_dims->w,
512                               input_dims->h,
513                               input_dims->c,
514                               kernel,
515                               output_dims->c,
516                               dw_conv_params->ch_mult,
517                               filter_dims->w,
518                               filter_dims->h,
519                               dw_conv_params->padding.w,
520                               dw_conv_params->padding.h,
521                               dw_conv_params->stride.w,
522                               dw_conv_params->stride.h,
523                               bias,
524                               output,
525                               quant_params->shift,
526                               quant_params->multiplier,
527                               output_dims->w,
528                               output_dims->h,
529                               dw_conv_params->output_offset,
530                               dw_conv_params->input_offset,
531                               dw_conv_params->activation.min,
532                               dw_conv_params->activation.max,
533                               dilation_x,
534                               dilation_y);
535     /* Return to application */
536     return ARM_CMSIS_NN_SUCCESS;
537 }
538 
539 /**
540  * @} end of NNConv group
541  */
542