1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_svdf_s8.c
22  * Description:  S8 basic SVDF layer function
23  *
24  * $Date:        24 Sep 2024
25  * $Revision:    V.6.1.1
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup Public
36  */
37 
38 /**
39  * @addtogroup SVDF
40  * @{
41  */
42 
43 /*
44  * S8 SVDF layer function for TensorFlow Lite with 8 bit state tensor
45  *
46  * Refer to header file for details.
47  *
48  */
49 
arm_svdf_s8(const cmsis_nn_context * ctx,const cmsis_nn_context * input_ctx,const cmsis_nn_context * output_ctx,const cmsis_nn_svdf_params * svdf_params,const cmsis_nn_per_tensor_quant_params * input_quant_params,const cmsis_nn_per_tensor_quant_params * output_quant_params,const cmsis_nn_dims * input_dims,const int8_t * input_data,const cmsis_nn_dims * state_dims,int8_t * state_data,const cmsis_nn_dims * weights_feature_dims,const int8_t * weights_feature_data,const cmsis_nn_dims * weights_time_dims,const int8_t * weights_time_data,const cmsis_nn_dims * bias_dims,const int32_t * bias_data,const cmsis_nn_dims * output_dims,int8_t * output_data)50 arm_cmsis_nn_status arm_svdf_s8(const cmsis_nn_context *ctx,
51                                 const cmsis_nn_context *input_ctx,
52                                 const cmsis_nn_context *output_ctx,
53                                 const cmsis_nn_svdf_params *svdf_params,
54                                 const cmsis_nn_per_tensor_quant_params *input_quant_params,
55                                 const cmsis_nn_per_tensor_quant_params *output_quant_params,
56                                 const cmsis_nn_dims *input_dims,
57                                 const int8_t *input_data,
58                                 const cmsis_nn_dims *state_dims,
59                                 int8_t *state_data,
60                                 const cmsis_nn_dims *weights_feature_dims,
61                                 const int8_t *weights_feature_data,
62                                 const cmsis_nn_dims *weights_time_dims,
63                                 const int8_t *weights_time_data,
64                                 const cmsis_nn_dims *bias_dims,
65                                 const int32_t *bias_data,
66                                 const cmsis_nn_dims *output_dims,
67                                 int8_t *output_data)
68 {
69     (void)bias_dims;
70     (void)state_dims;
71     (void)output_dims;
72 
73 #if defined(ARM_MATH_MVEI)
74     if (ctx->buf == NULL)
75     {
76         return (ARM_CMSIS_NN_ARG_ERROR);
77     }
78 #endif
79 
80     const int32_t multiplier_in = input_quant_params->multiplier;
81     const int32_t shift_in = input_quant_params->shift;
82     const int32_t multiplier_out = output_quant_params->multiplier;
83     const int32_t shift_2 = output_quant_params->shift;
84     const int32_t zp_in = svdf_params->input_offset;
85     const int32_t zp_out = svdf_params->output_offset;
86     const int32_t in_activation_min = svdf_params->input_activation.min;
87     const int32_t in_activation_max = svdf_params->input_activation.max;
88     const int32_t out_activation_min = svdf_params->output_activation.min;
89     const int32_t out_activation_max = svdf_params->output_activation.max;
90     const int16_t rank = svdf_params->rank;
91 
92     const int32_t input_batches = input_dims->n;
93     const int32_t input_height = input_dims->h;
94     const int32_t feature_batches = weights_feature_dims->n;
95     const int32_t time_batches = weights_time_dims->h;
96     const int32_t unit_count = feature_batches / rank;
97 
98     if (input_ctx->buf == NULL)
99     {
100         return ARM_CMSIS_NN_ARG_ERROR;
101     }
102     int32_t *buffer_a = (int32_t *)input_ctx->buf;
103 
104     if (output_ctx->buf == NULL)
105     {
106         return ARM_CMSIS_NN_ARG_ERROR;
107     }
108     int32_t *buffer_b = (int32_t *)output_ctx->buf;
109 
110     int32_t *kernel_sum_data = (int32_t *)ctx->buf;
111 
112     // Left shift state
113     // Using memcpy on overlapping data is in general undefined behaviour, but since the behaviour of arm_memcpy_s8 is
114     // known it is certain that the data has been copied before it is overwritten in this case.
115 #ifdef ARM_MATH_MVEI
116     arm_memcpy_s8(state_data,
117                   state_data + 1,
118                   (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
119 #else
120     memmove(state_data,
121             state_data + 1,
122             (size_t)((input_batches * feature_batches * time_batches - 1) * (int32_t)sizeof(int8_t)));
123 #endif
124 
125     // Matrix multiplication input * feature weight
126     for (int i_batch = 0; i_batch < input_batches; i_batch++)
127     {
128         int8_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
129         const int8_t *input = input_data + i_batch * input_height;
130 
131         arm_cmsis_nn_status res = arm_nn_vec_mat_mult_t_s8(input,
132                                                            weights_feature_data,
133                                                            kernel_sum_data,
134                                                            NULL,
135                                                            res_ptr,
136                                                            -zp_in,
137                                                            0,
138                                                            multiplier_in,
139                                                            shift_in,
140                                                            input_height,
141                                                            feature_batches,
142                                                            in_activation_min,
143                                                            in_activation_max,
144                                                            time_batches,
145                                                            0);
146 
147         if (res != ARM_CMSIS_NN_SUCCESS)
148         {
149             return res;
150         }
151     }
152 
153     // Matrix multiplicate time weight * state tensors
154     {
155         int32_t *ptr_a = buffer_a;
156         const int8_t *v2 = state_data;
157         for (int i_batch = 0; i_batch < input_batches; i_batch++)
158         {
159             const int8_t *v1 = weights_time_data;
160 
161             for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
162             {
163                 *ptr_a = 0;
164                 int32_t sum = 0;
165 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
166                 // Perform matrix multiplication in blocks of four
167                 int j = 0;
168                 int32_t block_count = time_batches >> 2;
169                 for (int i = 0; i < block_count; i++)
170                 {
171                     j += 4;
172 
173                     int32_t r1_1, r1_2, r2_1, r2_2;
174                     v1 = read_and_pad_reordered(v1, &r1_1, &r1_2);
175                     v2 = read_and_pad_reordered(v2, &r2_1, &r2_2);
176                     sum = SMLAD(r1_1, r2_1, sum);
177                     sum = SMLAD(r1_2, r2_2, sum);
178                 }
179 
180                 // Process the remaining data
181                 for (; j < time_batches; j++)
182                 {
183                     sum += *v1 * *v2;
184                     v1++;
185                     v2++;
186                 }
187 #else
188                 for (int j = 0; j < time_batches; j++)
189                 {
190                     sum += *v1 * *v2;
191                     v1++;
192                     v2++;
193                 }
194 #endif
195 
196                 *ptr_a = sum;
197                 ptr_a++;
198             }
199         }
200     }
201 
202     if (bias_data)
203     {
204         if (unit_count == feature_batches)
205         {
206             for (int i = 0; i < input_batches; i++)
207             {
208                 int32_t *output_temp = buffer_b + i * feature_batches;
209                 const int32_t *ptr_a = buffer_a + i * feature_batches;
210 
211                 const int32_t *bi = bias_data;
212                 for (int j = 0; j < feature_batches; j++)
213                 {
214                     output_temp[j] = ptr_a[j] + bi[j];
215                 }
216             }
217         }
218         else
219         {
220             for (int i_batch = 0; i_batch < input_batches; i_batch++)
221             {
222                 int32_t *output_data_temp = buffer_b + i_batch * unit_count;
223                 int32_t *ptr_a = buffer_a + i_batch * feature_batches;
224 
225                 for (int i = 0; i < unit_count; i++)
226                 {
227                     int32_t sum = bias_data[i];
228                     for (int j = 0; j < rank; j++)
229                     {
230                         sum += *ptr_a;
231                         ptr_a++;
232                     }
233                     output_data_temp[i] = sum;
234                 }
235             }
236         }
237     }
238     else
239     {
240         for (int i_batch = 0; i_batch < input_batches; i_batch++)
241         {
242             int32_t *output_data_temp = buffer_b + i_batch * unit_count;
243             int32_t *ptr_a = buffer_a + i_batch * feature_batches;
244 
245             for (int i = 0; i < unit_count; i++)
246             {
247                 int32_t sum = 0;
248                 for (int j = 0; j < rank; j++)
249                 {
250                     sum += *ptr_a;
251                     ptr_a++;
252                 }
253                 output_data_temp[i] = sum;
254             }
255         }
256     }
257 
258 #if defined(ARM_MATH_MVEI)
259     int32_t num_elements = input_batches * unit_count;
260     const int32_t loop_count = (num_elements + 3) / 4;
261     for (int i_op = 0; i_op < loop_count; i_op++)
262     {
263         mve_pred16_t p = vctp32q((uint32_t)num_elements);
264         int32x4_t op = vldrwq_z_s32(buffer_b, p);
265         op = arm_requantize_mve(op, multiplier_out, shift_2);
266         op = vaddq_n_s32(op, zp_out);
267         const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
268         const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
269         op = vmaxq_s32(op, min_vec);
270         op = vminq_s32(op, max_vec);
271         vstrbq_p_s32(output_data, op, p);
272         output_data += 4;
273         buffer_b += 4;
274         num_elements -= 4;
275     }
276 #else
277     for (int i = 0; i < input_batches * unit_count; i++)
278     {
279         output_data[i] = (int8_t)CLAMP(
280             arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
281     }
282 #endif
283 
284     return (ARM_CMSIS_NN_SUCCESS);
285 }
286 
287 /**
288  * @} end of SVDF group
289  */
290