1 /*
2 * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_transpose_s8.c
22 * Description: Transpose a s8 vector
23 *
24 * $Date: 30 October 2024
25 * $Revision: V.1.0.1
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup Transpose
40 * @{
41 */
42
arm_transpose_s8_nhcw(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const int32_t * const in_strides,const int32_t * const out_strides)43 static arm_cmsis_nn_status arm_transpose_s8_nhcw(const int8_t *input,
44 int8_t *const output,
45 const cmsis_nn_dims *const input_dims,
46 const int32_t *const in_strides,
47 const int32_t *const out_strides)
48 {
49 const int32_t n = input_dims->n;
50 const int32_t h = input_dims->h;
51 const int32_t w = input_dims->w;
52 const int32_t c = input_dims->c;
53
54 const int8_t *input_n = input;
55 int8_t *output_n = output;
56
57 const uint16_t src_rows = w;
58 const uint16_t src_cols = c;
59
60 #if defined(ARM_MATH_MVEI)
61 uint16x8_t vec_offsets;
62 uint16x8_t vec_input;
63
64 vec_offsets = vidupq_u16((uint32_t)0, 1);
65 vec_offsets = vec_offsets * src_cols;
66 #endif
67
68 for (int32_t i = 0; i < n; i++)
69 {
70 const int8_t *input_h = input_n;
71 int8_t *output_h = output_n;
72
73 for (int32_t y = 0; y < h; y++)
74 {
75
76 #if defined(ARM_MATH_MVEI)
77 const uint8_t *input_c = (const uint8_t *)input_h;
78 uint8_t *output_c = (uint8_t *)output_h;
79
80 for (int32_t z = 0; z < src_cols; z++)
81 {
82 uint8_t const *input_w = (uint8_t const *)input_c;
83 uint8_t *output_w = (uint8_t *)output_c;
84
85 int32_t block_count = src_rows;
86 while (block_count > 0)
87 {
88 mve_pred16_t p = vctp16q(block_count);
89
90 vec_input = vldrbq_gather_offset_z_u16(input_w, vec_offsets, p);
91 vstrbq_p_u16(output_w, vec_input, p);
92
93 input_w = input_w + src_cols * 8;
94 output_w += 8;
95 block_count -= 8;
96 }
97
98 input_c++;
99 output_c += src_rows;
100 }
101 #else
102 const uint8_t *input_w = (const uint8_t *)input_h;
103 uint8_t *output_w = (uint8_t *)output_h;
104
105 for (int32_t src_row_i = 0; src_row_i < src_rows; src_row_i++)
106 {
107 output_w = (uint8_t *)output + src_row_i;
108
109 for (int32_t x = 0; x < src_cols; x++)
110 {
111 *output_w = *input_w++;
112 output_w += src_rows;
113 }
114 }
115 #endif
116 input_h += in_strides[1];
117 output_h += out_strides[1];
118 }
119 input_n += in_strides[0];
120 output_n += out_strides[0];
121 }
122
123 return ARM_CMSIS_NN_SUCCESS;
124 }
125
arm_transpose_s8_default(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const int32_t * const in_strides,const int32_t * const out_strides)126 static arm_cmsis_nn_status arm_transpose_s8_default(const int8_t *input,
127 int8_t *const output,
128 const cmsis_nn_dims *const input_dims,
129 const int32_t *const in_strides,
130 const int32_t *const out_strides)
131 {
132 const int32_t n = input_dims->n;
133 const int32_t h = input_dims->h;
134 const int32_t w = input_dims->w;
135 const int32_t c = input_dims->c;
136
137 for (int32_t i = 0; i < n; i++)
138 {
139 for (int32_t y = 0; y < h; y++)
140 {
141 for (int32_t x = 0; x < w; x++)
142 {
143 for (int32_t z = 0; z < c; z++)
144 {
145 const int32_t from_index =
146 i * in_strides[0] + y * in_strides[1] + x * in_strides[2] + z * in_strides[3];
147
148 const int32_t to_index =
149 i * out_strides[0] + y * out_strides[1] + x * out_strides[2] + z * out_strides[3];
150
151 output[to_index] = input[from_index];
152 }
153 }
154 }
155 }
156
157 return ARM_CMSIS_NN_SUCCESS;
158 }
159
160 /*
161 * Basic s8 transpose function.
162 *
163 * Refer header file for details.
164 *
165 */
arm_transpose_s8(const int8_t * input,int8_t * const output,const cmsis_nn_dims * const input_dims,const cmsis_nn_dims * const output_dims,const cmsis_nn_transpose_params * const transpose_params)166 arm_cmsis_nn_status arm_transpose_s8(const int8_t *input,
167 int8_t *const output,
168 const cmsis_nn_dims *const input_dims,
169 const cmsis_nn_dims *const output_dims,
170 const cmsis_nn_transpose_params *const transpose_params)
171 {
172 int32_t in_strides[4];
173 int32_t out_strides[4] = {0};
174
175 const uint32_t *const perm = transpose_params->permutations;
176
177 const int32_t n = input_dims->n;
178 const int32_t h = input_dims->h;
179 const int32_t w = input_dims->w;
180 const int32_t c = input_dims->c;
181
182 in_strides[0] = h * w * c;
183 in_strides[1] = w * c;
184 in_strides[2] = c;
185 in_strides[3] = 1;
186
187 if (transpose_params->num_dims == 1)
188 {
189 arm_memcpy_s8(output, input, input_dims->n);
190
191 return ARM_CMSIS_NN_SUCCESS;
192 }
193 else if (transpose_params->num_dims == 2)
194 {
195 const cmsis_nn_dims smaller_input_dims = {1, 1, n, h};
196
197 return arm_transpose_s8_nhcw(input, output, &smaller_input_dims, in_strides, out_strides);
198 }
199 else if (transpose_params->num_dims == 3)
200 {
201 const cmsis_nn_dims smaller_input_dims = {1, n, h, w};
202
203 in_strides[0] = 0;
204 in_strides[1] = h * w;
205 in_strides[2] = w;
206 in_strides[3] = 1;
207
208 if (perm[0] > 2 || perm[1] > 2 || perm[2] > 2)
209 {
210 return ARM_CMSIS_NN_ARG_ERROR;
211 }
212
213 out_strides[0] = 0;
214 out_strides[perm[0] + 1] = output_dims->h * output_dims->w;
215 out_strides[perm[1] + 1] = output_dims->w;
216 out_strides[perm[2] + 1] = 1;
217
218 return arm_transpose_s8_default(input, output, &smaller_input_dims, in_strides, out_strides);
219 }
220
221 if (perm[0] > 3 || perm[1] > 3 || perm[2] > 3 || perm[3] > 3)
222 {
223 return ARM_CMSIS_NN_ARG_ERROR;
224 }
225
226 out_strides[perm[0]] = output_dims->h * output_dims->w * output_dims->c;
227 out_strides[perm[1]] = output_dims->w * output_dims->c;
228 out_strides[perm[2]] = output_dims->c;
229 out_strides[perm[3]] = 1;
230
231 #if defined(ARM_MATH_MVEI)
232 if (perm[0] == 0 && perm[1] == 1)
233 {
234 return arm_transpose_s8_nhcw(input, output, input_dims, in_strides, out_strides);
235 }
236 #endif
237
238 return arm_transpose_s8_default(input, output, input_dims, in_strides, out_strides);
239 }
240
241 /**
242 * @} end of Transpose group
243 */
244