1 /*
2  * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_svdf_s8.c
22  * Description:  S8 basic SVDF layer function
23  *
24  * $Date:        15. April 2021
25  * $Revision:    V.1.5.0
26  *
27  * Target Processor:  Cortex-M processors
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 /**
35  * @ingroup groupNN
36  */
37 
38 /**
39  * @addtogroup SVDF
40  * @{
41  */
42 
43 /*
44  * S8 SVDF layer function for TensorFlow Lite
45  *
46  * Refer to header file for details.
47  *
48  */
49 
arm_svdf_s8(const cmsis_nn_context * input_ctx,const cmsis_nn_context * output_ctx,const cmsis_nn_svdf_params * svdf_params,const cmsis_nn_per_tensor_quant_params * input_quant_params,const cmsis_nn_per_tensor_quant_params * output_quant_params,const cmsis_nn_dims * input_dims,const q7_t * input_data,const cmsis_nn_dims * state_dims,q15_t * state_data,const cmsis_nn_dims * weights_feature_dims,const q7_t * weights_feature_data,const cmsis_nn_dims * weights_time_dims,const q15_t * weights_time_data,const cmsis_nn_dims * bias_dims,const q31_t * bias_data,const cmsis_nn_dims * output_dims,q7_t * output_data)50 arm_status arm_svdf_s8(const cmsis_nn_context *input_ctx,
51                        const cmsis_nn_context *output_ctx,
52                        const cmsis_nn_svdf_params *svdf_params,
53                        const cmsis_nn_per_tensor_quant_params *input_quant_params,
54                        const cmsis_nn_per_tensor_quant_params *output_quant_params,
55                        const cmsis_nn_dims *input_dims,
56                        const q7_t *input_data,
57                        const cmsis_nn_dims *state_dims,
58                        q15_t *state_data,
59                        const cmsis_nn_dims *weights_feature_dims,
60                        const q7_t *weights_feature_data,
61                        const cmsis_nn_dims *weights_time_dims,
62                        const q15_t *weights_time_data,
63                        const cmsis_nn_dims *bias_dims,
64                        const q31_t *bias_data,
65                        const cmsis_nn_dims *output_dims,
66                        q7_t *output_data)
67 {
68     (void)bias_dims;
69     (void)state_dims;
70     (void)output_dims;
71 
72     const q31_t multiplier_in = input_quant_params->multiplier;
73     const q31_t shift_in = input_quant_params->shift;
74     const q31_t multiplier_out = output_quant_params->multiplier;
75     const q31_t shift_2 = output_quant_params->shift;
76     const int32_t zp_in = svdf_params->input_offset;
77     const int32_t zp_out = svdf_params->output_offset;
78     const int32_t in_activation_min = svdf_params->input_activation.min;
79     const int32_t in_activation_max = svdf_params->input_activation.max;
80     const int32_t out_activation_min = svdf_params->output_activation.min;
81     const int32_t out_activation_max = svdf_params->output_activation.max;
82     const int16_t rank = svdf_params->rank;
83 
84     const int32_t input_batches = input_dims->n;
85     const int32_t input_height = input_dims->h;
86     const int32_t feature_batches = weights_feature_dims->n;
87     const int32_t time_batches = weights_time_dims->h;
88     const int32_t unit_count = feature_batches / rank;
89 
90     q31_t *buffer_a = (q31_t *)input_ctx->buf;
91     q31_t *buffer_b = (q31_t *)output_ctx->buf;
92 
93     memmove((q15_t *)state_data,
94             (q15_t *)state_data + 1,
95             (size_t)(input_batches * feature_batches * time_batches * (int32_t)sizeof(int16_t)));
96 
97     for (int i_batch = 0; i_batch < input_batches; i_batch++)
98     {
99         q15_t *res_ptr = state_data + (time_batches * i_batch * feature_batches) + (time_batches - 1);
100         const q7_t *weight = weights_feature_data;
101         const q7_t *input = input_data + i_batch * input_height;
102 
103         arm_status res = arm_nn_vec_mat_mult_t_svdf_s8(input,
104                                                        weight,
105                                                        res_ptr,
106                                                        -zp_in,
107                                                        0,
108                                                        time_batches,
109                                                        multiplier_in,
110                                                        shift_in,
111                                                        input_height,
112                                                        feature_batches,
113                                                        in_activation_min,
114                                                        in_activation_max);
115 
116         if (res != ARM_MATH_SUCCESS)
117         {
118             return res;
119         }
120     }
121 
122     {
123         q31_t *ptr_a = buffer_a;
124         const q15_t *v2 = state_data;
125         for (int i_batch = 0; i_batch < input_batches; i_batch++)
126         {
127             const q15_t *v1 = weights_time_data;
128 
129             for (int i_feature_batch = 0; i_feature_batch < feature_batches; i_feature_batch++)
130             {
131                 *ptr_a = 0;
132                 int32_t sum = 0;
133 #if defined(ARM_MATH_DSP) && !defined(ARM_MATH_MVEI)
134                 int j = 0;
135                 int32_t block_count = time_batches >> 1;
136                 for (int i = 0; i < block_count; i++)
137                 {
138                     j += 2;
139                     q31_t r1 = arm_nn_read_q15x2_ia(&v1);
140                     q31_t r2 = arm_nn_read_q15x2_ia(&v2);
141 
142                     sum = __SMLAD(r1, r2, sum);
143                 }
144 
145                 // Process the remaining data
146                 for (; j < time_batches; j++)
147                 {
148                     sum += *v1 * *v2;
149                     v1++;
150                     v2++;
151                 }
152 #else
153                 for (int j = 0; j < time_batches; j++)
154                 {
155                     sum += *v1 * *v2;
156                     v1++;
157                     v2++;
158                 }
159 #endif
160 
161                 *ptr_a = sum;
162                 ptr_a++;
163             }
164         }
165     }
166 
167     if (bias_data)
168     {
169         if (unit_count == feature_batches)
170         {
171             for (int i = 0; i < input_batches; i++)
172             {
173                 q31_t *output_temp = buffer_b + i * feature_batches;
174                 const q31_t *ptr_a = buffer_a + i * feature_batches;
175 
176                 const int32_t *bi = bias_data;
177                 for (int j = 0; j < feature_batches; j++)
178                 {
179                     output_temp[j] = ptr_a[j] + bi[j];
180                 }
181             }
182         }
183         else
184         {
185             for (int i_batch = 0; i_batch < input_batches; i_batch++)
186             {
187                 q31_t *output_data_temp = buffer_b + i_batch * unit_count;
188                 q31_t *ptr_a = buffer_a + i_batch * feature_batches;
189 
190                 for (int i = 0; i < unit_count; i++)
191                 {
192                     int32_t sum = bias_data[i];
193                     for (int j = 0; j < rank; j++)
194                     {
195                         sum += *ptr_a;
196                         ptr_a++;
197                     }
198                     output_data_temp[i] = sum;
199                 }
200             }
201         }
202     }
203     else
204     {
205         for (int i_batch = 0; i_batch < input_batches; i_batch++)
206         {
207             q31_t *output_data_temp = buffer_b + i_batch * unit_count;
208             q31_t *ptr_a = buffer_a + i_batch * feature_batches;
209 
210             for (int i = 0; i < unit_count; i++)
211             {
212                 int32_t sum = 0;
213                 for (int j = 0; j < rank; j++)
214                 {
215                     sum += *ptr_a;
216                     ptr_a++;
217                 }
218                 output_data_temp[i] = sum;
219             }
220         }
221     }
222 
223 #if defined(ARM_MATH_MVEI)
224     int32_t num_elements = input_batches * unit_count;
225     const int32_t loop_count = (num_elements + 3) / 4;
226     for (int i_op = 0; i_op < loop_count; i_op++)
227     {
228         mve_pred16_t p = vctp32q((uint32_t)num_elements);
229         int32x4_t op = vldrwq_z_s32(buffer_b, p);
230         op = arm_requantize_mve(op, multiplier_out, shift_2);
231         op = vaddq_n_s32(op, zp_out);
232         const int32x4_t min_vec = vdupq_n_s32((int8_t)out_activation_min);
233         const int32x4_t max_vec = vdupq_n_s32((int8_t)out_activation_max);
234         op = vmaxq_s32(op, min_vec);
235         op = vminq_s32(op, max_vec);
236         vstrbq_p_s32(output_data, op, p);
237         output_data += 4;
238         buffer_b += 4;
239         num_elements -= 4;
240     }
241 #else
242     for (int i = 0; i < input_batches * unit_count; i++)
243     {
244         output_data[i] = (q7_t)CLAMP(
245             arm_nn_requantize(buffer_b[i], multiplier_out, shift_2) + zp_out, out_activation_max, out_activation_min);
246     }
247 #endif
248 
249     return (ARM_MATH_SUCCESS);
250 }
251 
252 /**
253  * @} end of SVDF group
254  */
255