1 /*
2 * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_depthwise_separable_conv_HWC_q7_nonsquare.c
22 * Description: Q7 depthwise separable convolution function (non-square shape)
23 *
24 * $Date: January 26, 2021
25 * $Revision: V.1.0.2
26 *
27 * Target Processor: Cortex-M cores
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup groupNN
36 */
37
38 /**
39 * @addtogroup NNConv
40 * @{
41 */
42
43 /**
44 * @brief Q7 depthwise separable convolution function (non-square shape)
45 * @param[in] Im_in pointer to input tensor
46 * @param[in] dim_im_in_x input tensor dimension x
47 * @param[in] dim_im_in_y input tensor dimension y
48 * @param[in] ch_im_in number of input tensor channels
49 * @param[in] wt pointer to kernel weights
50 * @param[in] ch_im_out number of filters, i.e., output tensor channels
51 * @param[in] dim_kernel_x filter kernel size x
52 * @param[in] dim_kernel_y filter kernel size y
53 * @param[in] padding_x padding sizes x
54 * @param[in] padding_y padding sizes y
55 * @param[in] stride_x convolution stride x
56 * @param[in] stride_y convolution stride y
57 * @param[in] bias pointer to bias
58 * @param[in] bias_shift amount of left-shift for bias
59 * @param[in] out_shift amount of right-shift for output
60 * @param[in,out] Im_out pointer to output tensor
61 * @param[in] dim_im_out_x output tensor dimension x
62 * @param[in] dim_im_out_y output tensor dimension y
63 * @param[in,out] bufferA pointer to buffer space for input
64 * @param[in,out] bufferB pointer to buffer space for output
65 * @return The function returns either
66 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
67 *
68 * This function is the version with full list of optimization tricks, but with
69 * some constraints:
70 * ch_im_in is equal to ch_im_out
71 *
72 */
73
arm_depthwise_separable_conv_HWC_q7_nonsquare(const q7_t * Im_in,const uint16_t dim_im_in_x,const uint16_t dim_im_in_y,const uint16_t ch_im_in,const q7_t * wt,const uint16_t ch_im_out,const uint16_t dim_kernel_x,const uint16_t dim_kernel_y,const uint16_t padding_x,const uint16_t padding_y,const uint16_t stride_x,const uint16_t stride_y,const q7_t * bias,const uint16_t bias_shift,const uint16_t out_shift,q7_t * Im_out,const uint16_t dim_im_out_x,const uint16_t dim_im_out_y,q15_t * bufferA,q7_t * bufferB)74 arm_status arm_depthwise_separable_conv_HWC_q7_nonsquare(const q7_t *Im_in,
75 const uint16_t dim_im_in_x,
76 const uint16_t dim_im_in_y,
77 const uint16_t ch_im_in,
78 const q7_t *wt,
79 const uint16_t ch_im_out,
80 const uint16_t dim_kernel_x,
81 const uint16_t dim_kernel_y,
82 const uint16_t padding_x,
83 const uint16_t padding_y,
84 const uint16_t stride_x,
85 const uint16_t stride_y,
86 const q7_t *bias,
87 const uint16_t bias_shift,
88 const uint16_t out_shift,
89 q7_t *Im_out,
90 const uint16_t dim_im_out_x,
91 const uint16_t dim_im_out_y,
92 q15_t *bufferA,
93 q7_t *bufferB)
94 {
95
96 (void)bufferB;
97
98 #if defined(ARM_MATH_DSP)
99 /* Run the following code for Cortex-M4 and Cortex-M7 */
100
101 /*
102 * Implementation:
103 * There are 3 nested loop here:
104 * Inner loop: calculate each output value with MAC instruction over an accumulator
105 * Mid loop: loop over different output channel
106 * Outer loop: loop over different output (x, y)
107 *
108 */
109
110 int16_t i_out_y, i_out_x;
111 int16_t i_ker_y, i_ker_x;
112 q7_t *colBuffer = (q7_t *)bufferA;
113 q7_t *pBuffer = colBuffer;
114 const q7_t *pBias = bias;
115 q7_t *pOut = Im_out;
116 uint16_t rowCnt;
117 uint16_t row_shift;
118
119 /* do some checking here, basically ch_im_in == ch_im_out */
120 if (ch_im_in != ch_im_out)
121 {
122 return ARM_MATH_SIZE_MISMATCH;
123 }
124
125 for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
126 {
127 for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
128 {
129 /* we first do im2col here */
130 for (i_ker_y = i_out_y * stride_y - padding_y; i_ker_y < i_out_y * stride_y - padding_y + dim_kernel_y;
131 i_ker_y++)
132 {
133 for (i_ker_x = i_out_x * stride_x - padding_x; i_ker_x < i_out_x * stride_x - padding_x + dim_kernel_x;
134 i_ker_x++)
135 {
136 if (i_ker_y < 0 || i_ker_y >= dim_im_in_y || i_ker_x < 0 || i_ker_x >= dim_im_in_x)
137 {
138 /* arm_fill_q7(0, pBuffer, ch_im_in); */
139 memset(pBuffer, 0, ch_im_in);
140 }
141 else
142 {
143 /* arm_copy_q7((q7_t *) Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, pBuffer,
144 * ch_im_in); */
145 memcpy(pBuffer, (q7_t *)Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, ch_im_in);
146 }
147 pBuffer += ch_im_in;
148 }
149 }
150
151 /* we will do the computation here for each channel */
152 rowCnt = ch_im_out >> 2;
153 row_shift = 0;
154 pBias = bias;
155
156 while (rowCnt)
157 {
158 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
159 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
160 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
161 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
162
163 uint16_t colCnt = (dim_kernel_x * dim_kernel_y) >> 1;
164 q7_t *pB = colBuffer + row_shift;
165 const q7_t *pA = wt + row_shift;
166 row_shift += 4;
167
168 #ifdef USE_INTRINSIC
169
170 #ifndef ARM_MATH_BIG_ENDIAN
171
172 while (colCnt)
173 {
174 q31_t inA1, inA2, inB1, inB2, opA, opB;
175
176 inB1 = arm_nn_read_q7x4(pB);
177 pB += ch_im_in;
178 opB = arm_nn_read_q7x4(pB);
179 pB += ch_im_in;
180 inB2 = __PKHTB(opB, inB1, 16);
181 inB1 = __PKHBT(inB1, opB, 16);
182 inA1 = arm_nn_read_q7x4(pA);
183 pA += ch_im_in;
184 opB = arm_nn_read_q7x4(pA);
185 pA += ch_im_in;
186 inA2 = __PKHTB(opB, inA1, 16);
187 inA1 = __PKHBT(inA1, opB, 16);
188 opA = __SXTB16(inA1);
189 opB = __SXTB16(inB1);
190 sum = __SMLAD(opA, opB, sum);
191 opA = __SXTB16(__ROR(inA1, 8));
192 opB = __SXTB16(__ROR(inB1, 8));
193 sum2 = __SMLAD(opA, opB, sum2);
194 opA = __SXTB16(inA2);
195 opB = __SXTB16(inB2);
196 sum3 = __SMLAD(opA, opB, sum3);
197 opA = __SXTB16(__ROR(inA2, 8));
198 opB = __SXTB16(__ROR(inB2, 8));
199 sum4 = __SMLAD(opA, opB, sum4);
200 colCnt--;
201 }
202 #else
203
204 while (colCnt)
205 {
206 q31_t inA1, inA2, inB1, inB2, opA, opB;
207
208 inB1 = arm_nn_read_q7x4(pB);
209 pB += ch_im_in;
210 opB = arm_nn_read_q7x4(pB);
211 pB += ch_im_in;
212 inB2 = __PKHBT(opB, inB1, 16);
213 inB1 = __PKHTB(inB1, opB, 16);
214 inA1 = arm_nn_read_q7x4(pA);
215 pA += ch_im_in;
216 opB = arm_nn_read_q7x4(pA);
217 pA += ch_im_in;
218 inA2 = __PKHBT(opB, inA1, 16);
219 inA1 = __PKHTB(inA1, opB, 16);
220 opA = __SXTB16(inA1);
221 opB = __SXTB16(inB1);
222 sum2 = __SMLAD(opA, opB, sum2);
223 opA = __SXTB16(__ROR(inA1, 8));
224 opB = __SXTB16(__ROR(inB1, 8));
225 sum = __SMLAD(opA, opB, sum);
226 opA = __SXTB16(inA2);
227 opB = __SXTB16(inB2);
228 sum4 = __SMLAD(opA, opB, sum4);
229 opA = __SXTB16(__ROR(inA2, 8));
230 opB = __SXTB16(__ROR(inB2, 8));
231 sum3 = __SMLAD(opA, opB, sum3);
232 colCnt--;
233 }
234
235 #endif /* ARM_MATH_BIG_ENDIAN */
236
237 #else
238
239 #ifndef ARM_MATH_BIG_ENDIAN
240 // r0 r1 r2 r3 r4 r5
241 // inA1, inA2, inB1, inB2, opA, opB
242 asm volatile("COL_LOOP:\n"
243 "ldr.w r2, [%[pB], #0]\n"
244 "add.w %[pB], %[pB], %[ch_im_in]\n"
245 "ldr.w r5, [%[pB], #0]\n"
246 "add.w %[pB], %[pB], %[ch_im_in]\n"
247 "pkhtb r3, r5, r2, ASR #16\n"
248 "pkhbt r2, r2, r5, LSL #16\n"
249 "ldr.w r0, [%[pA], #0]\n"
250 "add.w %[pA], %[pA], %[ch_im_in]\n"
251 "ldr.w r5, [%[pA], #0]\n"
252 "add.w %[pA], %[pA], %[ch_im_in]\n"
253 "pkhtb r1, r5, r0, ASR #16\n"
254 "pkhbt r0, r0, r5, LSL #16\n"
255 "sxtb16 r4, r0\n"
256 "sxtb16 r5, r2\n"
257 "smlad %[sum], r4, r5, %[sum]\n"
258 "mov.w r4, r0, ror #8\n"
259 "mov.w r5, r2, ror #8\n"
260 "sxtb16 r4, r4\n"
261 "sxtb16 r5, r5\n"
262 "smlad %[sum2], r4, r5, %[sum2]\n"
263 "sxtb16 r4, r1\n"
264 "sxtb16 r5, r3\n"
265 "smlad %[sum3], r4, r5, %[sum3]\n"
266 "mov.w r4, r1, ror #8\n"
267 "mov.w r5, r3, ror #8\n"
268 "sxtb16 r4, r4\n"
269 "sxtb16 r5, r5\n"
270 "smlad %[sum4], r4, r5, %[sum4]\n"
271 "subs %[colCnt], #1\n"
272 "bne COL_LOOP\n"
273 : [ sum ] "+r"(sum),
274 [ sum2 ] "+r"(sum2),
275 [ sum3 ] "+r"(sum3),
276 [ sum4 ] "+r"(sum4),
277 [ pB ] "+r"(pB),
278 [ pA ] "+r"(pA)
279 : [ colCnt ] "r"(colCnt), [ ch_im_in ] "r"(ch_im_in)
280 : "r0", "r1", "r2", "r3", "r4", "r5");
281 #else
282 // r0 r1 r2 r3 r4 r5
283 // inA1, inA2, inB1, inB2, opA, opB
284 asm volatile("COL_LOOP:\n"
285 "ldr.w r2, [%[pB], #0]\n"
286 "add.w %[pB], %[pB], %[ch_im_in]\n"
287 "ldr.w r5, [%[pB], #0]\n"
288 "add.w %[pB], %[pB], %[ch_im_in]\n"
289 "pkhbt r3, r5, r2, LSL #16\n"
290 "pkhtb r2, r2, r5, ASR #16\n"
291 "ldr.w r0, [%[pA], #0]\n"
292 "add.w %[pA], %[pA], %[ch_im_in]\n"
293 "ldr.w r5, [%[pA], #0]\n"
294 "add.w %[pA], %[pA], %[ch_im_in]\n"
295 "pkhbt r1, r5, r0, LSL #16\n"
296 "pkhtb r0, r0, r5, ASR #16\n"
297 "sxtb16 r4, r0\n"
298 "sxtb16 r5, r2\n"
299 "smlad %[sum2], r4, r5, %[sum2]\n"
300 "mov.w r4, r0, ror #8\n"
301 "mov.w r5, r2, ror #8\n"
302 "sxtb16 r4, r4\n"
303 "sxtb16 r5, r5\n"
304 "smlad %[sum], r4, r5, %[sum]\n"
305 "sxtb16 r4, r1\n"
306 "sxtb16 r5, r3\n"
307 "smlad %[sum4], r4, r5, %[sum4]\n"
308 "mov.w r4, r1, ror #8\n"
309 "mov.w r5, r3, ror #8\n"
310 "sxtb16 r4, r4\n"
311 "sxtb16 r5, r5\n"
312 "smlad %[sum3], r4, r5, %[sum3]\n"
313 "subs %[colCnt], #1\n"
314 "bne COL_LOOP\n"
315 : [ sum ] "+r"(sum),
316 [ sum2 ] "+r"(sum2),
317 [ sum3 ] "+r"(sum3),
318 [ sum4 ] "+r"(sum4),
319 [ pB ] "+r"(pB),
320 [ pA ] "+r"(pA)
321 : [ colCnt ] "r"(colCnt), [ ch_im_in ] "r"(ch_im_in)
322 : "r0", "r1", "r2", "r3", "r4", "r5");
323 #endif /*ARM_MATH_BIG_ENDIAN */
324
325 #endif /* USE_INTRINSIC */
326
327 colCnt = (dim_kernel_x * dim_kernel_y) & 0x1;
328 while (colCnt)
329 {
330 union arm_nnword inA, inB;
331 inA.word = arm_nn_read_q7x4(pA);
332 pA += ch_im_in;
333 inB.word = arm_nn_read_q7x4(pB);
334 pB += ch_im_in;
335 sum += inA.bytes[0] * inB.bytes[0];
336 sum2 += inA.bytes[1] * inB.bytes[1];
337 sum3 += inA.bytes[2] * inB.bytes[2];
338 sum4 += inA.bytes[3] * inB.bytes[3];
339 colCnt--;
340 }
341
342 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
343 *pOut++ = (q7_t)__SSAT((sum2 >> out_shift), 8);
344 *pOut++ = (q7_t)__SSAT((sum3 >> out_shift), 8);
345 *pOut++ = (q7_t)__SSAT((sum4 >> out_shift), 8);
346
347 rowCnt--;
348 }
349
350 rowCnt = ch_im_out & 0x3;
351 while (rowCnt)
352 {
353 q7_t *pB = colBuffer + row_shift;
354 const q7_t *pA = wt + row_shift;
355 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
356 uint16_t colCnt = (dim_kernel_x * dim_kernel_y);
357
358 row_shift += 1;
359
360 while (colCnt)
361 {
362 q7_t A1 = *pA;
363 q7_t B1 = *pB;
364 pA += ch_im_in;
365 pB += ch_im_in;
366 sum += A1 * B1;
367
368 colCnt--;
369 }
370 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
371 rowCnt--;
372 }
373
374 // clear counter and pointers
375 pBuffer = colBuffer;
376 }
377 }
378
379 #else
380 (void)bufferA;
381
382 /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
383 int i_out_y, i_out_x, i_ch_out;
384 int i_ker_y, i_ker_x;
385
386 /* do some checking here, basically ch_im_in == ch_im_out */
387 if (ch_im_in != ch_im_out)
388 {
389 return ARM_MATH_SIZE_MISMATCH;
390 }
391
392 for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
393 {
394 for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
395 {
396 for (i_ch_out = 0; i_ch_out < ch_im_out; i_ch_out++)
397 {
398 // for each output
399 int conv_out = ((q31_t)(bias[i_ch_out]) << bias_shift) + NN_ROUND(out_shift);
400 for (i_ker_y = 0; i_ker_y < dim_kernel_y; i_ker_y++)
401 {
402 for (i_ker_x = 0; i_ker_x < dim_kernel_x; i_ker_x++)
403 {
404 int in_row = stride_y * i_out_y + i_ker_y - padding_y;
405 int in_col = stride_x * i_out_x + i_ker_x - padding_x;
406 if (in_row >= 0 && in_col >= 0 && in_row < dim_im_in_y && in_col < dim_im_in_x)
407 {
408 conv_out += Im_in[(in_row * dim_im_in_x + in_col) * ch_im_in + i_ch_out] *
409 wt[(i_ker_y * dim_kernel_x + i_ker_x) * ch_im_out + i_ch_out];
410 }
411 }
412 }
413 Im_out[(i_out_y * dim_im_out_x + i_out_x) * ch_im_out + i_ch_out] =
414 (q7_t)__SSAT((conv_out >> out_shift), 8);
415 }
416 }
417 }
418
419 #endif /* ARM_MATH_DSP */
420
421 /* Return to application */
422 return ARM_MATH_SUCCESS;
423 }
424
425 /**
426 * @} end of NNConv group
427 */
428