1 /*
2 * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_conv_s4_opt.c
22 * Description: Optimized s4 depthwise separable convolution function for
23 * channel multiplier of 1.
24 *
25 * $Date: 17 April 2024
26 * $Revision: V.1.1.0
27 *
28 * Target : Arm(R) M-Profile Architecture
29 *
30 * -------------------------------------------------------------------- */
31
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup Public
37 */
38
39 /**
40 * @addtogroup NNConv
41 * @{
42 */
43
44 /*
45 * Optimized s4 depthwise convolution function with constraint that in_channel equals out_channel
46 *
47 * Refer prototype header file for details.
48 *
49 */
50
arm_depthwise_conv_s4_opt(const cmsis_nn_context * ctx,const cmsis_nn_dw_conv_params * dw_conv_params,const cmsis_nn_per_channel_quant_params * quant_params,const cmsis_nn_dims * input_dims,const int8_t * input,const cmsis_nn_dims * filter_dims,const int8_t * kernel,const cmsis_nn_dims * bias_dims,const int32_t * bias,const cmsis_nn_dims * output_dims,int8_t * output)51 arm_cmsis_nn_status arm_depthwise_conv_s4_opt(const cmsis_nn_context *ctx,
52 const cmsis_nn_dw_conv_params *dw_conv_params,
53 const cmsis_nn_per_channel_quant_params *quant_params,
54 const cmsis_nn_dims *input_dims,
55 const int8_t *input,
56 const cmsis_nn_dims *filter_dims,
57 const int8_t *kernel,
58 const cmsis_nn_dims *bias_dims,
59 const int32_t *bias,
60 const cmsis_nn_dims *output_dims,
61 int8_t *output)
62 {
63 (void)bias_dims;
64
65 const int32_t input_ch = input_dims->c;
66 const int32_t output_ch = output_dims->c;
67
68 /* Check depth multiplier is 1 */
69 if (input_ch != output_ch)
70 {
71 return ARM_CMSIS_NN_ARG_ERROR;
72 }
73
74 if (ctx->buf == NULL)
75 {
76 return ARM_CMSIS_NN_ARG_ERROR;
77 }
78
79 const int32_t input_x = input_dims->w;
80 const int32_t input_y = input_dims->h;
81 const int32_t kernel_x = filter_dims->w;
82 const int32_t kernel_y = filter_dims->h;
83 const int32_t pad_x = dw_conv_params->padding.w;
84 const int32_t pad_y = dw_conv_params->padding.h;
85 const int32_t stride_x = dw_conv_params->stride.w;
86 const int32_t stride_y = dw_conv_params->stride.h;
87 const int32_t *output_shift = quant_params->shift;
88 const int32_t *output_mult = quant_params->multiplier;
89 const int32_t output_x = output_dims->w;
90 const int32_t output_y = output_dims->h;
91 const int32_t output_offset = dw_conv_params->output_offset;
92 const int32_t input_offset = dw_conv_params->input_offset;
93 const int32_t output_activation_min = dw_conv_params->activation.min;
94 const int32_t output_activation_max = dw_conv_params->activation.max;
95 int16_t *buffer_a = (int16_t *)ctx->buf;
96
97 #ifdef ARM_MATH_MVEI
98 /* Generate two columns from the input tensor */
99 int8_t *lhs_buffer = (int8_t *)buffer_a;
100 int8_t *out = output;
101 int buffer_count = 0;
102 const int32_t kernel_size = kernel_x * kernel_y;
103
104 const int32_t ch_loop = (input_ch + (S4_CH_IN_BLOCK_MVE - 1)) / S4_CH_IN_BLOCK_MVE;
105 int32_t remaining_ch = output_ch;
106 int32_t active_ch = MIN(S4_CH_IN_BLOCK_MVE, remaining_ch);
107 remaining_ch -= S4_CH_IN_BLOCK_MVE;
108
109 for (int i_ch = 0; i_ch < ch_loop; i_ch++)
110 {
111 out = output + i_ch * S4_CH_IN_BLOCK_MVE;
112 const int8_t *input_slice = input + (i_ch * S4_CH_IN_BLOCK_MVE);
113
114 for (int i_out_y = 0, base_idx_y = -pad_y; i_out_y < output_y; base_idx_y += stride_y, i_out_y++)
115 {
116 for (int i_out_x = 0, base_idx_x = -pad_x; i_out_x < output_x; base_idx_x += stride_x, i_out_x++)
117 {
118 for (int i_ker_y = base_idx_y; i_ker_y < base_idx_y + kernel_y; i_ker_y++)
119 {
120 for (int i_ker_x = base_idx_x; i_ker_x < base_idx_x + kernel_x; i_ker_x++)
121 {
122 if (i_ker_y < 0 || i_ker_y >= input_y || i_ker_x < 0 || i_ker_x >= input_x)
123 {
124 arm_memset_s8(lhs_buffer, (int8_t)-input_offset, (uint32_t)active_ch);
125 }
126 else
127 {
128 arm_memcpy_s8(lhs_buffer,
129 input_slice + (i_ker_y * input_x + i_ker_x) * input_ch,
130 (uint32_t)active_ch);
131 }
132 lhs_buffer += S4_CH_IN_BLOCK_MVE;
133 }
134 }
135 buffer_count++;
136
137 if (buffer_count == 4)
138 {
139 const int32_t block_offset = i_ch * S4_CH_IN_BLOCK_MVE;
140 lhs_buffer = (int8_t *)buffer_a;
141 arm_nn_depthwise_conv_nt_t_s4(lhs_buffer,
142 kernel + (block_offset >> 1),
143 input_offset,
144 active_ch,
145 input_ch,
146 output_shift + block_offset,
147 output_mult + block_offset,
148 output_offset,
149 output_activation_min,
150 output_activation_max,
151 kernel_size,
152 bias + block_offset,
153 out);
154
155 out += (4 * input_ch);
156 buffer_count = 0;
157 }
158 }
159 }
160 /* Handle left over buffers */
161 lhs_buffer = (int8_t *)buffer_a;
162
163 int8_t *out_base = out;
164 const uint32x4_t gather_offset = {0, 0, 1, 1};
165 const mve_pred16_t lower_nibble_mask = 3855; // 0000111100001111
166 for (int i_buf = 0; i_buf < buffer_count; i_buf++)
167 {
168 int32_t loop_count = (active_ch + 3) / 4;
169 int32_t num_ch_to_process = active_ch;
170 out = out_base + (i_buf * input_ch);
171 for (int i_loop_cnt = 0, offset = i_ch * S4_CH_IN_BLOCK_MVE; i_loop_cnt < loop_count;
172 num_ch_to_process -= 4, offset += 4, i_loop_cnt++)
173 {
174 const int8_t *col_0 = lhs_buffer + (kernel_size * S4_CH_IN_BLOCK_MVE * i_buf) + (i_loop_cnt * 4);
175 const int8_t *row_0 = kernel + (offset >> 1);
176 int32x4_t out_0 = vdupq_n_s32(0);
177 if (bias)
178 {
179 out_0 = vldrwq_s32(&bias[offset]);
180 }
181
182 if (input_ch % 2)
183 {
184 int get_low_nibble = 1;
185 for (int i_ker = 0; i_ker < kernel_size; i_ker++)
186 {
187 int32x4_t ker_0;
188 if (get_low_nibble)
189 {
190 ker_0 = vldrbq_gather_offset_s32(row_0, gather_offset);
191 ker_0 = vrshlq_m_n_s32(ker_0, 28, lower_nibble_mask);
192 ker_0 = vshrq_m_n_s32(ker_0, ker_0, 24, lower_nibble_mask);
193
194 ker_0 = vshrq_n_s32(ker_0, 4);
195 }
196 else
197 {
198 int8_t temp[] = {row_0[0] >> 4,
199 (int8_t)(row_0[1] << 4) >> 4,
200 row_0[1] >> 4,
201 (int8_t)(row_0[2] << 4) >> 4};
202 ker_0 = vldrbq_s32(temp);
203 }
204
205 int32x4_t ip_0 = vldrbq_s32(col_0);
206 ip_0 = vaddq_n_s32(ip_0, input_offset);
207 out_0 += vmulq_s32(ip_0, ker_0);
208
209 get_low_nibble = !get_low_nibble;
210 col_0 += S4_CH_IN_BLOCK_MVE;
211 row_0 += (input_ch >> 1) + get_low_nibble;
212 }
213 }
214 else
215 {
216 for (int i_ker = 0; i_ker < kernel_size; i_ker++)
217 {
218 int32x4_t ker_0 = vldrbq_gather_offset_s32(row_0, gather_offset);
219 ker_0 = vrshlq_m_n_s32(ker_0, 28, lower_nibble_mask);
220 ker_0 = vshrq_m_n_s32(ker_0, ker_0, 24, lower_nibble_mask);
221
222 ker_0 = vshrq_n_s32(ker_0, 4);
223
224 int32x4_t ip_0 = vldrbq_s32(col_0);
225 ip_0 = vaddq_n_s32(ip_0, input_offset);
226 out_0 += vmulq_s32(ip_0, ker_0);
227
228 col_0 += S4_CH_IN_BLOCK_MVE;
229 row_0 += input_ch >> 1;
230 }
231 }
232
233 const int32x4_t mult = vldrwq_s32(&output_mult[offset]);
234 const int32x4_t shift = vldrwq_s32(&output_shift[offset]);
235
236 out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
237 out_0 = vaddq_n_s32(out_0, output_offset);
238 out_0 = vmaxq_s32(out_0, vdupq_n_s32(output_activation_min));
239 out_0 = vminq_s32(out_0, vdupq_n_s32(output_activation_max));
240 mve_pred16_t p = vctp32q((uint32_t)num_ch_to_process);
241 vstrbq_p_s32(out, out_0, p);
242
243 out += 4;
244 }
245 }
246 buffer_count = 0;
247
248 active_ch = MIN(S4_CH_IN_BLOCK_MVE, remaining_ch);
249 remaining_ch -= S4_CH_IN_BLOCK_MVE;
250 }
251 #else
252 int16_t *const col_buffer_start = buffer_a;
253 int16_t *col_buffer = col_buffer_start;
254 const int32_t *const bias_start_pos = bias;
255 const int32_t *const out_mult_start_pos = output_mult;
256 const int32_t *const out_shift_start_pos = output_shift;
257 const uint16_t num_cols = kernel_x * kernel_y;
258 uint16_t row_count;
259 uint16_t row_shift = 0;
260 uint16_t col_shift = 0;
261
262 for (int i_out_y = 0; i_out_y < output_y; i_out_y++)
263 {
264 const int16_t base_idx_y = (i_out_y * stride_y) - pad_y;
265 for (int i_out_x = 0; i_out_x < output_x; i_out_x++)
266 {
267 const int16_t base_idx_x = (i_out_x * stride_x) - pad_x;
268
269 /* Out of bounds is only considered for the y axis as it provides a contiguous zero'ing opportunity than
270 along the x axis */
271 const int ker_y_start = MAX(0, -base_idx_y);
272 /* Condition for kernel end dimension: (base_idx_y + ker_y_end) < input_y */
273 const int ker_y_end = MIN(kernel_y, input_y - base_idx_y);
274
275 int32_t index = 0;
276 if (ker_y_start != 0)
277 {
278 memset(&col_buffer[index], 0, (kernel_x * input_ch) * ker_y_start * sizeof(int16_t));
279 index += (kernel_x * input_ch) * ker_y_start;
280 }
281
282 for (int i_ker_y = ker_y_start; i_ker_y < ker_y_end; i_ker_y++)
283 {
284 const int32_t idx_y = base_idx_y + i_ker_y;
285
286 for (int i_ker_x = 0; i_ker_x < kernel_x; i_ker_x++)
287 {
288 const int32_t idx_x = base_idx_x + i_ker_x;
289 if (idx_x < 0 || idx_x >= input_x)
290 {
291 memset(&col_buffer[index], 0, input_ch * sizeof(int16_t));
292 }
293 else
294 {
295 arm_q7_to_q15_with_offset((int8_t *)input + (idx_y * input_x + idx_x) * input_ch,
296 &col_buffer[index],
297 input_ch,
298 (int16_t)input_offset);
299 }
300 index += input_ch;
301 }
302 }
303
304 const int diff = kernel_y - ker_y_end;
305 if (diff != 0)
306 {
307 memset(&col_buffer[index], 0, (kernel_x * input_ch) * diff * sizeof(int16_t));
308 }
309
310 row_count = output_ch / 4;
311 row_shift = 0;
312 col_shift = 0;
313 bias = bias_start_pos;
314 output_mult = out_mult_start_pos;
315 output_shift = out_shift_start_pos;
316
317 if (output_ch % 2) /* Uneven number of channels */
318 {
319 int get_low_nibble = 1;
320
321 while (row_count)
322 {
323 int32_t sum = 0;
324 int32_t sum_2 = 0;
325 int32_t sum_3 = 0;
326 int32_t sum_4 = 0;
327 if (bias)
328 {
329 sum = *bias++;
330 sum_2 = *bias++;
331 sum_3 = *bias++;
332 sum_4 = *bias++;
333 }
334
335 uint16_t col_count = num_cols / 2;
336 int16_t *col_pos = col_buffer_start + col_shift;
337 const int8_t *row_pos = kernel + row_shift;
338
339 row_shift += 2;
340 col_shift += 4;
341
342 while (col_count)
343 {
344 #ifdef ARM_MATH_DSP
345 /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
346 use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
347 /* Note: variable names can be improved here to align with rows and columns. */
348 int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
349
350 /* Read 4 weights */
351 read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
352 read_and_pad_s4_uneven(row_pos + (input_ch >> 1), &ip_a1, &ip_b2);
353
354 op_a = arm_nn_read_s16x2(col_pos);
355 op_b = arm_nn_read_s16x2(col_pos + input_ch);
356
357 op_c = PKHBT(op_b, op_a, 16);
358 op_a = PKHTB(op_b, op_a, 16);
359 op_b = PKHBT(ip_b2, ip_a2, 16);
360 sum = SMLAD(op_c, op_b, sum);
361
362 op_b = PKHBT(ip_b1, ip_a1, 16);
363
364 sum_2 = SMLAD(op_a, op_b, sum_2);
365
366 op_a = arm_nn_read_s16x2(col_pos + 2);
367 op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
368
369 op_c = PKHBT(op_b, op_a, 16);
370 op_a = PKHTB(op_b, op_a, 16);
371 op_b = PKHTB(ip_a2, ip_b2, 16);
372 sum_3 = SMLAD(op_c, op_b, sum_3);
373
374 op_b = PKHTB(ip_a1, ip_b1, 16);
375 sum_4 = SMLAD(op_a, op_b, sum_4);
376
377 #else
378 int8_t ker0, ker1, ker2, ker3, ker00, ker11;
379
380 ker00 = row_pos[0];
381 ker11 = row_pos[1];
382 ker0 = (int8_t)(ker00 << 4) >> 4;
383 ker1 = ker00 >> 4;
384 ker2 = (int8_t)(ker11 << 4) >> 4;
385 ker3 = ker11 >> 4;
386
387 sum += ker0 * col_pos[0];
388 sum_2 += ker1 * col_pos[1];
389 sum_3 += ker2 * col_pos[2];
390 sum_4 += ker3 * col_pos[3];
391
392 ker11 = row_pos[1 + (input_ch >> 1)];
393 ker0 = row_pos[0 + (input_ch >> 1)] >> 4;
394 ker1 = (int8_t)(ker11 << 4) >> 4;
395 ker2 = ker11 >> 4;
396 ker3 = (int8_t)(row_pos[2 + (input_ch >> 1)] << 4) >> 4;
397
398 sum += ker0 * col_pos[0 + input_ch];
399 sum_2 += ker1 * col_pos[1 + input_ch];
400 sum_3 += ker2 * col_pos[2 + input_ch];
401 sum_4 += ker3 * col_pos[3 + input_ch];
402
403 #endif
404 row_pos += (input_ch);
405 col_pos += input_ch << 1;
406
407 col_count--;
408 }
409
410 col_count = num_cols & 0x1;
411
412 while (col_count)
413 {
414 int8_t ker0, ker1, ker2, ker3, ker00, ker11;
415
416 ker00 = row_pos[0];
417 ker11 = row_pos[1];
418
419 ker0 = (int8_t)(ker00 << 4) >> 4;
420 ker1 = ker00 >> 4;
421
422 ker2 = (int8_t)(ker11 << 4) >> 4;
423 ker3 = ker11 >> 4;
424
425 sum += ker0 * col_pos[0];
426 sum_2 += ker1 * col_pos[1];
427 sum_3 += ker2 * col_pos[2];
428 sum_4 += ker3 * col_pos[3];
429
430 row_pos += input_ch >> 1;
431 col_pos += input_ch;
432
433 col_count--;
434 }
435
436 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
437 sum += output_offset;
438 sum = MAX(sum, output_activation_min);
439 sum = MIN(sum, output_activation_max);
440 *output++ = (int8_t)sum;
441
442 sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
443 sum_2 += output_offset;
444 sum_2 = MAX(sum_2, output_activation_min);
445 sum_2 = MIN(sum_2, output_activation_max);
446 *output++ = (int8_t)sum_2;
447 sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
448 sum_3 += output_offset;
449 sum_3 = MAX(sum_3, output_activation_min);
450 sum_3 = MIN(sum_3, output_activation_max);
451 *output++ = (int8_t)sum_3;
452
453 sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
454 sum_4 += output_offset;
455 sum_4 = MAX(sum_4, output_activation_min);
456 sum_4 = MIN(sum_4, output_activation_max);
457 *output++ = (int8_t)sum_4;
458
459 row_count--;
460 }
461
462 row_count = output_ch & 0x3;
463
464 while (row_count)
465 {
466 const int16_t *col_pos = col_buffer_start + col_shift;
467 const int8_t *row_pos = kernel + row_shift;
468 int32_t sum = 0;
469 int col_index = 0;
470
471 if (bias)
472 {
473 sum = *bias++;
474 }
475
476 col_shift += 1;
477
478 for (int i = 0; i < num_cols; i++)
479 {
480 int8_t rhs = row_pos[i * (input_ch >> 1) + col_index];
481 int8_t rhs0;
482 int16_t lhs0 = col_pos[i * input_ch];
483
484 if (get_low_nibble)
485 {
486 rhs0 = (int8_t)(rhs << 4) >> 4;
487 get_low_nibble = 0;
488 }
489 else
490 {
491 rhs0 = rhs >> 4;
492 get_low_nibble = 1;
493 col_index++;
494 }
495
496 sum += rhs0 * lhs0;
497 }
498
499 if (num_cols % 2 == 0)
500 {
501 get_low_nibble = !get_low_nibble;
502 }
503
504 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
505 sum += output_offset;
506 sum = MAX(sum, output_activation_min);
507 sum = MIN(sum, output_activation_max);
508 *output++ = (int8_t)sum;
509
510 row_count--;
511
512 /* Last row */
513 if (row_count == 1)
514 {
515 row_shift += 1;
516 }
517 }
518 }
519 else /* Even number of channels */
520 {
521 while (row_count)
522 {
523 int32_t sum = 0;
524 int32_t sum_2 = 0;
525 int32_t sum_3 = 0;
526 int32_t sum_4 = 0;
527 if (bias)
528 {
529 sum = *bias++;
530 sum_2 = *bias++;
531 sum_3 = *bias++;
532 sum_4 = *bias++;
533 }
534
535 uint16_t col_count = num_cols / 2;
536 int16_t *col_pos = col_buffer_start + col_shift;
537 const int8_t *row_pos = kernel + row_shift;
538
539 row_shift += 2;
540 col_shift += 4;
541
542 #ifdef ARM_MATH_DSP
543 while (col_count)
544 {
545 /* General idea is to read 4 + 4 (input, kernel) pair and re-arrange them in the right order to
546 use in a SMLAD instruction . One run of this loop produces 4 partial outputs with 8 MACs. */
547 /* Note: variable names can be improved here to align with rows and columns. */
548 int32_t ip_a1, ip_a2, ip_b1, ip_b2, op_a, op_b, op_c;
549
550 /* Read 4 weights */
551 read_and_pad_s4(row_pos, &ip_a2, &ip_b1);
552 read_and_pad_s4(row_pos + (input_ch >> 1), &ip_b2, &ip_a1);
553
554 op_a = arm_nn_read_s16x2(col_pos);
555 op_b = arm_nn_read_s16x2(col_pos + input_ch);
556
557 op_c = PKHBT(op_b, op_a, 16);
558 op_a = PKHTB(op_b, op_a, 16);
559 op_b = PKHBT(ip_b2, ip_a2, 16);
560 sum = SMLAD(op_c, op_b, sum);
561
562 op_b = PKHBT(ip_b1, ip_a1, 16);
563
564 sum_2 = SMLAD(op_a, op_b, sum_2);
565
566 op_a = arm_nn_read_s16x2(col_pos + 2);
567 op_b = arm_nn_read_s16x2(col_pos + input_ch + 2);
568
569 op_c = PKHBT(op_b, op_a, 16);
570 op_a = PKHTB(op_b, op_a, 16);
571 op_b = PKHTB(ip_a2, ip_b2, 16);
572 sum_3 = SMLAD(op_c, op_b, sum_3);
573
574 op_b = PKHTB(ip_a1, ip_b1, 16);
575 sum_4 = SMLAD(op_a, op_b, sum_4);
576
577 row_pos += (input_ch);
578 col_pos += input_ch << 1;
579
580 col_count--;
581 }
582
583 col_count = num_cols & 0x1;
584 #else
585 col_count = num_cols;
586 #endif
587 while (col_count)
588 {
589 int8_t ker0, ker1, ker2, ker3, ker00, ker11;
590
591 ker00 = row_pos[0];
592 ker11 = row_pos[1];
593
594 ker0 = (int8_t)(ker00 << 4) >> 4;
595 ker1 = ker00 >> 4;
596
597 ker2 = (int8_t)(ker11 << 4) >> 4;
598 ker3 = ker11 >> 4;
599
600 sum += ker0 * col_pos[0];
601 sum_2 += ker1 * col_pos[1];
602 sum_3 += ker2 * col_pos[2];
603 sum_4 += ker3 * col_pos[3];
604
605 row_pos += input_ch >> 1;
606 col_pos += input_ch;
607
608 col_count--;
609 }
610
611 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
612 sum += output_offset;
613 sum = MAX(sum, output_activation_min);
614 sum = MIN(sum, output_activation_max);
615 *output++ = (int8_t)sum;
616
617 sum_2 = arm_nn_requantize(sum_2, *output_mult++, *output_shift++);
618 sum_2 += output_offset;
619 sum_2 = MAX(sum_2, output_activation_min);
620 sum_2 = MIN(sum_2, output_activation_max);
621 *output++ = (int8_t)sum_2;
622 sum_3 = arm_nn_requantize(sum_3, *output_mult++, *output_shift++);
623 sum_3 += output_offset;
624 sum_3 = MAX(sum_3, output_activation_min);
625 sum_3 = MIN(sum_3, output_activation_max);
626 *output++ = (int8_t)sum_3;
627
628 sum_4 = arm_nn_requantize(sum_4, *output_mult++, *output_shift++);
629 sum_4 += output_offset;
630 sum_4 = MAX(sum_4, output_activation_min);
631 sum_4 = MIN(sum_4, output_activation_max);
632 *output++ = (int8_t)sum_4;
633
634 row_count--;
635 }
636
637 if (output_ch & 0x2)
638 {
639 const int16_t *col_pos = col_buffer_start + col_shift;
640 const int16_t *col_pos_2 = col_buffer_start + col_shift + 1;
641 const int8_t *row_pos = kernel + row_shift;
642 int32_t sum = 0;
643 int32_t sum2 = 0;
644
645 if (bias)
646 {
647 sum = *bias++;
648 sum2 = *bias++;
649 }
650
651 for (int i = 0; i < num_cols; i++)
652 {
653 int8_t rhs = row_pos[i * (input_ch >> 1)];
654
655 int8_t rhs_low = (int8_t)(rhs << 4) >> 4;
656 int8_t rhs_high = rhs >> 4;
657
658 int16_t lhs0 = col_pos[i * input_ch];
659 int16_t lhs1 = col_pos_2[i * input_ch];
660
661 sum += rhs_low * lhs0;
662 sum2 += rhs_high * lhs1;
663 }
664
665 sum = arm_nn_requantize(sum, *output_mult++, *output_shift++);
666 sum += output_offset;
667 sum = MAX(sum, output_activation_min);
668 sum = MIN(sum, output_activation_max);
669 *output++ = (int8_t)sum;
670 sum2 = arm_nn_requantize(sum2, *output_mult++, *output_shift++);
671 sum2 += output_offset;
672 sum2 = MAX(sum2, output_activation_min);
673 sum2 = MIN(sum2, output_activation_max);
674 *output++ = (int8_t)sum2;
675 }
676 }
677
678 /* Clear counter and pointers */
679 col_buffer = col_buffer_start;
680 }
681 }
682 #endif // ARM_MATH_MVEI
683 /* Return to application */
684 return ARM_CMSIS_NN_SUCCESS;
685 }
686
687 /**
688 * @} end of NNConv group
689 */
690