1 
2 /*
3  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  *
7  * Licensed under the Apache License, Version 2.0 (the License); you may
8  * not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
15  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 /* ----------------------------------------------------------------------
21  * Project:      CMSIS NN Library
22  * Title:        arm_pad_s8.c
23  * Description:  Pad a s8 vector
24  *
25  * $Date:        19 Sep 2024
26  * $Revision:    V.1.0.0
27  *
28  * Target :  Arm(R) M-Profile Architecture
29  *
30  * -------------------------------------------------------------------- */
31 
32 #include "arm_nn_types.h"
33 #include "arm_nnfunctions.h"
34 #include "arm_nnsupportfunctions.h"
35 /**
36  *  @ingroup Public
37  */
38 
39 /**
40  * @addtogroup Pad
41  * @{
42  */
43 
44 /*
45  * Basic s8 pad function.
46  *
47  * Refer header file for details.
48  *
49  */
50 
arm_pad_s8(const int8_t * input,int8_t * output,const int8_t pad_value,const cmsis_nn_dims * input_size,const cmsis_nn_dims * pre_pad,const cmsis_nn_dims * post_pad)51 arm_cmsis_nn_status arm_pad_s8(const int8_t *input,
52                                int8_t *output,
53                                const int8_t pad_value,
54                                const cmsis_nn_dims *input_size,
55                                const cmsis_nn_dims *pre_pad,
56                                const cmsis_nn_dims *post_pad)
57 {
58 
59     const cmsis_nn_dims output_size = {pre_pad->n + input_size->n + post_pad->n,
60                                        pre_pad->h + input_size->h + post_pad->h,
61                                        pre_pad->w + input_size->w + post_pad->w,
62                                        pre_pad->c + input_size->c + post_pad->c};
63 
64     const int32_t batch_block_size = output_size.h * output_size.w * output_size.c;
65     const int32_t row_block_size = output_size.w * output_size.c;
66     const int32_t col_block_size = output_size.c;
67 
68     arm_memset_s8(output, pad_value, batch_block_size * pre_pad->n);
69     output += batch_block_size * pre_pad->n;
70     for (int32_t b = 0; b < input_size->n; b++)
71     {
72 
73         arm_memset_s8(output, pad_value, row_block_size * pre_pad->h);
74         output += row_block_size * pre_pad->h;
75         for (int32_t y = 0; y < input_size->h; y++)
76         {
77 
78             arm_memset_s8(output, pad_value, col_block_size * pre_pad->w);
79             output += col_block_size * pre_pad->w;
80             if (input_size->c == output_size.c)
81             {
82                 arm_memcpy_s8(output, input, input_size->w * input_size->c);
83                 output += input_size->w * input_size->c;
84                 input += input_size->w * input_size->c;
85             }
86             else
87             {
88                 for (int32_t x = 0; x < input_size->w; x++)
89                 {
90 
91                     arm_memset_s8(output, pad_value, pre_pad->c);
92                     output += pre_pad->c;
93 
94                     arm_memcpy_s8(output, input, input_size->c);
95                     output += input_size->c;
96                     input += input_size->c;
97 
98                     arm_memset_s8(output, pad_value, post_pad->c);
99                     output += post_pad->c;
100                 }
101             }
102 
103             arm_memset_s8(output, pad_value, col_block_size * post_pad->w);
104             output += col_block_size * post_pad->w;
105         }
106 
107         arm_memset_s8(output, pad_value, row_block_size * post_pad->h);
108         output += row_block_size * post_pad->h;
109     }
110     arm_memset_s8(output, pad_value, batch_block_size * post_pad->n);
111 
112     return ARM_CMSIS_NN_SUCCESS;
113 }
114 
115 /**
116  * @} end of Pad group
117  */
118