1 /*
2  * SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_fast_s16.c
22  * Description:  Optimized s16 depthwise separable convolution function for
23  *               channel multiplier of 1.
24  *
25  * $Date:        19 March 2024
26  * $Revision:    V.1.4.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup Public
37  */
38 
39 /**
40  * @addtogroup NNConv
41  * @{
42  */
43 
44 /*
45  * Optimized s16 depthwise convolution function with constraint that in_channel equals out_channel
46  *
47  *  Refer prototype header file for details.
48  *
49  */
50 
arm_depthwise_conv_fast_s16(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int16_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int64_t * bias,const cmsis_nn_dims * output_dims,int16_t * output)51 arm_cmsis_nn_status arm_depthwise_conv_fast_s16(const cmsis_nn_context *ctx,
52                                                 const cmsis_nn_dw_conv_params *dw_conv_params,
53                                                 const cmsis_nn_per_channel_quant_params *quant_params,
54                                                 const cmsis_nn_dims *input_dims,
55                                                 const int16_t *input,
56                                                 const cmsis_nn_dims *filter_dims,
57                                                 const int8_t *kernel,
58                                                 const cmsis_nn_dims *bias_dims,
59                                                 const int64_t *bias,
60                                                 const cmsis_nn_dims *output_dims,
61                                                 int16_t *output)
62 {
63     const int32_t input_ch = input_dims->c;
64     const int32_t output_ch = output_dims->c;
65 
66     /* Check input constraints input_ch == output_ch */
67     if (input_ch != output_ch)
68     {
69         return ARM_CMSIS_NN_ARG_ERROR;
70     }
71 
72     if (filter_dims->w * filter_dims->h >= MAX_COL_COUNT)
73     {
74         return ARM_CMSIS_NN_ARG_ERROR;
75     }
76 
77     if (ctx->buf == NULL && arm_depthwise_conv_fast_s16_get_buffer_size(input_dims, filter_dims) > 0)
78     {
79         return ARM_CMSIS_NN_ARG_ERROR;
80     }
81 
82 #if defined(ARM_MATH_DSP)
83     (void)bias_dims;
84     const int32_t input_x = input_dims->w;
85     const int32_t input_y = input_dims->h;
86     const int32_t input_batches = input_dims->n;
87     const int32_t kernel_x = filter_dims->w;
88     const int32_t kernel_y = filter_dims->h;
89     const int32_t pad_x = dw_conv_params->padding.w;
90     const int32_t pad_y = dw_conv_params->padding.h;
91     const int32_t stride_x = dw_conv_params->stride.w;
92     const int32_t stride_y = dw_conv_params->stride.h;
93     const int32_t *output_shift = quant_params->shift;
94     const int32_t *output_mult = quant_params->multiplier;
95     const int32_t output_x = output_dims->w;
96     const int32_t output_y = output_dims->h;
97     const int32_t output_activation_min = dw_conv_params->activation.min;
98     const int32_t output_activation_max = dw_conv_params->activation.max;
99     int16_t *buffer_a = (int16_t *)ctx->buf;
100 
101     #if defined(ARM_MATH_MVEI)
102     int16_t *lhs_buffer = buffer_a;
103     int16_t *out = output;
104     int buffer_count = 0;
105     const int32_t kernel_size = kernel_x * kernel_y;
106 
107     for (int i_batch = 0; i_batch < input_batches; i_batch++)
108     {
109         /* This part implements the im2col function */
110         for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
111         {
112             for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
113             {
114                 for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
115                 {
116                     for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
117                     {
118                         if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
119                         {
120                             memset(lhs_buffer, (int16_t)0, (uint32_t)(input_ch * sizeof(int16_t)));
121                         }
122                         else
123                         {
124                             arm_memcpy_q15(lhs_buffer,
125                                            (int16_t *)(input + (i_ker_y * input_x + i_ker_x) * input_ch),
126                                            (uint32_t)(input_ch * sizeof(int16_t)));
127                         }
128                         lhs_buffer += input_ch;
129                     }
130                 }
131                 buffer_count++;
132                 if (buffer_count == 4)
133                 {
134                     lhs_buffer = buffer_a;
135 
136                     out = arm_nn_depthwise_conv_nt_t_s16(lhs_buffer,
137                                                          kernel,
138                                                          input_ch,
139                                                          output_shift,
140                                                          output_mult,
141                                                          output_activation_min,
142                                                          output_activation_max,
143                                                          kernel_size,
144                                                          bias,
145                                                          out);
146                     buffer_count = 0;
147                 }
148             }
149         }
150         input += input_x * input_y * input_ch;
151     }
152 
153     /* Handle left over buffers */
154     lhs_buffer = buffer_a;
155     for (int i_buf = 0; i_buf < buffer_count; i_buf++)
156     {
157         int32_t loop_count = (input_ch + 3) / 4;
158         int32_t num_ch_to_process = input_ch;
159 
160         for (int i_loop_cnt = 0, offset = 0; i_loop_cnt < loop_count; num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
161         {
162             const int8_t *row_0 = kernel + offset;
163             const int16_t *col_0 = lhs_buffer + (kernel_size * input_ch * i_buf) + offset;
164 
165             int32x4_t out_0 = vdupq_n_s32(0);
166 
167             for (int i_ker = 0; i_ker < kernel_size; i_ker++)
168             {
169                 const int32x4_t ker_0 = vldrbq_s32(row_0);
170 
171                 int32x4_t ip_0 = vldrhq_s32(col_0);
172                 out_0 += vmulq_s32(ip_0, ker_0);
173 
174                 col_0 += input_ch;
175                 row_0 += input_ch;
176             }
177 
178             int64_t in_requantize_0 = (int64_t)out_0[0];
179             int64_t in_requantize_1 = (int64_t)out_0[1];
180             int64_t in_requantize_2 = (int64_t)out_0[2];
181             int64_t in_requantize_3 = (int64_t)out_0[3];
182 
183             if (bias)
184             {
185                 in_requantize_0 += bias[offset];
186                 in_requantize_1 += bias[offset + 1];
187                 in_requantize_2 += bias[offset + 2];
188                 in_requantize_3 += bias[offset + 3];
189             }
190 
191             int32_t reduced_multiplier_0 = REDUCE_MULTIPLIER(output_mult[offset]);
192             int32_t reduced_multiplier_1 = REDUCE_MULTIPLIER(output_mult[offset + 1]);
193             int32_t reduced_multiplier_2 = REDUCE_MULTIPLIER(output_mult[offset + 2]);
194             int32_t reduced_multiplier_3 = REDUCE_MULTIPLIER(output_mult[offset + 3]);
195 
196             out_0[0] = arm_nn_requantize_s64(in_requantize_0, reduced_multiplier_0, output_shift[offset]);
197             out_0[1] = arm_nn_requantize_s64(in_requantize_1, reduced_multiplier_1, output_shift[offset + 1]);
198             out_0[2] = arm_nn_requantize_s64(in_requantize_2, reduced_multiplier_2, output_shift[offset + 2]);
199             out_0[3] = arm_nn_requantize_s64(in_requantize_3, reduced_multiplier_3, output_shift[offset + 3]);
200 
201             out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
202             out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
203 
204             mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
205             vstrhq_p_s32(out, out_0, p);
206 
207             out += 4;
208         }
209 
210         const int tail_ch = input_ch & 0x3;
211         if (tail_ch != 0)
212         {
213             out -= (4 - tail_ch);
214         }
215     }
216 
217     #else // ARM_MATH_DSP
218 
219     /* Run the following code in cores using DSP extension */
220     int16_t *const col_buffer_start = buffer_a;
221     int16_t *col_buffer = col_buffer_start;
222     const int64_t *const bias_start_pos = bias;
223     const int32_t *const out_mult_start_pos = output_mult;
224     const int32_t *const out_shift_start_pos = output_shift;
225     uint16_t row_count;
226     uint16_t row_shift;
227     int32_t result;
228 
229     for (int i_batch = 0; i_batch < input_batches; i_batch++)
230     {
231         for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
232         {
233             const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
234             for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
235             {
236                 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
237 
238                 /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
239                    along the x axis */
240                 const int ker_y_start = MAX(0, -base_idx_y);
241                 /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
242                 const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
243 
244                 int32_t index = 0;
245                 if (ker_y_start != 0)
246                 {
247                     memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
248                     index += (kernel_x * input_ch) * ker_y_start;
249                 }
250 
251                 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
252                 {
253                     const int32_t idx_y = base_idx_y + i_ker_y;
254 
255                     for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
256                     {
257                         const int32_t idx_x = base_idx_x + i_ker_x;
258 
259                         if (idx_x < 0 || idx_x >= input_x)
260                         {
261                             memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
262                         }
263                         else
264                         {
265                             arm_memcpy_q15(&col_buffer[index],
266                                            input + (idx_y * input_x + idx_x) * input_ch,
267                                            input_ch * sizeof(int16_t));
268                         }
269                         index += input_ch;
270                     }
271                 }
272 
273                 const int diff = kernel_y - ker_y_end;
274                 if (diff != 0)
275                 {
276                     memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
277                 }
278 
279                 row_count = output_ch / 4;
280                 row_shift = 0;
281                 bias = bias_start_pos;
282                 output_mult = out_mult_start_pos;
283                 output_shift = out_shift_start_pos;
284 
285                 while (row_count)
286                 {
287                     int32_t sum_1 = 0;
288                     int32_t sum_2 = 0;
289                     int32_t sum_3 = 0;
290                     int32_t sum_4 = 0;
291 
292                     int32_t output_mult_1 = REDUCE_MULTIPLIER(output_mult[0]);
293                     int32_t output_mult_2 = REDUCE_MULTIPLIER(output_mult[1]);
294                     int32_t output_mult_3 = REDUCE_MULTIPLIER(output_mult[2]);
295                     int32_t output_mult_4 = REDUCE_MULTIPLIER(output_mult[3]);
296                     output_mult += 4;
297 
298                     uint16_t col_count = (kernel_x * kernel_y) / 2;
299                     int16_t *col_pos = col_buffer_start + row_shift;
300                     const int8_t *row_pos = kernel + row_shift;
301                     row_shift += 4;
302 
303                     while (col_count)
304                     {
305                         /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
306                         use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
307                         int32_t row_a1, row_a2, row_b1, row_b2, col_a, row_c, col_b, col_c;
308 
309                         /* Read 4 weights */
310                         row_b1 = arm_nn_read_s8x4(row_pos);
311                         row_a1 = arm_nn_read_s8x4(row_pos + input_ch);
312                         col_a = arm_nn_read_s16x2(col_pos);
313                         col_b = arm_nn_read_s16x2(col_pos + input_ch);
314 
315                         row_a2 = SXTB16(row_b1);
316                         row_b1 = SXTB16(ROR(row_b1, 8));
317 
318                         row_b2 = SXTB16(row_a1);
319                         row_a1 = SXTB16(ROR(row_a1, 8));
320 
321                         col_c = PKHBT(col_b, col_a, 16);
322                         col_a = PKHTB(col_b, col_a, 16);
323                         row_c = PKHBT(row_b2, row_a2, 16);
324                         sum_1 = SMLAD(col_c, row_c, sum_1);
325 
326                         row_c = PKHBT(row_b1, row_a1, 16);
327                         sum_2 = SMLAD(col_a, row_c, sum_2);
328 
329                         col_a = arm_nn_read_s16x2(col_pos + 2);
330                         col_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
331 
332                         col_c = PKHBT(col_b, col_a, 16);
333                         col_a = PKHTB(col_b, col_a, 16);
334                         row_c = PKHTB(row_a2, row_b2, 16);
335                         sum_3 = SMLAD(col_c, row_c, sum_3);
336 
337                         row_c = PKHTB(row_a1, row_b1, 16);
338                         sum_4 = SMLAD(col_a, row_c, sum_4);
339 
340                         row_pos += input_ch << 1;
341                         col_pos += input_ch << 1;
342                         col_count--;
343                     }
344 
345                     col_count = (kernel_x * kernel_y) & 0x1;
346                     while (col_count)
347                     {
348                         sum_1 += row_pos[0] * col_pos[0];
349                         sum_2 += row_pos[1] * col_pos[1];
350                         sum_3 += row_pos[2] * col_pos[2];
351                         sum_4 += row_pos[3] * col_pos[3];
352 
353                         row_pos += input_ch;
354                         col_pos += input_ch;
355 
356                         col_count--;
357                     }
358 
359                     int64_t acc_1 = sum_1;
360                     int64_t acc_2 = sum_2;
361                     int64_t acc_3 = sum_3;
362                     int64_t acc_4 = sum_4;
363 
364                     if (bias)
365                     {
366                         acc_1 += *bias++;
367                         acc_2 += *bias++;
368                         acc_3 += *bias++;
369                         acc_4 += *bias++;
370                     }
371 
372                     result = arm_nn_requantize_s64(acc_1, output_mult_1, *output_shift++);
373                     result = MAX(result, output_activation_min);
374                     result = MIN(result, output_activation_max);
375                     *output++ = (int16_t)result;
376 
377                     result = arm_nn_requantize_s64(acc_2, output_mult_2, *output_shift++);
378                     result = MAX(result, output_activation_min);
379                     result = MIN(result, output_activation_max);
380                     *output++ = (int16_t)result;
381 
382                     result = arm_nn_requantize_s64(acc_3, output_mult_3, *output_shift++);
383                     result = MAX(result, output_activation_min);
384                     result = MIN(result, output_activation_max);
385                     *output++ = (int16_t)result;
386 
387                     result = arm_nn_requantize_s64(acc_4, output_mult_4, *output_shift++);
388                     result = MAX(result, output_activation_min);
389                     result = MIN(result, output_activation_max);
390                     *output++ = (int16_t)result;
391 
392                     row_count--;
393                 }
394 
395                 row_count = output_ch & 0x3;
396                 while (row_count)
397                 {
398                     int16_t *col_pos = col_buffer_start + row_shift;
399                     const int8_t *row_pos = kernel + row_shift;
400                     int32_t sum = 0;
401                     const uint16_t col_count = (kernel_x * kernel_y);
402                     row_shift += 1;
403 
404                     for (int i = 0; i < col_count; i++)
405                     {
406                         sum += row_pos[i * input_ch] * col_pos[i * input_ch];
407                     }
408                     int64_t acc = sum;
409                     if (bias)
410                     {
411                         acc += *bias++;
412                     }
413                     result = arm_nn_requantize_s64(acc, REDUCE_MULTIPLIER(*output_mult), *output_shift++);
414                     output_mult++;
415                     result = MAX(result, output_activation_min);
416                     result = MIN(result, output_activation_max);
417                     *output++ = (int16_t)result;
418 
419                     row_count--;
420                 }
421                 // clear counter and pointers
422                 col_buffer = col_buffer_start;
423             }
424         }
425 
426         /* Advance to the next batch */
427         input += (input_x * input_y * input_ch);
428     }
429     #endif
430 #else
431     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
432     return arm_depthwise_conv_s16(ctx,
433                                   dw_conv_params,
434                                   quant_params,
435                                   input_dims,
436                                   input,
437                                   filter_dims,
438                                   kernel,
439                                   bias_dims,
440                                   bias,
441                                   output_dims,
442                                   output);
443 #endif /* ARM_MATH_MVEI | ARM_MATH_DSP */
444 
445     /* Return to application */
446     return ARM_CMSIS_NN_SUCCESS;
447 }
448 
449 /**
450  * @} end of NNConv group
451  */
452