1 /*
2  * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_s8_opt.c
22  * Description:  Optimized s8 depthwise separable convolution function for
23  *               channel multiplier of 1.
24  *
25  * $Date:        January 26, 2021
26  * $Revision:    V.2.0.3
27  *
28  * Target Processor:  Cortex-M CPUs
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup groupNN
37  */
38 
39 /**
40  * @addtogroup NNConv
41  * @{
42  */
43 
44 /*
45  * Optimized s8 depthwise convolution function with constraint that in_channel equals out_channel
46  *
47  *  Refer prototype header file for details.
48  *
49  */
50 
arm_depthwise_conv_s8_opt(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const q7_t * input,const cmsis_nn_dims * filter_dims,const q7_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,q7_t * output)51 arm_status arm_depthwise_conv_s8_opt(const cmsis_nn_context *ctx,
52                                      const cmsis_nn_dw_conv_params *dw_conv_params,
53                                      const cmsis_nn_per_channel_quant_params *quant_params,
54                                      const cmsis_nn_dims *input_dims,
55                                      const q7_t *input,
56                                      const cmsis_nn_dims *filter_dims,
57                                      const q7_t *kernel,
58                                      const cmsis_nn_dims *bias_dims,
59                                      const int32_t *bias,
60                                      const cmsis_nn_dims *output_dims,
61                                      q7_t *output)
62 {
63 
64     const int32_t input_ch = input_dims->c;
65     const int32_t output_ch = output_dims->c;
66 
67     /* Check input constraints input_ch == output_ch */
68     if (input_ch != output_ch)
69     {
70         return ARM_MATH_SIZE_MISMATCH;
71     }
72 #ifdef ARM_MATH_DSP
73     const int32_t input_x = input_dims->w;
74     const int32_t input_y = input_dims->h;
75     const int32_t kernel_x = filter_dims->w;
76     const int32_t kernel_y = filter_dims->h;
77     const int32_t pad_x = dw_conv_params->padding.w;
78     const int32_t pad_y = dw_conv_params->padding.h;
79     const int32_t stride_x = dw_conv_params->stride.w;
80     const int32_t stride_y = dw_conv_params->stride.h;
81     const int32_t *output_shift = quant_params->shift;
82     const int32_t *output_mult = quant_params->multiplier;
83     const int32_t output_x = output_dims->w;
84     const int32_t output_y = output_dims->h;
85     const int32_t output_offset = dw_conv_params->output_offset;
86     const int32_t input_offset = dw_conv_params->input_offset;
87     const int32_t output_activation_min = dw_conv_params->activation.min;
88     const int32_t output_activation_max = dw_conv_params->activation.max;
89     q15_t *buffer_a = (q15_t *)ctx->buf;
90 
91 #ifdef ARM_MATH_MVEI
92     (void)bias_dims;
93     /* Generate two columns from the input tensor */
94     q7_t *lhs_buffer = (q7_t *)buffer_a;
95     q7_t *out = output;
96     int padded = 0;
97     int buffer_count = 0;
98     const int32_t kernel_size = kernel_x * kernel_y;
99 
100     /* This part implements the im2col function */
101     for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
102     {
103         for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
104         {
105             for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
106             {
107                 for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
108                 {
109                     if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
110                     {
111                         arm_memset_q7(lhs_buffer, (int8_t)-input_offset, (uint32_t)input_ch);
112                         padded = 1;
113                     }
114                     else
115                     {
116                         arm_memcpy_q7(lhs_buffer, input + (i_ker_y * input_x + i_ker_x) * input_ch, (uint32_t)input_ch);
117                     }
118                     lhs_buffer += input_ch;
119                 }
120             }
121             buffer_count++;
122 
123             if (buffer_count == 4)
124             {
125                 lhs_buffer = (q7_t *)buffer_a;
126                 if (padded == 0)
127                 {
128                     out = arm_nn_depthwise_conv_nt_t_s8(lhs_buffer,
129                                                         kernel,
130                                                         input_offset,
131                                                         input_ch,
132                                                         output_shift,
133                                                         output_mult,
134                                                         output_offset,
135                                                         output_activation_min,
136                                                         output_activation_max,
137                                                         kernel_size,
138                                                         bias,
139                                                         out);
140                 }
141                 else
142                 {
143                     out = arm_nn_depthwise_conv_nt_t_padded_s8(lhs_buffer,
144                                                                kernel,
145                                                                input_offset,
146                                                                input_ch,
147                                                                output_shift,
148                                                                output_mult,
149                                                                output_offset,
150                                                                output_activation_min,
151                                                                output_activation_max,
152                                                                kernel_size,
153                                                                bias,
154                                                                out);
155                     padded = 0;
156                 }
157                 buffer_count = 0;
158             }
159         }
160     }
161 
162     /* Handle left over buffers */
163     lhs_buffer = (q7_t *)buffer_a;
164 
165     for (int i_buf = 0; i_buf < buffer_count; i_buf++)
166     {
167         int32_t loop_count = (input_ch + 3) / 4;
168 
169         int32_t num_ch_to_process = input_ch;
170         for (int i_loop_cnt = 0, offset = 0; i_loop_cnt < loop_count; num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
171         {
172             const int8_t *col_0 = lhs_buffer + (kernel_size * input_ch * i_buf) + offset;
173             const int8_t *row_0 = kernel + offset;
174             int32x4_t out_0 = vldrwq_s32(&bias[offset]);
175 
176             for (int i_ker = 0; i_ker < kernel_size; i_ker++)
177             {
178                 const int32x4_t ker_0 = vldrbq_s32(row_0);
179 
180                 int32x4_t ip_0 = vldrbq_s32(col_0);
181                 ip_0 = vaddq_n_s32(ip_0, input_offset);
182                 out_0 += vmulq_s32(ip_0, ker_0);
183 
184                 col_0 += input_ch;
185                 row_0 += input_ch;
186             }
187 
188             const int32x4_t mult = vldrwq_s32(&output_mult[offset]);
189             const int32x4_t shift = vldrwq_s32(&output_shift[offset]);
190 
191             out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
192             out_0 = vaddq_n_s32(out_0, output_offset);
193             out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
194             out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
195             mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
196             vstrbq_p_s32(out, out_0, p);
197 
198             out += 4;
199         }
200 
201         const int tail_ch = input_ch & 0x3;
202         if (tail_ch != 0)
203         {
204             out -= (4 - tail_ch);
205         }
206     }
207 
208 #else // ARM_MATH_DSP
209     (void)bias_dims;
210     /* Run the following code in cores using DSP extension */
211     q15_t *const col_buffer_start = buffer_a;
212     q15_t *col_buffer = col_buffer_start;
213     const int32_t *const bias_start_pos = bias;
214     const q31_t *const out_mult_start_pos = output_mult;
215     const q31_t *const out_shift_start_pos = output_shift;
216     uint16_t row_count;
217     uint16_t row_shift;
218 
219     for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
220     {
221         const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
222         for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
223         {
224             const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
225 
226             /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
227                along the x axis */
228             const int ker_y_start = MAX(0, -base_idx_y);
229             /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
230             const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
231 
232             int32_t index = 0;
233             if (ker_y_start != 0)
234             {
235                 memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(q15_t));
236                 index += (kernel_x * input_ch) * ker_y_start;
237             }
238 
239             for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
240             {
241                 const int32_t idx_y = base_idx_y + i_ker_y;
242 
243                 for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
244                 {
245                     const int32_t idx_x = base_idx_x + i_ker_x;
246                     if (idx_x < 0 || idx_x >= input_x)
247                     {
248                         memset(&col_buffer[index], 0, input_ch * sizeof(q15_t));
249                     }
250                     else
251                     {
252                         arm_q7_to_q15_with_offset((q7_t *)input + (idx_y * input_x + idx_x) * input_ch,
253                                                   &col_buffer[index],
254                                                   input_ch,
255                                                   input_offset);
256                     }
257                     index += input_ch;
258                 }
259             }
260 
261             const int diff = kernel_y - ker_y_end;
262             if (diff != 0)
263             {
264                 memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(q15_t));
265             }
266 
267             row_count = output_ch / 4;
268             row_shift = 0;
269             bias = bias_start_pos;
270             output_mult = out_mult_start_pos;
271             output_shift = out_shift_start_pos;
272 
273             while (row_count)
274             {
275                 q31_t sum = *bias++;
276                 q31_t sum_2 = *bias++;
277                 q31_t sum_3 = *bias++;
278                 q31_t sum_4 = *bias++;
279 
280                 uint16_t col_count = (kernel_x * kernel_y) / 2;
281                 q15_t *col_pos = col_buffer_start + row_shift;
282                 const q7_t *row_pos = kernel + row_shift;
283                 row_shift += 4;
284 
285                 while (col_count)
286                 {
287                     /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
288                     use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
289                     /* Note: variable names can be improved here to align with rows and columns. */
290                     q31_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
291                     /* Read 4 weights */
292                     ip_b1 = arm_nn_read_q7x4(row_pos);
293                     ip_a1 = arm_nn_read_q7x4(row_pos + input_ch);
294                     op_a = arm_nn_read_q15x2(col_pos);
295                     op_b = arm_nn_read_q15x2(col_pos + input_ch);
296 
297                     ip_a2 = __SXTB16(ip_b1);
298                     ip_b1 = __SXTB16(__ROR(ip_b1, 8));
299 
300                     ip_b2 = __SXTB16(ip_a1);
301                     ip_a1 = __SXTB16(__ROR(ip_a1, 8));
302 
303                     op_c = __PKHBT(op_b, op_a, 16);
304                     op_a = __PKHTB(op_b, op_a, 16);
305                     op_b = __PKHBT(ip_b2, ip_a2, 16);
306                     sum = __SMLAD(op_c, op_b, sum);
307 
308                     op_b = __PKHBT(ip_b1, ip_a1, 16);
309                     sum_2 = __SMLAD(op_a, op_b, sum_2);
310 
311                     op_a = arm_nn_read_q15x2(col_pos + 2);
312                     op_b = arm_nn_read_q15x2(col_pos + input_ch + 2);
313 
314                     op_c = __PKHBT(op_b, op_a, 16);
315                     op_a = __PKHTB(op_b, op_a, 16);
316                     op_b = __PKHTB(ip_a2, ip_b2, 16);
317                     sum_3 = __SMLAD(op_c, op_b, sum_3);
318 
319                     op_b = __PKHTB(ip_a1, ip_b1, 16);
320                     sum_4 = __SMLAD(op_a, op_b, sum_4);
321 
322                     row_pos += input_ch << 1;
323                     col_pos += input_ch << 1;
324                     col_count--;
325                 }
326 
327                 col_count = (kernel_x * kernel_y) & 0x1;
328                 while (col_count)
329                 {
330                     sum += row_pos[0] * col_pos[0];
331                     sum_2 += row_pos[1] * col_pos[1];
332                     sum_3 += row_pos[2] * col_pos[2];
333                     sum_4 += row_pos[3] * col_pos[3];
334 
335                     row_pos += input_ch;
336                     col_pos += input_ch;
337 
338                     col_count--;
339                 }
340                 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
341                 sum += output_offset;
342                 sum = MAX(sum, output_activation_min);
343                 sum = MIN(sum, output_activation_max);
344                 *output++ = (q7_t)sum;
345 
346                 sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
347                 sum_2 += output_offset;
348                 sum_2 = MAX(sum_2, output_activation_min);
349                 sum_2 = MIN(sum_2, output_activation_max);
350                 *output++ = (q7_t)sum_2;
351                 sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
352                 sum_3 += output_offset;
353                 sum_3 = MAX(sum_3, output_activation_min);
354                 sum_3 = MIN(sum_3, output_activation_max);
355                 *output++ = (q7_t)sum_3;
356 
357                 sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
358                 sum_4 += output_offset;
359                 sum_4 = MAX(sum_4, output_activation_min);
360                 sum_4 = MIN(sum_4, output_activation_max);
361                 *output++ = (q7_t)sum_4;
362 
363                 row_count--;
364             }
365 
366             row_count = output_ch & 0x3;
367             while (row_count)
368             {
369                 q15_t *col_pos = col_buffer_start + row_shift;
370                 const q7_t *row_pos = kernel + row_shift;
371                 q31_t sum = *bias++;
372                 const uint16_t col_count = (kernel_x * kernel_y);
373                 row_shift += 1;
374 
375                 for (int i = 0; i < col_count; i++)
376                 {
377                     sum += row_pos[i * input_ch] * col_pos[i * input_ch];
378                 }
379                 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
380                 sum += output_offset;
381                 sum = MAX(sum, output_activation_min);
382                 sum = MIN(sum, output_activation_max);
383                 *output++ = (q7_t)sum;
384 
385                 row_count--;
386             }
387 
388             // clear counter and pointers
389             col_buffer = col_buffer_start;
390         }
391     }
392 #endif
393 #else
394     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
395     return arm_depthwise_conv_s8(ctx,
396                                  dw_conv_params,
397                                  quant_params,
398                                  input_dims,
399                                  input,
400                                  filter_dims,
401                                  kernel,
402                                  bias_dims,
403                                  bias,
404                                  output_dims,
405                                  output);
406 #endif /* ARM_MATH_MVEI | ARM_MATH_DSP */
407 
408     /* Return to application */
409     return ARM_MATH_SUCCESS;
410 }
411 
arm_depthwise_conv_s8_opt_get_buffer_size(const cmsis_nn_dims * input_dims,const cmsis_nn_dims * filter_dims)412 int32_t arm_depthwise_conv_s8_opt_get_buffer_size(const cmsis_nn_dims *input_dims, const cmsis_nn_dims *filter_dims)
413 {
414 #if defined(ARM_MATH_MVEI)
415     /* The + 4 accounts for out of bounds read of the lhs buffers in the *_nt_t_* functions.  */
416     return (2 * input_dims->c * filter_dims->w * filter_dims->h) * (int32_t)sizeof(int16_t) + 4;
417 #elif defined(ARM_MATH_DSP)
418     return (input_dims->c * filter_dims->w * filter_dims->h) * sizeof(int16_t);
419 #else
420     (void)input_dims;
421     (void)filter_dims;
422     return 0;
423 #endif
424 }
425 
426 /**
427  * @} end of NNConv group
428  */
429