1 /*
2 * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_svdf_s8.c
22 * Description: S8 basic SVDF layer function
23 *
24 * $Date: 24 Sep 2024
25 * $Revision: V.6.1.1
26 *
27 * Target : Arm(R) M-Profile Architecture
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33
34 /**
35 * @ingroup Public
36 */
37
38 /**
39 * @addtogroup SVDF
40 * @{
41 */
42
43 /*
44 * S8 SVDF layer function for TensorFlow Lite with 8 bit state tensor
45 *
46 * Refer to header file for details.
47 *
48 */
49
arm_svdf_s8(const cmsis_nn_context * ctx,const cmsis_nn_context * input_ctx,const cmsis_nn_context * output_ctx,const cmsis_nn_svdf_params * svdf_params,const cmsis_nn_per_tensor_quant_params * input_quant_params,const cmsis_nn_per_tensor_quant_params * output_quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * state_dims,int8_t * state_data,const cmsis_nn_dims * weights_feature_dims,const int8_t * weights_feature_data,const cmsis_nn_dims * weights_time_dims,const int8_t * weights_time_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)50 arm_cmsis_nn_status arm_svdf_s8(const cmsis_nn_context *ctx,
51 const cmsis_nn_context *input_ctx,
52 const cmsis_nn_context *output_ctx,
53 const cmsis_nn_svdf_params *svdf_params,
54 const cmsis_nn_per_tensor_quant_params *input_quant_params,
55 const cmsis_nn_per_tensor_quant_params *output_quant_params,
56 const cmsis_nn_dims *input_dims,
57 const int8_t *input_data,
58 const cmsis_nn_dims *state_dims,
59 int8_t *state_data,
60 const cmsis_nn_dims *weights_feature_dims,
61 const int8_t *weights_feature_data,
62 const cmsis_nn_dims *weights_time_dims,
63 const int8_t *weights_time_data,
64 const cmsis_nn_dims *bias_dims,
65 const int32_t *bias_data,
66 const cmsis_nn_dims *output_dims,
67 int8_t *output_data)
68 {
69 (void)bias_dims;
70 (void)state_dims;
71 (void)output_dims;
72
73 #if defined(ARM_MATH_MVEI)
74 if (ctx->buf == NULL)
75 {
76 return (ARM_CMSIS_NN_ARG_ERROR);
77 }
78 #endif
79
80 const int32_t multiplier_in = input_quant_params->multiplier;
81 const int32_t shift_in = input_quant_params->shift;
82 const int32_t multiplier_out = output_quant_params->multiplier;
83 const int32_t shift_2 = output_quant_params->shift;
84 const int32_t zp_in = svdf_params->input_offset;
85 const int32_t zp_out = svdf_params->output_offset;
86 const int32_t in_activation_min = svdf_params->input_activation.min;
87 const int32_t in_activation_max = svdf_params->input_activation.max;
88 const int32_t out_activation_min = svdf_params->output_activation.min;
89 const int32_t out_activation_max = svdf_params->output_activation.max;
90 const int16_t rank = svdf_params->rank;
91
92 const int32_t input_batches = input_dims->n;
93 const int32_t input_height = input_dims->h;
94 const int32_t feature_batches = weights_feature_dims->n;
95 const int32_t time_batches = weights_time_dims->h;
96 const int32_t unit_count = feature_batches / rank;
97
98 if (input_ctx->buf == NULL)
99 {
100 return ARM_CMSIS_NN_ARG_ERROR;
101 }
102 int32_t *buffer_a = (int32_t *)input_ctx->buf;
103
104 if (output_ctx->buf == NULL)
105 {
106 return ARM_CMSIS_NN_ARG_ERROR;
107 }
108 int32_t *buffer_b = (int32_t *)output_ctx->buf;
109
110 int32_t *kernel_sum_data = (int32_t *)ctx->buf;
111
112 // Left shift state
113 // Using memcpy on overlapping data is in general undefined behaviour, but since the behaviour of arm_memcpy_s8 is
114 // known it is certain that the data has been copied before it is overwritten in this case.
115 #ifdef ARM_MATH_MVEI
116 arm_memcpy_s8(state_data,
117 state_data + 1,
118 (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
119 #else
120 memmove(state_data,
121 state_data + 1,
122 (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
123 #endif
124
125 // Matrix multiplication input * feature weight
126 for (int i_batch = 0; i_batch < input_batches; i_batch++)
127 {
128 int8_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
129 const int8_t *input = input_data + i_batch * input_height;
130
131 arm_cmsis_nn_status res = arm_nn_vec_mat_mult_t_s8(input,
132 weights_feature_data,
133 kernel_sum_data,
134 NULL,
135 res_ptr,
136 -zp_in,
137 0,
138 multiplier_in,
139 shift_in,
140 input_height,
141 feature_batches,
142 in_activation_min,
143 in_activation_max,
144 time_batches,
145 0);
146
147 if (res != ARM_CMSIS_NN_SUCCESS)
148 {
149 return res;
150 }
151 }
152
153 // Matrix multiplicate time weight * state tensors
154 {
155 int32_t *ptr_a = buffer_a;
156 const int8_t *v2 = state_data;
157 for (int i_batch = 0; i_batch < input_batches; i_batch++)
158 {
159 const int8_t *v1 = weights_time_data;
160
161 for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
162 {
163 *ptr_a = 0;
164 int32_t sum = 0;
165 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
166 // Perform matrix multiplication in blocks of four
167 int j = 0;
168 int32_t block_count = time_batches >> 2;
169 for (int i = 0; i < block_count; i++)
170 {
171 j += 4;
172
173 int32_t r1_1, r1_2, r2_1, r2_2;
174 v1 = read_and_pad_reordered(v1, &r1_1, &r1_2);
175 v2 = read_and_pad_reordered(v2, &r2_1, &r2_2);
176 sum = SMLAD(r1_1, r2_1, sum);
177 sum = SMLAD(r1_2, r2_2, sum);
178 }
179
180 // Process the remaining data
181 for (; j < time_batches; j++)
182 {
183 sum += *v1 * *v2;
184 v1++;
185 v2++;
186 }
187 #else
188 for (int j = 0; j < time_batches; j++)
189 {
190 sum += *v1 * *v2;
191 v1++;
192 v2++;
193 }
194 #endif
195
196 *ptr_a = sum;
197 ptr_a++;
198 }
199 }
200 }
201
202 if (bias_data)
203 {
204 if (unit_count == feature_batches)
205 {
206 for (int i = 0; i < input_batches; i++)
207 {
208 int32_t *output_temp = buffer_b + i * feature_batches;
209 const int32_t *ptr_a = buffer_a + i * feature_batches;
210
211 const int32_t *bi = bias_data;
212 for (int j = 0; j < feature_batches; j++)
213 {
214 output_temp[j] = ptr_a[j] + bi[j];
215 }
216 }
217 }
218 else
219 {
220 for (int i_batch = 0; i_batch < input_batches; i_batch++)
221 {
222 int32_t *output_data_temp = buffer_b + i_batch * unit_count;
223 int32_t *ptr_a = buffer_a + i_batch * feature_batches;
224
225 for (int i = 0; i < unit_count; i++)
226 {
227 int32_t sum = bias_data[i];
228 for (int j = 0; j < rank; j++)
229 {
230 sum += *ptr_a;
231 ptr_a++;
232 }
233 output_data_temp[i] = sum;
234 }
235 }
236 }
237 }
238 else
239 {
240 for (int i_batch = 0; i_batch < input_batches; i_batch++)
241 {
242 int32_t *output_data_temp = buffer_b + i_batch * unit_count;
243 int32_t *ptr_a = buffer_a + i_batch * feature_batches;
244
245 for (int i = 0; i < unit_count; i++)
246 {
247 int32_t sum = 0;
248 for (int j = 0; j < rank; j++)
249 {
250 sum += *ptr_a;
251 ptr_a++;
252 }
253 output_data_temp[i] = sum;
254 }
255 }
256 }
257
258 #if defined(ARM_MATH_MVEI)
259 int32_t num_elements = input_batches * unit_count;
260 const int32_t loop_count = (num_elements + 3) / 4;
261 for (int i_op = 0; i_op < loop_count; i_op++)
262 {
263 mve_pred16_t p = vctp32q((uint32_t)num_elements);
264 int32x4_t op = vldrwq_z_s32(buffer_b, p);
265 op = arm_requantize_mve(op, multiplier_out, shift_2);
266 op = vaddq_n_s32(op, zp_out);
267 const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
268 const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
269 op = vmaxq_s32(op, min_vec);
270 op = vminq_s32(op, max_vec);
271 vstrbq_p_s32(output_data, op, p);
272 output_data += 4;
273 buffer_b += 4;
274 num_elements -= 4;
275 }
276 #else
277 for (int i = 0; i < input_batches * unit_count; i++)
278 {
279 output_data[i] = (int8_t)CLAMP(
280 arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
281 }
282 #endif
283
284 return (ARM_CMSIS_NN_SUCCESS);
285 }
286
287 /**
288 * @} end of SVDF group
289 */
290