1 /*
2  * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_depthwise_separable_conv_HWC_q7_nonsquare.c
22  * Description:  Q7 depthwise separable convolution function (non-square shape)
23  *
24  * $Date:        January 26, 2021
25  * $Revision:    V.1.0.2
26  *
27  * Target Processor:  Cortex-M cores
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup groupNN
36  */
37 
38 /**
39  * @addtogroup NNConv
40  * @{
41  */
42 
43 /**
44  * @brief Q7 depthwise separable convolution function (non-square shape)
45  * @param[in]       Im_in         pointer to input tensor
46  * @param[in]       dim_im_in_x   input tensor dimension x
47  * @param[in]       dim_im_in_y   input tensor dimension y
48  * @param[in]       ch_im_in      number of input tensor channels
49  * @param[in]       wt            pointer to kernel weights
50  * @param[in]       ch_im_out     number of filters, i.e., output tensor channels
51  * @param[in]       dim_kernel_x  filter kernel size x
52  * @param[in]       dim_kernel_y  filter kernel size y
53  * @param[in]       padding_x     padding sizes x
54  * @param[in]       padding_y     padding sizes y
55  * @param[in]       stride_x      convolution stride x
56  * @param[in]       stride_y      convolution stride y
57  * @param[in]       bias          pointer to bias
58  * @param[in]       bias_shift    amount of left-shift for bias
59  * @param[in]       out_shift     amount of right-shift for output
60  * @param[in,out]   Im_out        pointer to output tensor
61  * @param[in]       dim_im_out_x  output tensor dimension x
62  * @param[in]       dim_im_out_y  output tensor dimension y
63  * @param[in,out]   bufferA       pointer to buffer space for input
64  * @param[in,out]   bufferB       pointer to buffer space for output
65  * @return     The function returns either
66  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
67  *
68  * This function is the version with full list of optimization tricks, but with
69  * some constraints:
70  *   ch_im_in is equal to ch_im_out
71  *
72  */
73 
arm_depthwise_separable_conv_HWC_q7_nonsquare(const q7_t * Im_in,const uint16_t dim_im_in_x,const uint16_t dim_im_in_y,const uint16_t ch_im_in,const q7_t * wt,const uint16_t ch_im_out,const uint16_t dim_kernel_x,const uint16_t dim_kernel_y,const uint16_t padding_x,const uint16_t padding_y,const uint16_t stride_x,const uint16_t stride_y,const q7_t * bias,const uint16_t bias_shift,const uint16_t out_shift,q7_t * Im_out,const uint16_t dim_im_out_x,const uint16_t dim_im_out_y,q15_t * bufferA,q7_t * bufferB)74 arm_status arm_depthwise_separable_conv_HWC_q7_nonsquare(const q7_t *Im_in,
75                                                          const uint16_t dim_im_in_x,
76                                                          const uint16_t dim_im_in_y,
77                                                          const uint16_t ch_im_in,
78                                                          const q7_t *wt,
79                                                          const uint16_t ch_im_out,
80                                                          const uint16_t dim_kernel_x,
81                                                          const uint16_t dim_kernel_y,
82                                                          const uint16_t padding_x,
83                                                          const uint16_t padding_y,
84                                                          const uint16_t stride_x,
85                                                          const uint16_t stride_y,
86                                                          const q7_t *bias,
87                                                          const uint16_t bias_shift,
88                                                          const uint16_t out_shift,
89                                                          q7_t *Im_out,
90                                                          const uint16_t dim_im_out_x,
91                                                          const uint16_t dim_im_out_y,
92                                                          q15_t *bufferA,
93                                                          q7_t *bufferB)
94 {
95 
96     (void)bufferB;
97 
98 #if defined(ARM_MATH_DSP)
99     /* Run the following code for Cortex-M4 and Cortex-M7 */
100 
101     /*
102      * Implementation:
103      * There are 3 nested loop here:
104      * Inner loop: calculate each output value with MAC instruction over an accumulator
105      * Mid   loop: loop over different output channel
106      * Outer loop: loop over different output (x, y)
107      *
108      */
109 
110     int16_t i_out_y, i_out_x;
111     int16_t i_ker_y, i_ker_x;
112     q7_t *colBuffer = (q7_t *)bufferA;
113     q7_t *pBuffer = colBuffer;
114     const q7_t *pBias = bias;
115     q7_t *pOut = Im_out;
116     uint16_t rowCnt;
117     uint16_t row_shift;
118 
119     /* do some checking here, basically ch_im_in == ch_im_out */
120     if (ch_im_in != ch_im_out)
121     {
122         return ARM_MATH_SIZE_MISMATCH;
123     }
124 
125     for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
126     {
127         for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
128         {
129             /* we first do im2col here */
130             for (i_ker_y = i_out_y * stride_y - padding_y; i_ker_y < i_out_y * stride_y - padding_y + dim_kernel_y;
131                  i_ker_y++)
132             {
133                 for (i_ker_x = i_out_x * stride_x - padding_x; i_ker_x < i_out_x * stride_x - padding_x + dim_kernel_x;
134                      i_ker_x++)
135                 {
136                     if (i_ker_y < 0 || i_ker_y >= dim_im_in_y || i_ker_x < 0 || i_ker_x >= dim_im_in_x)
137                     {
138                         /* arm_fill_q7(0, pBuffer, ch_im_in); */
139                         memset(pBuffer, 0, ch_im_in);
140                     }
141                     else
142                     {
143                         /* arm_copy_q7((q7_t *) Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, pBuffer,
144                          * ch_im_in); */
145                         memcpy(pBuffer, (q7_t *)Im_in + (i_ker_y * dim_im_in_x + i_ker_x) * ch_im_in, ch_im_in);
146                     }
147                     pBuffer += ch_im_in;
148                 }
149             }
150 
151             /* we will do the computation here for each channel */
152             rowCnt = ch_im_out >> 2;
153             row_shift = 0;
154             pBias = bias;
155 
156             while (rowCnt)
157             {
158                 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
159                 q31_t sum2 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
160                 q31_t sum3 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
161                 q31_t sum4 = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
162 
163                 uint16_t colCnt = (dim_kernel_x * dim_kernel_y) >> 1;
164                 q7_t *pB = colBuffer + row_shift;
165                 const q7_t *pA = wt + row_shift;
166                 row_shift += 4;
167 
168 #ifdef USE_INTRINSIC
169 
170 #ifndef ARM_MATH_BIG_ENDIAN
171 
172                 while (colCnt)
173                 {
174                     q31_t inA1, inA2, inB1, inB2, opA, opB;
175 
176                     inB1 = arm_nn_read_q7x4(pB);
177                     pB += ch_im_in;
178                     opB = arm_nn_read_q7x4(pB);
179                     pB += ch_im_in;
180                     inB2 = __PKHTB(opB, inB1, 16);
181                     inB1 = __PKHBT(inB1, opB, 16);
182                     inA1 = arm_nn_read_q7x4(pA);
183                     pA += ch_im_in;
184                     opB = arm_nn_read_q7x4(pA);
185                     pA += ch_im_in;
186                     inA2 = __PKHTB(opB, inA1, 16);
187                     inA1 = __PKHBT(inA1, opB, 16);
188                     opA = __SXTB16(inA1);
189                     opB = __SXTB16(inB1);
190                     sum = __SMLAD(opA, opB, sum);
191                     opA = __SXTB16(__ROR(inA1, 8));
192                     opB = __SXTB16(__ROR(inB1, 8));
193                     sum2 = __SMLAD(opA, opB, sum2);
194                     opA = __SXTB16(inA2);
195                     opB = __SXTB16(inB2);
196                     sum3 = __SMLAD(opA, opB, sum3);
197                     opA = __SXTB16(__ROR(inA2, 8));
198                     opB = __SXTB16(__ROR(inB2, 8));
199                     sum4 = __SMLAD(opA, opB, sum4);
200                     colCnt--;
201                 }
202 #else
203 
204                 while (colCnt)
205                 {
206                     q31_t inA1, inA2, inB1, inB2, opA, opB;
207 
208                     inB1 = arm_nn_read_q7x4(pB);
209                     pB += ch_im_in;
210                     opB = arm_nn_read_q7x4(pB);
211                     pB += ch_im_in;
212                     inB2 = __PKHBT(opB, inB1, 16);
213                     inB1 = __PKHTB(inB1, opB, 16);
214                     inA1 = arm_nn_read_q7x4(pA);
215                     pA += ch_im_in;
216                     opB = arm_nn_read_q7x4(pA);
217                     pA += ch_im_in;
218                     inA2 = __PKHBT(opB, inA1, 16);
219                     inA1 = __PKHTB(inA1, opB, 16);
220                     opA = __SXTB16(inA1);
221                     opB = __SXTB16(inB1);
222                     sum2 = __SMLAD(opA, opB, sum2);
223                     opA = __SXTB16(__ROR(inA1, 8));
224                     opB = __SXTB16(__ROR(inB1, 8));
225                     sum = __SMLAD(opA, opB, sum);
226                     opA = __SXTB16(inA2);
227                     opB = __SXTB16(inB2);
228                     sum4 = __SMLAD(opA, opB, sum4);
229                     opA = __SXTB16(__ROR(inA2, 8));
230                     opB = __SXTB16(__ROR(inB2, 8));
231                     sum3 = __SMLAD(opA, opB, sum3);
232                     colCnt--;
233                 }
234 
235 #endif /* ARM_MATH_BIG_ENDIAN */
236 
237 #else
238 
239 #ifndef ARM_MATH_BIG_ENDIAN
240                 //  r0    r1    r2    r3    r4   r5
241                 // inA1, inA2, inB1, inB2, opA, opB
242                 asm volatile("COL_LOOP:\n"
243                              "ldr.w r2, [%[pB], #0]\n"
244                              "add.w %[pB], %[pB], %[ch_im_in]\n"
245                              "ldr.w r5, [%[pB], #0]\n"
246                              "add.w %[pB], %[pB], %[ch_im_in]\n"
247                              "pkhtb r3, r5, r2, ASR #16\n"
248                              "pkhbt r2, r2, r5, LSL #16\n"
249                              "ldr.w r0, [%[pA], #0]\n"
250                              "add.w %[pA], %[pA], %[ch_im_in]\n"
251                              "ldr.w r5, [%[pA], #0]\n"
252                              "add.w %[pA], %[pA], %[ch_im_in]\n"
253                              "pkhtb r1, r5, r0, ASR #16\n"
254                              "pkhbt r0, r0, r5, LSL #16\n"
255                              "sxtb16 r4, r0\n"
256                              "sxtb16 r5, r2\n"
257                              "smlad %[sum], r4, r5, %[sum]\n"
258                              "mov.w r4, r0, ror #8\n"
259                              "mov.w r5, r2, ror #8\n"
260                              "sxtb16 r4, r4\n"
261                              "sxtb16 r5, r5\n"
262                              "smlad %[sum2], r4, r5, %[sum2]\n"
263                              "sxtb16 r4, r1\n"
264                              "sxtb16 r5, r3\n"
265                              "smlad %[sum3], r4, r5, %[sum3]\n"
266                              "mov.w r4, r1, ror #8\n"
267                              "mov.w r5, r3, ror #8\n"
268                              "sxtb16 r4, r4\n"
269                              "sxtb16 r5, r5\n"
270                              "smlad %[sum4], r4, r5, %[sum4]\n"
271                              "subs %[colCnt], #1\n"
272                              "bne COL_LOOP\n"
273                              : [ sum ] "+r"(sum),
274                                [ sum2 ] "+r"(sum2),
275                                [ sum3 ] "+r"(sum3),
276                                [ sum4 ] "+r"(sum4),
277                                [ pB ] "+r"(pB),
278                                [ pA ] "+r"(pA)
279                              : [ colCnt ] "r"(colCnt), [ ch_im_in ] "r"(ch_im_in)
280                              : "r0", "r1", "r2", "r3", "r4", "r5");
281 #else
282                 //  r0    r1    r2    r3    r4   r5
283                 // inA1, inA2, inB1, inB2, opA, opB
284                 asm volatile("COL_LOOP:\n"
285                              "ldr.w r2, [%[pB], #0]\n"
286                              "add.w %[pB], %[pB], %[ch_im_in]\n"
287                              "ldr.w r5, [%[pB], #0]\n"
288                              "add.w %[pB], %[pB], %[ch_im_in]\n"
289                              "pkhbt r3, r5, r2, LSL #16\n"
290                              "pkhtb r2, r2, r5, ASR #16\n"
291                              "ldr.w r0, [%[pA], #0]\n"
292                              "add.w %[pA], %[pA], %[ch_im_in]\n"
293                              "ldr.w r5, [%[pA], #0]\n"
294                              "add.w %[pA], %[pA], %[ch_im_in]\n"
295                              "pkhbt r1, r5, r0, LSL #16\n"
296                              "pkhtb r0, r0, r5, ASR #16\n"
297                              "sxtb16 r4, r0\n"
298                              "sxtb16 r5, r2\n"
299                              "smlad %[sum2], r4, r5, %[sum2]\n"
300                              "mov.w r4, r0, ror #8\n"
301                              "mov.w r5, r2, ror #8\n"
302                              "sxtb16 r4, r4\n"
303                              "sxtb16 r5, r5\n"
304                              "smlad %[sum], r4, r5, %[sum]\n"
305                              "sxtb16 r4, r1\n"
306                              "sxtb16 r5, r3\n"
307                              "smlad %[sum4], r4, r5, %[sum4]\n"
308                              "mov.w r4, r1, ror #8\n"
309                              "mov.w r5, r3, ror #8\n"
310                              "sxtb16 r4, r4\n"
311                              "sxtb16 r5, r5\n"
312                              "smlad %[sum3], r4, r5, %[sum3]\n"
313                              "subs %[colCnt], #1\n"
314                              "bne COL_LOOP\n"
315                              : [ sum ] "+r"(sum),
316                                [ sum2 ] "+r"(sum2),
317                                [ sum3 ] "+r"(sum3),
318                                [ sum4 ] "+r"(sum4),
319                                [ pB ] "+r"(pB),
320                                [ pA ] "+r"(pA)
321                              : [ colCnt ] "r"(colCnt), [ ch_im_in ] "r"(ch_im_in)
322                              : "r0", "r1", "r2", "r3", "r4", "r5");
323 #endif /*ARM_MATH_BIG_ENDIAN */
324 
325 #endif /* USE_INTRINSIC */
326 
327                 colCnt = (dim_kernel_x * dim_kernel_y) & 0x1;
328                 while (colCnt)
329                 {
330                     union arm_nnword inA, inB;
331                     inA.word = arm_nn_read_q7x4(pA);
332                     pA += ch_im_in;
333                     inB.word = arm_nn_read_q7x4(pB);
334                     pB += ch_im_in;
335                     sum += inA.bytes[0] * inB.bytes[0];
336                     sum2 += inA.bytes[1] * inB.bytes[1];
337                     sum3 += inA.bytes[2] * inB.bytes[2];
338                     sum4 += inA.bytes[3] * inB.bytes[3];
339                     colCnt--;
340                 }
341 
342                 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
343                 *pOut++ = (q7_t)__SSAT((sum2 >> out_shift), 8);
344                 *pOut++ = (q7_t)__SSAT((sum3 >> out_shift), 8);
345                 *pOut++ = (q7_t)__SSAT((sum4 >> out_shift), 8);
346 
347                 rowCnt--;
348             }
349 
350             rowCnt = ch_im_out & 0x3;
351             while (rowCnt)
352             {
353                 q7_t *pB = colBuffer + row_shift;
354                 const q7_t *pA = wt + row_shift;
355                 q31_t sum = ((q31_t)(*pBias++) << bias_shift) + NN_ROUND(out_shift);
356                 uint16_t colCnt = (dim_kernel_x * dim_kernel_y);
357 
358                 row_shift += 1;
359 
360                 while (colCnt)
361                 {
362                     q7_t A1 = *pA;
363                     q7_t B1 = *pB;
364                     pA += ch_im_in;
365                     pB += ch_im_in;
366                     sum += A1 * B1;
367 
368                     colCnt--;
369                 }
370                 *pOut++ = (q7_t)__SSAT((sum >> out_shift), 8);
371                 rowCnt--;
372             }
373 
374             // clear counter and pointers
375             pBuffer = colBuffer;
376         }
377     }
378 
379 #else
380     (void)bufferA;
381 
382     /* Run the following code as reference implementation for Cortex-M0 and Cortex-M3 */
383     int i_out_y, i_out_x, i_ch_out;
384     int i_ker_y, i_ker_x;
385 
386     /* do some checking here, basically ch_im_in == ch_im_out */
387     if (ch_im_in != ch_im_out)
388     {
389         return ARM_MATH_SIZE_MISMATCH;
390     }
391 
392     for (i_out_y = 0; i_out_y < dim_im_out_y; i_out_y++)
393     {
394         for (i_out_x = 0; i_out_x < dim_im_out_x; i_out_x++)
395         {
396             for (i_ch_out = 0; i_ch_out < ch_im_out; i_ch_out++)
397             {
398                 // for each output
399                 int conv_out = ((q31_t)(bias[i_ch_out]) << bias_shift) + NN_ROUND(out_shift);
400                 for (i_ker_y = 0; i_ker_y < dim_kernel_y; i_ker_y++)
401                 {
402                     for (i_ker_x = 0; i_ker_x < dim_kernel_x; i_ker_x++)
403                     {
404                         int in_row = stride_y * i_out_y + i_ker_y - padding_y;
405                         int in_col = stride_x * i_out_x + i_ker_x - padding_x;
406                         if (in_row >= 0 && in_col >= 0 && in_row < dim_im_in_y && in_col < dim_im_in_x)
407                         {
408                             conv_out += Im_in[(in_row * dim_im_in_x + in_col) * ch_im_in + i_ch_out] *
409                                 wt[(i_ker_y * dim_kernel_x + i_ker_x) * ch_im_out + i_ch_out];
410                         }
411                     }
412                 }
413                 Im_out[(i_out_y * dim_im_out_x + i_out_x) * ch_im_out + i_ch_out] =
414                     (q7_t)__SSAT((conv_out >> out_shift), 8);
415             }
416         }
417     }
418 
419 #endif /* ARM_MATH_DSP */
420 
421     /* Return to application */
422     return ARM_MATH_SUCCESS;
423 }
424 
425 /**
426  * @} end of NNConv group
427  */
428