1 /*
2  * Copyright (C) 2010-2020 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_avgpool_s8.c
22  * Description:  Pooling function implementations
23  *
24  * $Date:        01. March 2021
25  * $Revision:    V.2.0.4
26  *
27  * Target Processor:  Cortex-M CPUs
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
35 
scale_q31_to_q7_and_clamp(const q31_t * buffer,q7_t * target,int32_t length,const int32_t count,const int act_min,const int act_max)36 static void scale_q31_to_q7_and_clamp(const q31_t *buffer,
37                                       q7_t *target,
38                                       int32_t length,
39                                       const int32_t count,
40                                       const int act_min,
41                                       const int act_max)
42 {
43     const int half_count = count / 2;
44     for (int i = 0; i < length; i++)
45     {
46         int32_t sum = buffer[i] > 0 ? (buffer[i] + half_count) : (buffer[i] - half_count);
47         sum = sum / count;
48         sum = MAX(sum, act_min);
49         sum = MIN(sum, act_max);
50 
51         target[i] = (q7_t)sum;
52     }
53 }
54 #endif
55 
56 /**
57  *  @ingroup groupNN
58  */
59 
60 /**
61  * @addtogroup Pooling
62  * @{
63  */
64 
65 /*
66  * s8 average pooling function
67  *
68  * Refer to header file for details.
69  *
70  */
71 
72 #if defined(ARM_MATH_MVEI)
73 
arm_avgpool_s8(const cmsis_nn_context * ctx,const cmsis_nn_pool_params * pool_params,const cmsis_nn_dims * input_dims,const q7_t * src,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims,q7_t * dst)74 arm_status arm_avgpool_s8(const cmsis_nn_context *ctx,
75                           const cmsis_nn_pool_params *pool_params,
76                           const cmsis_nn_dims *input_dims,
77                           const q7_t *src,
78                           const cmsis_nn_dims *filter_dims,
79                           const cmsis_nn_dims *output_dims,
80                           q7_t *dst)
81 {
82     (void)ctx;
83     const int32_t input_y = input_dims->h;
84     const int32_t input_x = input_dims->w;
85     const int32_t output_y = output_dims->h;
86     const int32_t output_x = output_dims->w;
87     const int32_t stride_y = pool_params->stride.h;
88     const int32_t stride_x = pool_params->stride.w;
89     const int32_t kernel_y = filter_dims->h;
90     const int32_t kernel_x = filter_dims->w;
91     const int32_t pad_y = pool_params->padding.h;
92     const int32_t pad_x = pool_params->padding.w;
93     const int32_t act_min = pool_params->activation.min;
94     const int32_t act_max = pool_params->activation.max;
95     const int32_t ch_src = input_dims->c;
96 
97     int32_t i_x, i_y;
98     int32_t k_x, k_y;
99 
100     for (i_y = 0; i_y < output_y; i_y++)
101     {
102         for (i_x = 0; i_x < output_x; i_x++)
103         {
104 
105             int32_t k_y_start, k_y_end;
106             int32_t k_x_start, k_x_end;
107             int32_t chCnt;
108             const int8_t *pTmp, *pTmpInner;
109             int8_t *pDst;
110 
111             k_y_start = MAX(0, i_y * stride_y - pad_y);
112             k_y_end = MIN(i_y * stride_y - pad_y + kernel_y, input_y);
113 
114             k_x_start = MAX(0, i_x * stride_x - pad_x);
115             k_x_end = MIN(i_x * stride_x - pad_x + kernel_x, input_x);
116 
117             pTmp = src;
118             pDst = &dst[ch_src * (i_x + i_y * output_x)];
119 
120             chCnt = ch_src >> 4;
121             while (chCnt > 0)
122             {
123                 int32x4_t sumV1, sumV2, sumV3, sumV4;
124 
125                 int8x16_t tempV;
126                 int16x8_t tempVLO, tempVHI;
127                 int32x4_t tempVLOLO, tempVLOHI, tempVHILO, tempVHIHI;
128                 int32_t count = 0;
129 
130                 sumV1 = vdupq_n_s32(0);
131                 sumV2 = vdupq_n_s32(0);
132                 sumV3 = vdupq_n_s32(0);
133                 sumV4 = vdupq_n_s32(0);
134 
135                 for (k_y = k_y_start; k_y < k_y_end; k_y++)
136                 {
137                     for (k_x = k_x_start; k_x < k_x_end; k_x++)
138                     {
139                         pTmpInner = pTmp + (ch_src * (k_x + k_y * input_x));
140                         tempV = vldrbq_s8(pTmpInner);
141 
142                         tempVLO = vmovlbq_s8(tempV);
143                         tempVHI = vmovltq_s8(tempV);
144 
145                         tempVLOLO = vmovlbq_s16(tempVLO);
146                         tempVLOHI = vmovltq_s16(tempVLO);
147 
148                         tempVHILO = vmovlbq_s16(tempVHI);
149                         tempVHIHI = vmovltq_s16(tempVHI);
150 
151                         sumV1 = vaddq_s32(sumV1, tempVLOLO);
152                         sumV2 = vaddq_s32(sumV2, tempVLOHI);
153                         sumV3 = vaddq_s32(sumV3, tempVHILO);
154                         sumV4 = vaddq_s32(sumV4, tempVHIHI);
155 
156                         count++;
157                     }
158                 }
159 
160                 // Prevent static code issue DIVIDE_BY_ZERO.
161                 if (count == 0)
162                 {
163                     return ARM_MATH_ARGUMENT_ERROR;
164                 }
165 
166                 sumV1[0] = sumV1[0] > 0 ? (sumV1[0] + count / 2) / count : (sumV1[0] - count / 2) / count;
167                 sumV1[1] = sumV1[1] > 0 ? (sumV1[1] + count / 2) / count : (sumV1[1] - count / 2) / count;
168                 sumV1[2] = sumV1[2] > 0 ? (sumV1[2] + count / 2) / count : (sumV1[2] - count / 2) / count;
169                 sumV1[3] = sumV1[3] > 0 ? (sumV1[3] + count / 2) / count : (sumV1[3] - count / 2) / count;
170 
171                 sumV2[0] = sumV2[0] > 0 ? (sumV2[0] + count / 2) / count : (sumV2[0] - count / 2) / count;
172                 sumV2[1] = sumV2[1] > 0 ? (sumV2[1] + count / 2) / count : (sumV2[1] - count / 2) / count;
173                 sumV2[2] = sumV2[2] > 0 ? (sumV2[2] + count / 2) / count : (sumV2[2] - count / 2) / count;
174                 sumV2[3] = sumV2[3] > 0 ? (sumV2[3] + count / 2) / count : (sumV2[3] - count / 2) / count;
175 
176                 sumV3[0] = sumV3[0] > 0 ? (sumV3[0] + count / 2) / count : (sumV3[0] - count / 2) / count;
177                 sumV3[1] = sumV3[1] > 0 ? (sumV3[1] + count / 2) / count : (sumV3[1] - count / 2) / count;
178                 sumV3[2] = sumV3[2] > 0 ? (sumV3[2] + count / 2) / count : (sumV3[2] - count / 2) / count;
179                 sumV3[3] = sumV3[3] > 0 ? (sumV3[3] + count / 2) / count : (sumV3[3] - count / 2) / count;
180 
181                 sumV4[0] = sumV4[0] > 0 ? (sumV4[0] + count / 2) / count : (sumV4[0] - count / 2) / count;
182                 sumV4[1] = sumV4[1] > 0 ? (sumV4[1] + count / 2) / count : (sumV4[1] - count / 2) / count;
183                 sumV4[2] = sumV4[2] > 0 ? (sumV4[2] + count / 2) / count : (sumV4[2] - count / 2) / count;
184                 sumV4[3] = sumV4[3] > 0 ? (sumV4[3] + count / 2) / count : (sumV4[3] - count / 2) / count;
185 
186                 sumV1 = vmaxq_s32(sumV1, vdupq_n_s32(act_min));
187                 sumV1 = vminq_s32(sumV1, vdupq_n_s32(act_max));
188 
189                 sumV2 = vmaxq_s32(sumV2, vdupq_n_s32(act_min));
190                 sumV2 = vminq_s32(sumV2, vdupq_n_s32(act_max));
191 
192                 sumV3 = vmaxq_s32(sumV3, vdupq_n_s32(act_min));
193                 sumV3 = vminq_s32(sumV3, vdupq_n_s32(act_max));
194 
195                 sumV4 = vmaxq_s32(sumV4, vdupq_n_s32(act_min));
196                 sumV4 = vminq_s32(sumV4, vdupq_n_s32(act_max));
197 
198                 tempVLO = vmovnbq_s32(tempVLO, sumV1);
199                 tempVLO = vmovntq_s32(tempVLO, sumV2);
200 
201                 tempVHI = vmovnbq_s32(tempVHI, sumV3);
202                 tempVHI = vmovntq_s32(tempVHI, sumV4);
203 
204                 tempV = vmovnbq_s16(tempV, tempVLO);
205                 tempV = vmovntq_s16(tempV, tempVHI);
206 
207                 vstrbq_s8(pDst, tempV);
208                 pDst += 16;
209 
210                 chCnt--;
211                 pTmp += 16;
212             }
213 
214             chCnt = ch_src & 0xF;
215             while (chCnt > 0)
216             {
217                 int32_t sum = 0;
218                 int32_t count = 0;
219 
220                 for (k_y = k_y_start; k_y < k_y_end; k_y++)
221                 {
222                     for (k_x = k_x_start; k_x < k_x_end; k_x++)
223                     {
224                         sum += pTmp[ch_src * (k_x + k_y * input_x)];
225                         count++;
226                     }
227                 }
228                 sum = sum > 0 ? (sum + count / 2) / count : (sum - count / 2) / count;
229                 sum = MAX(sum, act_min);
230                 sum = MIN(sum, act_max);
231 
232                 *pDst++ = sum;
233 
234                 chCnt--;
235                 pTmp++;
236             }
237         }
238     }
239     return ARM_MATH_SUCCESS;
240 }
241 
242 #else
arm_avgpool_s8(const cmsis_nn_context * ctx,const cmsis_nn_pool_params * pool_params,const cmsis_nn_dims * input_dims,const q7_t * src,const cmsis_nn_dims * filter_dims,const cmsis_nn_dims * output_dims,q7_t * dst)243 arm_status arm_avgpool_s8(const cmsis_nn_context *ctx,
244                           const cmsis_nn_pool_params *pool_params,
245                           const cmsis_nn_dims *input_dims,
246                           const q7_t *src,
247                           const cmsis_nn_dims *filter_dims,
248                           const cmsis_nn_dims *output_dims,
249                           q7_t *dst)
250 {
251     const int32_t input_y = input_dims->h;
252     const int32_t input_x = input_dims->w;
253     const int32_t output_y = output_dims->h;
254     const int32_t output_x = output_dims->w;
255     const int32_t stride_y = pool_params->stride.h;
256     const int32_t stride_x = pool_params->stride.w;
257     const int32_t kernel_y = filter_dims->h;
258     const int32_t kernel_x = filter_dims->w;
259     const int32_t pad_y = pool_params->padding.h;
260     const int32_t pad_x = pool_params->padding.w;
261     const int32_t act_min = pool_params->activation.min;
262     const int32_t act_max = pool_params->activation.max;
263     const int32_t ch_src = input_dims->c;
264     q31_t *buffer = (q31_t *)ctx->buf;
265 
266 #if defined(ARM_MATH_DSP)
267 
268     /* Run the following code for CPU's with DSP extension
269      */
270     for (int i_y = 0, idx_y = -pad_y; i_y < output_y; idx_y += stride_y, i_y++)
271     {
272         for (int i_x = 0, idx_x = -pad_x; i_x < output_x; idx_x += stride_x, i_x++)
273         {
274             /* Condition for kernel start dimension:
275                       (base_idx_<x,y> + kernel_<x,y>_start) >= 0 */
276             const int32_t kernel_y_start = MAX(0, -idx_y);
277             const int32_t kernel_x_start = MAX(0, -idx_x);
278 
279             /* Condition for kernel end dimension:
280                    (base_idx_<x,y> + kernel_<x,y>_end) < dim_src_<width,height> */
281             const int32_t kernel_y_end = MIN(kernel_y, input_y - idx_y);
282             const int32_t kernel_x_end = MIN(kernel_x, input_x - idx_x);
283 
284             int count = 0;
285 
286             for (int k_y = kernel_y_start; k_y < kernel_y_end; k_y++)
287             {
288                 for (int k_x = kernel_x_start; k_x < kernel_x_end; k_x++)
289                 {
290                     const q7_t *start = src + ch_src * (k_x + idx_x + (k_y + idx_y) * input_x);
291 
292                     if (count == 0)
293                     {
294                         for (int i = 0; i < ch_src; i++)
295                         {
296                             buffer[i] = start[i];
297                         }
298                     }
299                     else
300                     {
301                         for (int i = 0; i < ch_src; i++)
302                         {
303                             buffer[i] = __QADD(start[i], buffer[i]);
304                         }
305                     }
306                     count++;
307                 }
308             }
309 
310             // Prevent static code issue DIVIDE_BY_ZERO.
311             if (count == 0)
312             {
313                 return ARM_MATH_ARGUMENT_ERROR;
314             }
315 
316             scale_q31_to_q7_and_clamp(buffer, dst, ch_src, count, act_min, act_max);
317             dst += ch_src;
318         }
319     }
320 #else
321 
322     /* Reference C code adapted from CMSIS-NN arm_avepool_q7_HWC.
323      */
324     (void)buffer;
325     int16_t i_ch_in, i_x, i_y;
326     int16_t k_x, k_y;
327 
328     for (i_y = 0; i_y < output_y; i_y++)
329     {
330         for (i_x = 0; i_x < output_x; i_x++)
331         {
332             for (i_ch_in = 0; i_ch_in < ch_src; i_ch_in++)
333             {
334                 int sum = 0;
335                 int count = 0;
336                 for (k_y = i_y * stride_y - pad_y; k_y < i_y * stride_y - pad_y + kernel_y; k_y++)
337                 {
338                     for (k_x = i_x * stride_x - pad_x; k_x < i_x * stride_x - pad_x + kernel_x; k_x++)
339                     {
340                         if (k_y >= 0 && k_x >= 0 && k_y < input_y && k_x < input_x)
341                         {
342                             sum += src[i_ch_in + ch_src * (k_x + k_y * input_x)];
343                             count++;
344                         }
345                     }
346                 }
347 
348                 // Prevent static code issue DIVIDE_BY_ZERO.
349                 if (count == 0)
350                 {
351                     return ARM_MATH_ARGUMENT_ERROR;
352                 }
353 
354                 sum = sum > 0 ? (sum + count / 2) / count : (sum - count / 2) / count;
355                 sum = MAX(sum, act_min);
356                 sum = MIN(sum, act_max);
357 
358                 dst[i_ch_in + ch_src * (i_x + i_y * output_x)] = sum;
359             }
360         }
361     }
362 
363 #endif
364     return ARM_MATH_SUCCESS;
365 }
366 
367 #endif /* ARM_MATH_MVEI */
368 
arm_avgpool_s8_get_buffer_size(const int output_x,const int ch_src)369 int32_t arm_avgpool_s8_get_buffer_size(const int output_x, const int ch_src)
370 {
371     (void)output_x;
372 
373 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
374     return (ch_src * sizeof(int32_t));
375 #else
376     (void)ch_src;
377     return 0;
378 #endif
379 }
380 /**
381  * @} end of Pooling group
382  */
383