1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_conv_s4_opt.c
22  * Description:  Optimized s4 depthwise separable convolution function for
23  *               channel multiplier of 1.
24  *
25  * $Date:        17 April 2024
26  * $Revision:    V.1.1.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  *  @ingroup Public
37  */
38 
39 /**
40  * @addtogroup NNConv
41  * @{
42  */
43 
44 /*
45  * Optimized s4 depthwise convolution function with constraint that in_channel equals out_channel
46  *
47  *  Refer prototype header file for details.
48  *
49  */
50 
arm_depthwise_conv_s4_opt(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)51 arm_cmsis_nn_status arm_depthwise_conv_s4_opt(const cmsis_nn_context *ctx,
52                                               const cmsis_nn_dw_conv_params *dw_conv_params,
53                                               const cmsis_nn_per_channel_quant_params *quant_params,
54                                               const cmsis_nn_dims *input_dims,
55                                               const int8_t *input,
56                                               const cmsis_nn_dims *filter_dims,
57                                               const int8_t *kernel,
58                                               const cmsis_nn_dims *bias_dims,
59                                               const int32_t *bias,
60                                               const cmsis_nn_dims *output_dims,
61                                               int8_t *output)
62 {
63     (void)bias_dims;
64 
65     const int32_t input_ch = input_dims->c;
66     const int32_t output_ch = output_dims->c;
67 
68     /* Check depth multiplier is 1 */
69     if (input_ch != output_ch)
70     {
71         return ARM_CMSIS_NN_ARG_ERROR;
72     }
73 
74     if (ctx->buf == NULL)
75     {
76         return ARM_CMSIS_NN_ARG_ERROR;
77     }
78 
79     const int32_t input_x = input_dims->w;
80     const int32_t input_y = input_dims->h;
81     const int32_t kernel_x = filter_dims->w;
82     const int32_t kernel_y = filter_dims->h;
83     const int32_t pad_x = dw_conv_params->padding.w;
84     const int32_t pad_y = dw_conv_params->padding.h;
85     const int32_t stride_x = dw_conv_params->stride.w;
86     const int32_t stride_y = dw_conv_params->stride.h;
87     const int32_t *output_shift = quant_params->shift;
88     const int32_t *output_mult = quant_params->multiplier;
89     const int32_t output_x = output_dims->w;
90     const int32_t output_y = output_dims->h;
91     const int32_t output_offset = dw_conv_params->output_offset;
92     const int32_t input_offset = dw_conv_params->input_offset;
93     const int32_t output_activation_min = dw_conv_params->activation.min;
94     const int32_t output_activation_max = dw_conv_params->activation.max;
95     int16_t *buffer_a = (int16_t *)ctx->buf;
96 
97 #ifdef ARM_MATH_MVEI
98     /* Generate two columns from the input tensor */
99     int8_t *lhs_buffer = (int8_t *)buffer_a;
100     int8_t *out = output;
101     int buffer_count = 0;
102     const int32_t kernel_size = kernel_x * kernel_y;
103 
104     const int32_t ch_loop = (input_ch + (S4_CH_IN_BLOCK_MVE - 1)) / S4_CH_IN_BLOCK_MVE;
105     int32_t remaining_ch = output_ch;
106     int32_t active_ch = MIN(S4_CH_IN_BLOCK_MVE, remaining_ch);
107     remaining_ch -= S4_CH_IN_BLOCK_MVE;
108 
109     for (int i_ch = 0; i_ch < ch_loop; i_ch++)
110     {
111         out = output + i_ch * S4_CH_IN_BLOCK_MVE;
112         const int8_t *input_slice = input + (i_ch * S4_CH_IN_BLOCK_MVE);
113 
114         for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
115         {
116             for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
117             {
118                 for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
119                 {
120                     for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
121                     {
122                         if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
123                         {
124                             arm_memset_s8(lhs_buffer, (int8_t)-input_offset, (uint32_t)active_ch);
125                         }
126                         else
127                         {
128                             arm_memcpy_s8(lhs_buffer,
129                                           input_slice + (i_ker_y * input_x + i_ker_x) * input_ch,
130                                           (uint32_t)active_ch);
131                         }
132                         lhs_buffer += S4_CH_IN_BLOCK_MVE;
133                     }
134                 }
135                 buffer_count++;
136 
137                 if (buffer_count == 4)
138                 {
139                     const int32_t block_offset = i_ch * S4_CH_IN_BLOCK_MVE;
140                     lhs_buffer = (int8_t *)buffer_a;
141                     arm_nn_depthwise_conv_nt_t_s4(lhs_buffer,
142                                                   kernel + (block_offset >> 1),
143                                                   input_offset,
144                                                   active_ch,
145                                                   input_ch,
146                                                   output_shift + block_offset,
147                                                   output_mult + block_offset,
148                                                   output_offset,
149                                                   output_activation_min,
150                                                   output_activation_max,
151                                                   kernel_size,
152                                                   bias + block_offset,
153                                                   out);
154 
155                     out += (4 * input_ch);
156                     buffer_count = 0;
157                 }
158             }
159         }
160         /* Handle left over buffers */
161         lhs_buffer = (int8_t *)buffer_a;
162 
163         int8_t *out_base = out;
164         const uint32x4_t gather_offset = {0, 0, 1, 1};
165         const mve_pred16_t lower_nibble_mask = 3855; // 0000111100001111
166         for (int i_buf = 0; i_buf < buffer_count; i_buf++)
167         {
168             int32_t loop_count = (active_ch + 3) / 4;
169             int32_t num_ch_to_process = active_ch;
170             out = out_base + (i_buf * input_ch);
171             for (int i_loop_cnt = 0, offset = i_ch * S4_CH_IN_BLOCK_MVE; i_loop_cnt < loop_count;
172                  num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
173             {
174                 const int8_t *col_0 = lhs_buffer + (kernel_size * S4_CH_IN_BLOCK_MVE * i_buf) + (i_loop_cnt * 4);
175                 const int8_t *row_0 = kernel + (offset >> 1);
176                 int32x4_t out_0 = vdupq_n_s32(0);
177                 if (bias)
178                 {
179                     out_0 = vldrwq_s32(&bias[offset]);
180                 }
181 
182                 if (input_ch % 2)
183                 {
184                     int get_low_nibble = 1;
185                     for (int i_ker = 0; i_ker < kernel_size; i_ker++)
186                     {
187                         int32x4_t ker_0;
188                         if (get_low_nibble)
189                         {
190                             ker_0 = vldrbq_gather_offset_s32(row_0, gather_offset);
191                             ker_0 = vrshlq_m_n_s32(ker_0, 28, lower_nibble_mask);
192                             ker_0 = vshrq_m_n_s32(ker_0, ker_0, 24, lower_nibble_mask);
193 
194                             ker_0 = vshrq_n_s32(ker_0, 4);
195                         }
196                         else
197                         {
198                             int8_t temp[] = {row_0[0] >> 4,
199                                              (int8_t)(row_0[1] << 4) >> 4,
200                                              row_0[1] >> 4,
201                                              (int8_t)(row_0[2] << 4) >> 4};
202                             ker_0 = vldrbq_s32(temp);
203                         }
204 
205                         int32x4_t ip_0 = vldrbq_s32(col_0);
206                         ip_0 = vaddq_n_s32(ip_0, input_offset);
207                         out_0 += vmulq_s32(ip_0, ker_0);
208 
209                         get_low_nibble = !get_low_nibble;
210                         col_0 += S4_CH_IN_BLOCK_MVE;
211                         row_0 += (input_ch >> 1) + get_low_nibble;
212                     }
213                 }
214                 else
215                 {
216                     for (int i_ker = 0; i_ker < kernel_size; i_ker++)
217                     {
218                         int32x4_t ker_0 = vldrbq_gather_offset_s32(row_0, gather_offset);
219                         ker_0 = vrshlq_m_n_s32(ker_0, 28, lower_nibble_mask);
220                         ker_0 = vshrq_m_n_s32(ker_0, ker_0, 24, lower_nibble_mask);
221 
222                         ker_0 = vshrq_n_s32(ker_0, 4);
223 
224                         int32x4_t ip_0 = vldrbq_s32(col_0);
225                         ip_0 = vaddq_n_s32(ip_0, input_offset);
226                         out_0 += vmulq_s32(ip_0, ker_0);
227 
228                         col_0 += S4_CH_IN_BLOCK_MVE;
229                         row_0 += input_ch >> 1;
230                     }
231                 }
232 
233                 const int32x4_t mult = vldrwq_s32(&output_mult[offset]);
234                 const int32x4_t shift = vldrwq_s32(&output_shift[offset]);
235 
236                 out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
237                 out_0 = vaddq_n_s32(out_0, output_offset);
238                 out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
239                 out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
240                 mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
241                 vstrbq_p_s32(out, out_0, p);
242 
243                 out += 4;
244             }
245         }
246         buffer_count = 0;
247 
248         active_ch = MIN(S4_CH_IN_BLOCK_MVE, remaining_ch);
249         remaining_ch -= S4_CH_IN_BLOCK_MVE;
250     }
251 #else
252     int16_t *const col_buffer_start = buffer_a;
253     int16_t *col_buffer = col_buffer_start;
254     const int32_t *const bias_start_pos = bias;
255     const int32_t *const out_mult_start_pos = output_mult;
256     const int32_t *const out_shift_start_pos = output_shift;
257     const uint16_t num_cols = kernel_x * kernel_y;
258     uint16_t row_count;
259     uint16_t row_shift = 0;
260     uint16_t col_shift = 0;
261 
262     for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
263     {
264         const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
265         for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
266         {
267             const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
268 
269             /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
270                along the x axis */
271             const int ker_y_start = MAX(0, -base_idx_y);
272             /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
273             const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
274 
275             int32_t index = 0;
276             if (ker_y_start != 0)
277             {
278                 memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
279                 index += (kernel_x * input_ch) * ker_y_start;
280             }
281 
282             for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
283             {
284                 const int32_t idx_y = base_idx_y + i_ker_y;
285 
286                 for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
287                 {
288                     const int32_t idx_x = base_idx_x + i_ker_x;
289                     if (idx_x < 0 || idx_x >= input_x)
290                     {
291                         memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
292                     }
293                     else
294                     {
295                         arm_q7_to_q15_with_offset((int8_t *)input + (idx_y * input_x + idx_x) * input_ch,
296                                                   &col_buffer[index],
297                                                   input_ch,
298                                                   (int16_t)input_offset);
299                     }
300                     index += input_ch;
301                 }
302             }
303 
304             const int diff = kernel_y - ker_y_end;
305             if (diff != 0)
306             {
307                 memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
308             }
309 
310             row_count = output_ch / 4;
311             row_shift = 0;
312             col_shift = 0;
313             bias = bias_start_pos;
314             output_mult = out_mult_start_pos;
315             output_shift = out_shift_start_pos;
316 
317             if (output_ch % 2) /* Uneven number of channels */
318             {
319                 int get_low_nibble = 1;
320 
321                 while (row_count)
322                 {
323                     int32_t sum = 0;
324                     int32_t sum_2 = 0;
325                     int32_t sum_3 = 0;
326                     int32_t sum_4 = 0;
327                     if (bias)
328                     {
329                         sum = *bias++;
330                         sum_2 = *bias++;
331                         sum_3 = *bias++;
332                         sum_4 = *bias++;
333                     }
334 
335                     uint16_t col_count = num_cols / 2;
336                     int16_t *col_pos = col_buffer_start + col_shift;
337                     const int8_t *row_pos = kernel + row_shift;
338 
339                     row_shift += 2;
340                     col_shift += 4;
341 
342                     while (col_count)
343                     {
344     #ifdef ARM_MATH_DSP
345                         /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
346                            use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
347                         /* Note: variable names can be improved here to align with rows and columns. */
348                         int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
349 
350                         /* Read 4 weights */
351                         read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
352                         read_and_pad_s4_uneven(row_pos + (input_ch >> 1), &ip_a1, &ip_b2);
353 
354                         op_a = arm_nn_read_s16x2(col_pos);
355                         op_b = arm_nn_read_s16x2(col_pos + input_ch);
356 
357                         op_c = PKHBT(op_b, op_a, 16);
358                         op_a = PKHTB(op_b, op_a, 16);
359                         op_b = PKHBT(ip_b2, ip_a2, 16);
360                         sum = SMLAD(op_c, op_b, sum);
361 
362                         op_b = PKHBT(ip_b1, ip_a1, 16);
363 
364                         sum_2 = SMLAD(op_a, op_b, sum_2);
365 
366                         op_a = arm_nn_read_s16x2(col_pos + 2);
367                         op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
368 
369                         op_c = PKHBT(op_b, op_a, 16);
370                         op_a = PKHTB(op_b, op_a, 16);
371                         op_b = PKHTB(ip_a2, ip_b2, 16);
372                         sum_3 = SMLAD(op_c, op_b, sum_3);
373 
374                         op_b = PKHTB(ip_a1, ip_b1, 16);
375                         sum_4 = SMLAD(op_a, op_b, sum_4);
376 
377     #else
378                         int8_t ker0, ker1, ker2, ker3, ker00, ker11;
379 
380                         ker00 = row_pos[0];
381                         ker11 = row_pos[1];
382                         ker0 = (int8_t)(ker00 << 4) >> 4;
383                         ker1 = ker00 >> 4;
384                         ker2 = (int8_t)(ker11 << 4) >> 4;
385                         ker3 = ker11 >> 4;
386 
387                         sum += ker0 * col_pos[0];
388                         sum_2 += ker1 * col_pos[1];
389                         sum_3 += ker2 * col_pos[2];
390                         sum_4 += ker3 * col_pos[3];
391 
392                         ker11 = row_pos[1 + (input_ch >> 1)];
393                         ker0 = row_pos[0 + (input_ch >> 1)] >> 4;
394                         ker1 = (int8_t)(ker11 << 4) >> 4;
395                         ker2 = ker11 >> 4;
396                         ker3 = (int8_t)(row_pos[2 + (input_ch >> 1)] << 4) >> 4;
397 
398                         sum += ker0 * col_pos[0 + input_ch];
399                         sum_2 += ker1 * col_pos[1 + input_ch];
400                         sum_3 += ker2 * col_pos[2 + input_ch];
401                         sum_4 += ker3 * col_pos[3 + input_ch];
402 
403     #endif
404                         row_pos += (input_ch);
405                         col_pos += input_ch << 1;
406 
407                         col_count--;
408                     }
409 
410                     col_count = num_cols & 0x1;
411 
412                     while (col_count)
413                     {
414                         int8_t ker0, ker1, ker2, ker3, ker00, ker11;
415 
416                         ker00 = row_pos[0];
417                         ker11 = row_pos[1];
418 
419                         ker0 = (int8_t)(ker00 << 4) >> 4;
420                         ker1 = ker00 >> 4;
421 
422                         ker2 = (int8_t)(ker11 << 4) >> 4;
423                         ker3 = ker11 >> 4;
424 
425                         sum += ker0 * col_pos[0];
426                         sum_2 += ker1 * col_pos[1];
427                         sum_3 += ker2 * col_pos[2];
428                         sum_4 += ker3 * col_pos[3];
429 
430                         row_pos += input_ch >> 1;
431                         col_pos += input_ch;
432 
433                         col_count--;
434                     }
435 
436                     sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
437                     sum += output_offset;
438                     sum = MAX(sum, output_activation_min);
439                     sum = MIN(sum, output_activation_max);
440                     *output++ = (int8_t)sum;
441 
442                     sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
443                     sum_2 += output_offset;
444                     sum_2 = MAX(sum_2, output_activation_min);
445                     sum_2 = MIN(sum_2, output_activation_max);
446                     *output++ = (int8_t)sum_2;
447                     sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
448                     sum_3 += output_offset;
449                     sum_3 = MAX(sum_3, output_activation_min);
450                     sum_3 = MIN(sum_3, output_activation_max);
451                     *output++ = (int8_t)sum_3;
452 
453                     sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
454                     sum_4 += output_offset;
455                     sum_4 = MAX(sum_4, output_activation_min);
456                     sum_4 = MIN(sum_4, output_activation_max);
457                     *output++ = (int8_t)sum_4;
458 
459                     row_count--;
460                 }
461 
462                 row_count = output_ch & 0x3;
463 
464                 while (row_count)
465                 {
466                     const int16_t *col_pos = col_buffer_start + col_shift;
467                     const int8_t *row_pos = kernel + row_shift;
468                     int32_t sum = 0;
469                     int col_index = 0;
470 
471                     if (bias)
472                     {
473                         sum = *bias++;
474                     }
475 
476                     col_shift += 1;
477 
478                     for (int i = 0; i < num_cols; i++)
479                     {
480                         int8_t rhs = row_pos[i * (input_ch >> 1) + col_index];
481                         int8_t rhs0;
482                         int16_t lhs0 = col_pos[i * input_ch];
483 
484                         if (get_low_nibble)
485                         {
486                             rhs0 = (int8_t)(rhs << 4) >> 4;
487                             get_low_nibble = 0;
488                         }
489                         else
490                         {
491                             rhs0 = rhs >> 4;
492                             get_low_nibble = 1;
493                             col_index++;
494                         }
495 
496                         sum += rhs0 * lhs0;
497                     }
498 
499                     if (num_cols % 2 == 0)
500                     {
501                         get_low_nibble = !get_low_nibble;
502                     }
503 
504                     sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
505                     sum += output_offset;
506                     sum = MAX(sum, output_activation_min);
507                     sum = MIN(sum, output_activation_max);
508                     *output++ = (int8_t)sum;
509 
510                     row_count--;
511 
512                     /* Last row */
513                     if (row_count == 1)
514                     {
515                         row_shift += 1;
516                     }
517                 }
518             }
519             else /* Even number of channels */
520             {
521                 while (row_count)
522                 {
523                     int32_t sum = 0;
524                     int32_t sum_2 = 0;
525                     int32_t sum_3 = 0;
526                     int32_t sum_4 = 0;
527                     if (bias)
528                     {
529                         sum = *bias++;
530                         sum_2 = *bias++;
531                         sum_3 = *bias++;
532                         sum_4 = *bias++;
533                     }
534 
535                     uint16_t col_count = num_cols / 2;
536                     int16_t *col_pos = col_buffer_start + col_shift;
537                     const int8_t *row_pos = kernel + row_shift;
538 
539                     row_shift += 2;
540                     col_shift += 4;
541 
542     #ifdef ARM_MATH_DSP
543                     while (col_count)
544                     {
545                         /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
546                            use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
547                         /* Note: variable names can be improved here to align with rows and columns. */
548                         int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
549 
550                         /* Read 4 weights */
551                         read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
552                         read_and_pad_s4(row_pos + (input_ch >> 1), &ip_b2, &ip_a1);
553 
554                         op_a = arm_nn_read_s16x2(col_pos);
555                         op_b = arm_nn_read_s16x2(col_pos + input_ch);
556 
557                         op_c = PKHBT(op_b, op_a, 16);
558                         op_a = PKHTB(op_b, op_a, 16);
559                         op_b = PKHBT(ip_b2, ip_a2, 16);
560                         sum = SMLAD(op_c, op_b, sum);
561 
562                         op_b = PKHBT(ip_b1, ip_a1, 16);
563 
564                         sum_2 = SMLAD(op_a, op_b, sum_2);
565 
566                         op_a = arm_nn_read_s16x2(col_pos + 2);
567                         op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
568 
569                         op_c = PKHBT(op_b, op_a, 16);
570                         op_a = PKHTB(op_b, op_a, 16);
571                         op_b = PKHTB(ip_a2, ip_b2, 16);
572                         sum_3 = SMLAD(op_c, op_b, sum_3);
573 
574                         op_b = PKHTB(ip_a1, ip_b1, 16);
575                         sum_4 = SMLAD(op_a, op_b, sum_4);
576 
577                         row_pos += (input_ch);
578                         col_pos += input_ch << 1;
579 
580                         col_count--;
581                     }
582 
583                     col_count = num_cols & 0x1;
584     #else
585                     col_count = num_cols;
586     #endif
587                     while (col_count)
588                     {
589                         int8_t ker0, ker1, ker2, ker3, ker00, ker11;
590 
591                         ker00 = row_pos[0];
592                         ker11 = row_pos[1];
593 
594                         ker0 = (int8_t)(ker00 << 4) >> 4;
595                         ker1 = ker00 >> 4;
596 
597                         ker2 = (int8_t)(ker11 << 4) >> 4;
598                         ker3 = ker11 >> 4;
599 
600                         sum += ker0 * col_pos[0];
601                         sum_2 += ker1 * col_pos[1];
602                         sum_3 += ker2 * col_pos[2];
603                         sum_4 += ker3 * col_pos[3];
604 
605                         row_pos += input_ch >> 1;
606                         col_pos += input_ch;
607 
608                         col_count--;
609                     }
610 
611                     sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
612                     sum += output_offset;
613                     sum = MAX(sum, output_activation_min);
614                     sum = MIN(sum, output_activation_max);
615                     *output++ = (int8_t)sum;
616 
617                     sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
618                     sum_2 += output_offset;
619                     sum_2 = MAX(sum_2, output_activation_min);
620                     sum_2 = MIN(sum_2, output_activation_max);
621                     *output++ = (int8_t)sum_2;
622                     sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
623                     sum_3 += output_offset;
624                     sum_3 = MAX(sum_3, output_activation_min);
625                     sum_3 = MIN(sum_3, output_activation_max);
626                     *output++ = (int8_t)sum_3;
627 
628                     sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
629                     sum_4 += output_offset;
630                     sum_4 = MAX(sum_4, output_activation_min);
631                     sum_4 = MIN(sum_4, output_activation_max);
632                     *output++ = (int8_t)sum_4;
633 
634                     row_count--;
635                 }
636 
637                 if (output_ch & 0x2)
638                 {
639                     const int16_t *col_pos = col_buffer_start + col_shift;
640                     const int16_t *col_pos_2 = col_buffer_start + col_shift + 1;
641                     const int8_t *row_pos = kernel + row_shift;
642                     int32_t sum = 0;
643                     int32_t sum2 = 0;
644 
645                     if (bias)
646                     {
647                         sum = *bias++;
648                         sum2 = *bias++;
649                     }
650 
651                     for (int i = 0; i < num_cols; i++)
652                     {
653                         int8_t rhs = row_pos[i * (input_ch >> 1)];
654 
655                         int8_t rhs_low = (int8_t)(rhs << 4) >> 4;
656                         int8_t rhs_high = rhs >> 4;
657 
658                         int16_t lhs0 = col_pos[i * input_ch];
659                         int16_t lhs1 = col_pos_2[i * input_ch];
660 
661                         sum += rhs_low * lhs0;
662                         sum2 += rhs_high * lhs1;
663                     }
664 
665                     sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
666                     sum += output_offset;
667                     sum = MAX(sum, output_activation_min);
668                     sum = MIN(sum, output_activation_max);
669                     *output++ = (int8_t)sum;
670                     sum2 = arm_nn_requantize(sum2, *output_mult++, *output_shift++);
671                     sum2 += output_offset;
672                     sum2 = MAX(sum2, output_activation_min);
673                     sum2 = MIN(sum2, output_activation_max);
674                     *output++ = (int8_t)sum2;
675                 }
676             }
677 
678             /* Clear counter and pointers */
679             col_buffer = col_buffer_start;
680         }
681     }
682 #endif // ARM_MATH_MVEI
683     /* Return to application */
684     return ARM_CMSIS_NN_SUCCESS;
685 }
686 
687 /**
688  * @} end of NNConv group
689  */
690