1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_depthwise_conv_s8_core.c
22 * Description: Depthwise convolution on im2col buffers.
23 *
24 * $Date: 26 October 2022
25 * $Revision: V.1.0.5
26 *
27 * Target Processor: Cortex-M cores
28 * -------------------------------------------------------------------- */
29
30 #include "arm_nnsupportfunctions.h"
31
32 /*
33 * Depthwise conv on an im2col buffer where the input channel equals
34 * output channel.
35 *
36 * Refer header file for details.
37 *
38 */
39
arm_nn_depthwise_conv_s8_core(const int8_t * row,const int16_t * col,const uint16_t num_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int32_t activation_min,const int32_t activation_max,const uint16_t kernel_size,const int32_t * const output_bias,int8_t * out)40 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
41 const int16_t *col,
42 const uint16_t num_ch,
43 const int32_t *out_shift,
44 const int32_t *out_mult,
45 const int32_t out_offset,
46 const int32_t activation_min,
47 const int32_t activation_max,
48 const uint16_t kernel_size,
49 const int32_t *const output_bias,
50 int8_t *out)
51 {
52 #if defined(ARM_MATH_MVEI)
53 int32_t ch_per_loop = num_ch / 4;
54
55 const int32_t *bias = output_bias;
56 int8_t *out_tmp = out;
57
58 int32_t idx = 0;
59
60 while (ch_per_loop > 0)
61 {
62 int32x4_t ip_0;
63 int32x4_t ip_1;
64 int32_t ker_loop = kernel_size / 3;
65 int32x4_t out_0 = vldrwq_s32(bias);
66 int32x4_t out_1 = out_0;
67 bias += 4;
68
69 const int32_t offset = idx * 4;
70 const int8_t *row_0 = row + offset;
71 const int16_t *col_0 = col + offset;
72 const int16_t *col_1 = col + kernel_size * num_ch + offset;
73
74 int32x4_t ker_0 = vldrbq_s32(row_0);
75
76 while (ker_loop > 0)
77 {
78 const int8_t *row_1 = row_0 + num_ch;
79 const int8_t *row_2 = row_0 + 2 * num_ch;
80 const int32x4_t ker_1 = vldrbq_s32(row_1);
81 const int32x4_t ker_2 = vldrbq_s32(row_2);
82
83 ip_0 = vldrhq_s32(col_0);
84 ip_1 = vldrhq_s32(col_1);
85 col_0 += num_ch;
86 col_1 += num_ch;
87
88 out_0 += vmulq_s32(ip_0, ker_0);
89 out_1 += vmulq_s32(ip_1, ker_0);
90
91 ip_0 = vldrhq_s32(col_0);
92 ip_1 = vldrhq_s32(col_1);
93 col_0 += num_ch;
94 col_1 += num_ch;
95
96 out_0 += vmulq_s32(ip_0, ker_1);
97 out_1 += vmulq_s32(ip_1, ker_1);
98
99 ip_0 = vldrhq_s32(col_0);
100 ip_1 = vldrhq_s32(col_1);
101 col_0 += num_ch;
102 col_1 += num_ch;
103
104 out_0 += vmulq_s32(ip_0, ker_2);
105 out_1 += vmulq_s32(ip_1, ker_2);
106 row_0 += 3 * num_ch;
107
108 ker_0 = vldrbq_s32(row_0);
109 ker_loop--;
110 }
111
112 idx++;
113 /* Handle tail kernel elements */
114 ker_loop = kernel_size - ((kernel_size / 3) * 3);
115 while (ker_loop > 0)
116 {
117 ip_0 = vldrhq_s32(col_0);
118 ip_1 = vldrhq_s32(col_1);
119
120 out_0 += vmulq_s32(ip_0, ker_0);
121 out_1 += vmulq_s32(ip_1, ker_0);
122
123 col_0 += num_ch;
124 col_1 += num_ch;
125
126 ip_0 = vldrhq_s32(col_0);
127 ip_1 = vldrhq_s32(col_1);
128
129 row_0 += num_ch;
130 ker_0 = vldrbq_s32(row_0);
131 ker_loop--;
132 }
133 const int32x4_t mult = vldrwq_s32(out_mult);
134 const int32x4_t shift = vldrwq_s32(out_shift);
135 out_mult += 4;
136 out_shift += 4;
137
138 out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
139 out_1 = arm_requantize_mve_32x4(out_1, mult, shift);
140
141 out_0 = vaddq_n_s32(out_0, out_offset);
142 out_0 = vmaxq_s32(out_0, vdupq_n_s32(activation_min));
143 out_0 = vminq_s32(out_0, vdupq_n_s32(activation_max));
144 vstrbq_s32(out_tmp, out_0);
145
146 out_1 = vaddq_n_s32(out_1, out_offset);
147 out_1 = vmaxq_s32(out_1, vdupq_n_s32(activation_min));
148 out_1 = vminq_s32(out_1, vdupq_n_s32(activation_max));
149 vstrbq_s32(out_tmp + num_ch, out_1);
150
151 out_tmp += 4;
152 ch_per_loop--;
153 }
154
155 int32_t tail_ch = num_ch & 3;
156 if (tail_ch != 0)
157 {
158 int32_t ch_idx = (num_ch & ~3);
159 int32x4_t col_0_sum;
160 int32x4_t col_1_sum;
161
162 const int32_t single_buffer_size = kernel_size * num_ch;
163 for (int i = 0; i < tail_ch; i++)
164 {
165 const int16_t *col_pos_0 = col + ch_idx;
166 const int16_t *col_pos_1 = col_pos_0 + single_buffer_size;
167
168 const int8_t *row_pos = row + ch_idx;
169 int32_t sum_0 = bias[i];
170 int32_t sum_1 = bias[i];
171
172 for (int j = 0; j < kernel_size; j++)
173 {
174 const int8_t row_val = row_pos[j * num_ch];
175 sum_0 += row_val * col_pos_0[j * num_ch];
176 sum_1 += row_val * col_pos_1[j * num_ch];
177 }
178 col_0_sum[i] = sum_0;
179 col_1_sum[i] = sum_1;
180
181 ch_idx++;
182 }
183 const mve_pred16_t p = vctp32q((uint32_t)tail_ch);
184 const int32x4_t mult = vldrwq_z_s32(out_mult, p);
185 const int32x4_t shift = vldrwq_z_s32(out_shift, p);
186
187 col_0_sum = arm_requantize_mve_32x4(col_0_sum, mult, shift);
188 col_1_sum = arm_requantize_mve_32x4(col_1_sum, mult, shift);
189
190 col_0_sum = vaddq_n_s32(col_0_sum, out_offset);
191 col_0_sum = vmaxq_s32(col_0_sum, vdupq_n_s32(activation_min));
192 col_0_sum = vminq_s32(col_0_sum, vdupq_n_s32(activation_max));
193 vstrbq_p_s32(out_tmp, col_0_sum, p);
194
195 col_1_sum = vaddq_n_s32(col_1_sum, out_offset);
196 col_1_sum = vmaxq_s32(col_1_sum, vdupq_n_s32(activation_min));
197 col_1_sum = vminq_s32(col_1_sum, vdupq_n_s32(activation_max));
198 vstrbq_p_s32(out_tmp + num_ch, col_1_sum, p);
199
200 out_tmp += tail_ch;
201 }
202
203 return out_tmp + num_ch;
204 #else
205 (void)row;
206 (void)col;
207 (void)num_ch;
208 (void)out_shift;
209 (void)out_mult;
210 (void)out_offset;
211 (void)activation_min;
212 (void)activation_max;
213 (void)kernel_size;
214 (void)output_bias;
215 (void)out;
216 return NULL;
217 #endif
218 }
219