1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_depthwise_conv_s8_core.c
22  * Description:  Depthwise convolution on im2col buffers.
23  *
24  * $Date:        26 October 2022
25  * $Revision:    V.1.0.5
26  *
27  * Target Processor:  Cortex-M cores
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnsupportfunctions.h"
31 
32 /*
33  * Depthwise conv on an im2col buffer where the input channel equals
34  * output channel.
35  *
36  * Refer header file for details.
37  *
38  */
39 
arm_nn_depthwise_conv_s8_core(const int8_t * row,const int16_t * col,const uint16_t num_ch,const int32_t * out_shift,const int32_t * out_mult,const int32_t out_offset,const int32_t activation_min,const int32_t activation_max,const uint16_t kernel_size,const int32_t * const output_bias,int8_t * out)40 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
41                                       const int16_t *col,
42                                       const uint16_t num_ch,
43                                       const int32_t *out_shift,
44                                       const int32_t *out_mult,
45                                       const int32_t out_offset,
46                                       const int32_t activation_min,
47                                       const int32_t activation_max,
48                                       const uint16_t kernel_size,
49                                       const int32_t *const output_bias,
50                                       int8_t *out)
51 {
52 #if defined(ARM_MATH_MVEI)
53     int32_t ch_per_loop = num_ch / 4;
54 
55     const int32_t *bias = output_bias;
56     int8_t *out_tmp = out;
57 
58     int32_t idx = 0;
59 
60     while (ch_per_loop > 0)
61     {
62         int32x4_t ip_0;
63         int32x4_t ip_1;
64         int32_t ker_loop = kernel_size / 3;
65         int32x4_t out_0 = vldrwq_s32(bias);
66         int32x4_t out_1 = out_0;
67         bias += 4;
68 
69         const int32_t offset = idx * 4;
70         const int8_t *row_0 = row + offset;
71         const int16_t *col_0 = col + offset;
72         const int16_t *col_1 = col + kernel_size * num_ch + offset;
73 
74         int32x4_t ker_0 = vldrbq_s32(row_0);
75 
76         while (ker_loop > 0)
77         {
78             const int8_t *row_1 = row_0 + num_ch;
79             const int8_t *row_2 = row_0 + 2 * num_ch;
80             const int32x4_t ker_1 = vldrbq_s32(row_1);
81             const int32x4_t ker_2 = vldrbq_s32(row_2);
82 
83             ip_0 = vldrhq_s32(col_0);
84             ip_1 = vldrhq_s32(col_1);
85             col_0 += num_ch;
86             col_1 += num_ch;
87 
88             out_0 += vmulq_s32(ip_0, ker_0);
89             out_1 += vmulq_s32(ip_1, ker_0);
90 
91             ip_0 = vldrhq_s32(col_0);
92             ip_1 = vldrhq_s32(col_1);
93             col_0 += num_ch;
94             col_1 += num_ch;
95 
96             out_0 += vmulq_s32(ip_0, ker_1);
97             out_1 += vmulq_s32(ip_1, ker_1);
98 
99             ip_0 = vldrhq_s32(col_0);
100             ip_1 = vldrhq_s32(col_1);
101             col_0 += num_ch;
102             col_1 += num_ch;
103 
104             out_0 += vmulq_s32(ip_0, ker_2);
105             out_1 += vmulq_s32(ip_1, ker_2);
106             row_0 += 3 * num_ch;
107 
108             ker_0 = vldrbq_s32(row_0);
109             ker_loop--;
110         }
111 
112         idx++;
113         /* Handle tail kernel elements */
114         ker_loop = kernel_size - ((kernel_size / 3) * 3);
115         while (ker_loop > 0)
116         {
117             ip_0 = vldrhq_s32(col_0);
118             ip_1 = vldrhq_s32(col_1);
119 
120             out_0 += vmulq_s32(ip_0, ker_0);
121             out_1 += vmulq_s32(ip_1, ker_0);
122 
123             col_0 += num_ch;
124             col_1 += num_ch;
125 
126             ip_0 = vldrhq_s32(col_0);
127             ip_1 = vldrhq_s32(col_1);
128 
129             row_0 += num_ch;
130             ker_0 = vldrbq_s32(row_0);
131             ker_loop--;
132         }
133         const int32x4_t mult = vldrwq_s32(out_mult);
134         const int32x4_t shift = vldrwq_s32(out_shift);
135         out_mult += 4;
136         out_shift += 4;
137 
138         out_0 = arm_requantize_mve_32x4(out_0, mult, shift);
139         out_1 = arm_requantize_mve_32x4(out_1, mult, shift);
140 
141         out_0 = vaddq_n_s32(out_0, out_offset);
142         out_0 = vmaxq_s32(out_0, vdupq_n_s32(activation_min));
143         out_0 = vminq_s32(out_0, vdupq_n_s32(activation_max));
144         vstrbq_s32(out_tmp, out_0);
145 
146         out_1 = vaddq_n_s32(out_1, out_offset);
147         out_1 = vmaxq_s32(out_1, vdupq_n_s32(activation_min));
148         out_1 = vminq_s32(out_1, vdupq_n_s32(activation_max));
149         vstrbq_s32(out_tmp + num_ch, out_1);
150 
151         out_tmp += 4;
152         ch_per_loop--;
153     }
154 
155     int32_t tail_ch = num_ch & 3;
156     if (tail_ch != 0)
157     {
158         int32_t ch_idx = (num_ch & ~3);
159         int32x4_t col_0_sum;
160         int32x4_t col_1_sum;
161 
162         const int32_t single_buffer_size = kernel_size * num_ch;
163         for (int i = 0; i < tail_ch; i++)
164         {
165             const int16_t *col_pos_0 = col + ch_idx;
166             const int16_t *col_pos_1 = col_pos_0 + single_buffer_size;
167 
168             const int8_t *row_pos = row + ch_idx;
169             int32_t sum_0 = bias[i];
170             int32_t sum_1 = bias[i];
171 
172             for (int j = 0; j < kernel_size; j++)
173             {
174                 const int8_t row_val = row_pos[j * num_ch];
175                 sum_0 += row_val * col_pos_0[j * num_ch];
176                 sum_1 += row_val * col_pos_1[j * num_ch];
177             }
178             col_0_sum[i] = sum_0;
179             col_1_sum[i] = sum_1;
180 
181             ch_idx++;
182         }
183         const mve_pred16_t p = vctp32q((uint32_t)tail_ch);
184         const int32x4_t mult = vldrwq_z_s32(out_mult, p);
185         const int32x4_t shift = vldrwq_z_s32(out_shift, p);
186 
187         col_0_sum = arm_requantize_mve_32x4(col_0_sum, mult, shift);
188         col_1_sum = arm_requantize_mve_32x4(col_1_sum, mult, shift);
189 
190         col_0_sum = vaddq_n_s32(col_0_sum, out_offset);
191         col_0_sum = vmaxq_s32(col_0_sum, vdupq_n_s32(activation_min));
192         col_0_sum = vminq_s32(col_0_sum, vdupq_n_s32(activation_max));
193         vstrbq_p_s32(out_tmp, col_0_sum, p);
194 
195         col_1_sum = vaddq_n_s32(col_1_sum, out_offset);
196         col_1_sum = vmaxq_s32(col_1_sum, vdupq_n_s32(activation_min));
197         col_1_sum = vminq_s32(col_1_sum, vdupq_n_s32(activation_max));
198         vstrbq_p_s32(out_tmp + num_ch, col_1_sum, p);
199 
200         out_tmp += tail_ch;
201     }
202 
203     return out_tmp + num_ch;
204 #else
205     (void)row;
206     (void)col;
207     (void)num_ch;
208     (void)out_shift;
209     (void)out_mult;
210     (void)out_offset;
211     (void)activation_min;
212     (void)activation_max;
213     (void)kernel_size;
214     (void)output_bias;
215     (void)out;
216     return NULL;
217 #endif
218 }
219