1 /*
2  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_transpose_s8.c
22  * Description:  Transpose a s8 vector
23  *
24  * $Date:        30 October 2024
25  * $Revision:    V.1.0.1
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  *  @ingroup Public
36  */
37 
38 /**
39  * @addtogroup Transpose
40  * @{
41  */
42 
arm_transpose_s8_nhcw(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const int32_t * const in_strides,const int32_t * const out_strides)43 static arm_cmsis_nn_status arm_transpose_s8_nhcw(const int8_t *input,
44                                                  int8_t *const output,
45                                                  const cmsis_nn_dims *const input_dims,
46                                                  const int32_t *const in_strides,
47                                                  const int32_t *const out_strides)
48 {
49     const int32_t n = input_dims->n;
50     const int32_t h = input_dims->h;
51     const int32_t w = input_dims->w;
52     const int32_t c = input_dims->c;
53 
54     const int8_t *input_n = input;
55     int8_t *output_n = output;
56 
57     const uint16_t src_rows = w;
58     const uint16_t src_cols = c;
59 
60 #if defined(ARM_MATH_MVEI)
61     uint16x8_t vec_offsets;
62     uint16x8_t vec_input;
63 
64     vec_offsets = vidupq_u16((uint32_t)0, 1);
65     vec_offsets = vec_offsets * src_cols;
66 #endif
67 
68     for (int32_t i = 0; i < n; i++)
69     {
70         const int8_t *input_h = input_n;
71         int8_t *output_h = output_n;
72 
73         for (int32_t y = 0; y < h; y++)
74         {
75 
76 #if defined(ARM_MATH_MVEI)
77             const uint8_t *input_c = (const uint8_t *)input_h;
78             uint8_t *output_c = (uint8_t *)output_h;
79 
80             for (int32_t z = 0; z < src_cols; z++)
81             {
82                 uint8_t const *input_w = (uint8_t const *)input_c;
83                 uint8_t *output_w = (uint8_t *)output_c;
84 
85                 int32_t block_count = src_rows;
86                 while (block_count > 0)
87                 {
88                     mve_pred16_t p = vctp16q(block_count);
89 
90                     vec_input = vldrbq_gather_offset_z_u16(input_w, vec_offsets, p);
91                     vstrbq_p_u16(output_w, vec_input, p);
92 
93                     input_w = input_w + src_cols * 8;
94                     output_w += 8;
95                     block_count -= 8;
96                 }
97 
98                 input_c++;
99                 output_c += src_rows;
100             }
101 #else
102             const uint8_t *input_w = (const uint8_t *)input_h;
103             uint8_t *output_w = (uint8_t *)output_h;
104 
105             for (int32_t src_row_i = 0; src_row_i < src_rows; src_row_i++)
106             {
107                 output_w = (uint8_t *)output + src_row_i;
108 
109                 for (int32_t x = 0; x < src_cols; x++)
110                 {
111                     *output_w = *input_w++;
112                     output_w += src_rows;
113                 }
114             }
115 #endif
116             input_h += in_strides[1];
117             output_h += out_strides[1];
118         }
119         input_n += in_strides[0];
120         output_n += out_strides[0];
121     }
122 
123     return ARM_CMSIS_NN_SUCCESS;
124 }
125 
arm_transpose_s8_default(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const int32_t * const in_strides,const int32_t * const out_strides)126 static arm_cmsis_nn_status arm_transpose_s8_default(const int8_t *input,
127                                                     int8_t *const output,
128                                                     const cmsis_nn_dims *const input_dims,
129                                                     const int32_t *const in_strides,
130                                                     const int32_t *const out_strides)
131 {
132     const int32_t n = input_dims->n;
133     const int32_t h = input_dims->h;
134     const int32_t w = input_dims->w;
135     const int32_t c = input_dims->c;
136 
137     for (int32_t i = 0; i < n; i++)
138     {
139         for (int32_t y = 0; y < h; y++)
140         {
141             for (int32_t x = 0; x < w; x++)
142             {
143                 for (int32_t z = 0; z < c; z++)
144                 {
145                     const int32_t from_index =
146                         i * in_strides[0] + y * in_strides[1] + x * in_strides[2] + z * in_strides[3];
147 
148                     const int32_t to_index =
149                         i * out_strides[0] + y * out_strides[1] + x * out_strides[2] + z * out_strides[3];
150 
151                     output[to_index] = input[from_index];
152                 }
153             }
154         }
155     }
156 
157     return ARM_CMSIS_NN_SUCCESS;
158 }
159 
160 /*
161  * Basic s8 transpose function.
162  *
163  * Refer header file for details.
164  *
165  */
arm_transpose_s8(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const cmsis_nn_dims * const output_dims,const cmsis_nn_transpose_params * const transpose_params)166 arm_cmsis_nn_status arm_transpose_s8(const int8_t *input,
167                                      int8_t *const output,
168                                      const cmsis_nn_dims *const input_dims,
169                                      const cmsis_nn_dims *const output_dims,
170                                      const cmsis_nn_transpose_params *const transpose_params)
171 {
172     int32_t in_strides[4];
173     int32_t out_strides[4] = {0};
174 
175     const uint32_t *const perm = transpose_params->permutations;
176 
177     const int32_t n = input_dims->n;
178     const int32_t h = input_dims->h;
179     const int32_t w = input_dims->w;
180     const int32_t c = input_dims->c;
181 
182     in_strides[0] = h * w * c;
183     in_strides[1] = w * c;
184     in_strides[2] = c;
185     in_strides[3] = 1;
186 
187     if (transpose_params->num_dims == 1)
188     {
189         arm_memcpy_s8(output, input, input_dims->n);
190 
191         return ARM_CMSIS_NN_SUCCESS;
192     }
193     else if (transpose_params->num_dims == 2)
194     {
195         const cmsis_nn_dims smaller_input_dims = {1, 1, n, h};
196 
197         return arm_transpose_s8_nhcw(input, output, &smaller_input_dims, in_strides, out_strides);
198     }
199     else if (transpose_params->num_dims == 3)
200     {
201         const cmsis_nn_dims smaller_input_dims = {1, n, h, w};
202 
203         in_strides[0] = 0;
204         in_strides[1] = h * w;
205         in_strides[2] = w;
206         in_strides[3] = 1;
207 
208         if (perm[0] > 2 || perm[1] > 2 || perm[2] > 2)
209         {
210             return ARM_CMSIS_NN_ARG_ERROR;
211         }
212 
213         out_strides[0] = 0;
214         out_strides[perm[0] + 1] = output_dims->h * output_dims->w;
215         out_strides[perm[1] + 1] = output_dims->w;
216         out_strides[perm[2] + 1] = 1;
217 
218         return arm_transpose_s8_default(input, output, &smaller_input_dims, in_strides, out_strides);
219     }
220 
221     if (perm[0] > 3 || perm[1] > 3 || perm[2] > 3 || perm[3] > 3)
222     {
223         return ARM_CMSIS_NN_ARG_ERROR;
224     }
225 
226     out_strides[perm[0]] = output_dims->h * output_dims->w * output_dims->c;
227     out_strides[perm[1]] = output_dims->w * output_dims->c;
228     out_strides[perm[2]] = output_dims->c;
229     out_strides[perm[3]] = 1;
230 
231 #if defined(ARM_MATH_MVEI)
232     if (perm[0] == 0 && perm[1] == 1)
233     {
234         return arm_transpose_s8_nhcw(input, output, input_dims, in_strides, out_strides);
235     }
236 #endif
237 
238     return arm_transpose_s8_default(input, output, input_dims, in_strides, out_strides);
239 }
240 
241 /**
242  * @} end of Transpose group
243  */
244