1 /*
2  * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "unity.h"
20 #include <arm_nnfunctions.h>
21 
22 #include "../TestData/svdf_int8/test_data.h"
23 #include "../TestData/svdf_int8_2/test_data.h"
24 #include "../Utils/validate.h"
25 
26 #define REPEAT_NUM (1)
27 
svdf_int8_arm_svdf_s8(void)28 void svdf_int8_arm_svdf_s8(void)
29 {
30     const int32_t output_ref_size = SVDF_INT8_DST_SIZE;
31     const int8_t *output_ref = svdf_int8_output_ref;
32     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
33     cmsis_nn_context input_ctx;
34     cmsis_nn_context output_ctx;
35     cmsis_nn_svdf_params svdf_int8_params;
36     cmsis_nn_dims input_dims;
37     cmsis_nn_dims weights_feature_dims;
38     cmsis_nn_dims weights_time_dims;
39     cmsis_nn_dims state_dims;
40     cmsis_nn_dims output_dims;
41     cmsis_nn_dims bias_dims;
42     cmsis_nn_per_tensor_quant_params input_quant_params;
43     cmsis_nn_per_tensor_quant_params output_quant_params;
44     int8_t output_data[SVDF_INT8_DST_SIZE] = {1};
45     const int8_t *weights_feature_data = svdf_int8_weights_feature;
46     const int8_t *weights_time_data = svdf_int8_weights_time;
47 
48     input_dims.n = SVDF_INT8_INPUT_BATCHES;
49     input_dims.h = SVDF_INT8_INPUT_SIZE;
50     weights_feature_dims.n = SVDF_INT8_FEATURE_BATCHES;
51     weights_time_dims.h = SVDF_INT8_TIME_BATCHES;
52 
53     input_quant_params.multiplier = SVDF_INT8_MULTIPLIER_IN;
54     input_quant_params.shift = SVDF_INT8_SHIFT_1;
55     output_quant_params.multiplier = SVDF_INT8_MULTIPLIER_OUT;
56     output_quant_params.shift = SVDF_INT8_SHIFT_2;
57 
58     svdf_int8_params.input_activation.min = SVDF_INT8_IN_ACTIVATION_MIN;
59     svdf_int8_params.input_activation.max = SVDF_INT8_IN_ACTIVATION_MAX;
60     svdf_int8_params.output_activation.min = SVDF_INT8_OUT_ACTIVATION_MIN;
61     svdf_int8_params.output_activation.max = SVDF_INT8_OUT_ACTIVATION_MAX;
62     svdf_int8_params.input_offset = SVDF_INT8_INPUT_OFFSET;
63     svdf_int8_params.output_offset = SVDF_INT8_OUTPUT_OFFSET;
64     svdf_int8_params.rank = SVDF_INT8_RANK;
65 
66     const int input_round_size = SVDF_INT8_INPUT_BATCHES * SVDF_INT8_INPUT_SIZE;
67     const int number_inputs = sizeof(svdf_int8_input_sequence) / input_round_size;
68     const int32_t number_units = SVDF_INT8_FEATURE_BATCHES / SVDF_INT8_RANK;
69     const int scratch_size = SVDF_INT8_INPUT_BATCHES * SVDF_INT8_FEATURE_BATCHES * sizeof(int32_t);
70     const int scratch_size_out = SVDF_INT8_INPUT_BATCHES * number_units * sizeof(int32_t);
71 
72     cmsis_nn_context ctx;
73     const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims);
74     ctx.buf = malloc(buf_size);
75     ctx.size = buf_size;
76 
77 #if defined(ARM_MATH_MVEI)
78     int32_t *kernel_sum_buf = ctx.buf;
79     arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL);
80 #endif
81 
82     // + SVDF_INT8_TIME_BATCHES additional bytes to make sure it is not overwritten
83     const int state_data_size = sizeof(svdf_int8_state) + SVDF_INT8_TIME_BATCHES;
84     const int8_t initial_data = 66;
85 
86     input_ctx.buf = malloc(scratch_size);
87     output_ctx.buf = malloc(scratch_size_out);
88 
89     int8_t *input_data = malloc(input_round_size);
90     int8_t *state_data = malloc(state_data_size);
91 
92     memset(state_data, initial_data, state_data_size);
93 
94     for (int i = 0; i < REPEAT_NUM; i++)
95     {
96         memcpy(state_data, svdf_int8_state, sizeof(svdf_int8_state));
97         for (int j = 0; j < number_inputs; j++)
98         {
99             memcpy(input_data, svdf_int8_input_sequence + j * input_round_size, input_round_size);
100             arm_cmsis_nn_status result = arm_svdf_s8(&ctx,
101                                                      &input_ctx,
102                                                      &output_ctx,
103                                                      &svdf_int8_params,
104                                                      &input_quant_params,
105                                                      &output_quant_params,
106                                                      &input_dims,
107                                                      input_data,
108                                                      &state_dims,
109                                                      state_data,
110                                                      &weights_feature_dims,
111                                                      weights_feature_data,
112                                                      &weights_time_dims,
113                                                      weights_time_data,
114                                                      &bias_dims,
115                                                      svdf_int8_biases,
116                                                      &output_dims,
117                                                      output_data);
118             TEST_ASSERT_EQUAL(expected, result);
119         }
120 
121         TEST_ASSERT_TRUE(validate(output_data, output_ref, output_ref_size));
122     }
123 
124     if (ctx.buf)
125     {
126         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
127         memset(ctx.buf, 0, buf_size);
128         free(ctx.buf);
129     }
130 
131     // Make sure state data is not written outside boundary
132     for (int i = sizeof(svdf_int8_state); i < state_data_size; i++)
133     {
134         TEST_ASSERT_EQUAL(state_data[i], initial_data);
135     }
136 
137     free(state_data);
138     free(input_data);
139     free(input_ctx.buf);
140     free(output_ctx.buf);
141 }
142 
svdf_int8_2_arm_svdf_s8(void)143 void svdf_int8_2_arm_svdf_s8(void)
144 {
145     const int32_t output_ref_size = SVDF_INT8_2_DST_SIZE;
146     const int8_t *output_ref = svdf_int8_2_output_ref;
147     const arm_cmsis_nn_status expected = ARM_CMSIS_NN_SUCCESS;
148     cmsis_nn_context input_ctx;
149     cmsis_nn_context output_ctx;
150     cmsis_nn_svdf_params svdf_int8_2_params;
151     cmsis_nn_dims input_dims;
152     cmsis_nn_dims weights_feature_dims;
153     cmsis_nn_dims weights_time_dims;
154     cmsis_nn_dims state_dims;
155     cmsis_nn_dims output_dims;
156     cmsis_nn_dims bias_dims;
157     cmsis_nn_per_tensor_quant_params input_quant_params;
158     cmsis_nn_per_tensor_quant_params output_quant_params;
159     int8_t output_data[SVDF_INT8_2_DST_SIZE] = {1};
160     const int8_t *weights_feature_data = svdf_int8_2_weights_feature;
161     const int8_t *weights_time_data = svdf_int8_2_weights_time;
162 
163     input_dims.n = SVDF_INT8_2_INPUT_BATCHES;
164     input_dims.h = SVDF_INT8_2_INPUT_SIZE;
165     weights_feature_dims.n = SVDF_INT8_2_FEATURE_BATCHES;
166     weights_time_dims.h = SVDF_INT8_2_TIME_BATCHES;
167 
168     input_quant_params.multiplier = SVDF_INT8_2_MULTIPLIER_IN;
169     input_quant_params.shift = SVDF_INT8_2_SHIFT_1;
170     output_quant_params.multiplier = SVDF_INT8_2_MULTIPLIER_OUT;
171     output_quant_params.shift = SVDF_INT8_2_SHIFT_2;
172 
173     svdf_int8_2_params.input_activation.min = SVDF_INT8_2_IN_ACTIVATION_MIN;
174     svdf_int8_2_params.input_activation.max = SVDF_INT8_2_IN_ACTIVATION_MAX;
175     svdf_int8_2_params.output_activation.min = SVDF_INT8_2_OUT_ACTIVATION_MIN;
176     svdf_int8_2_params.output_activation.max = SVDF_INT8_2_OUT_ACTIVATION_MAX;
177     svdf_int8_2_params.input_offset = SVDF_INT8_2_INPUT_OFFSET;
178     svdf_int8_2_params.output_offset = SVDF_INT8_2_OUTPUT_OFFSET;
179     svdf_int8_2_params.rank = SVDF_INT8_2_RANK;
180 
181     const int input_round_size = SVDF_INT8_2_INPUT_BATCHES * SVDF_INT8_2_INPUT_SIZE;
182     const int number_inputs = sizeof(svdf_int8_2_input_sequence) / input_round_size;
183     const int32_t number_units = SVDF_INT8_2_FEATURE_BATCHES / SVDF_INT8_2_RANK;
184     const int scratch_size = SVDF_INT8_2_INPUT_BATCHES * SVDF_INT8_2_FEATURE_BATCHES * sizeof(int32_t);
185     const int scratch_size_out = SVDF_INT8_2_INPUT_BATCHES * number_units * sizeof(int32_t);
186 
187     cmsis_nn_context ctx;
188     const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims);
189     ctx.buf = malloc(buf_size);
190     ctx.size = buf_size;
191 
192 #if defined(ARM_MATH_MVEI)
193     int32_t *kernel_sum_buf = ctx.buf;
194     arm_vector_sum_s8(kernel_sum_buf, input_dims.h, weights_feature_dims.n, weights_feature_data, 1, NULL);
195 #endif
196 
197     const int state_data_size = sizeof(svdf_int8_2_state);
198 
199     input_ctx.buf = malloc(scratch_size);
200     output_ctx.buf = malloc(scratch_size_out);
201 
202     int8_t *input_data = malloc(input_round_size);
203     int8_t *state_data = malloc(state_data_size);
204 
205     for (int i = 0; i < REPEAT_NUM; i++)
206     {
207         memcpy(state_data, svdf_int8_2_state, sizeof(svdf_int8_2_state));
208         for (int j = 0; j < number_inputs; j++)
209         {
210             memcpy(input_data, svdf_int8_2_input_sequence + j * input_round_size, input_round_size);
211             arm_cmsis_nn_status result = arm_svdf_s8(&ctx,
212                                                      &input_ctx,
213                                                      &output_ctx,
214                                                      &svdf_int8_2_params,
215                                                      &input_quant_params,
216                                                      &output_quant_params,
217                                                      &input_dims,
218                                                      input_data,
219                                                      &state_dims,
220                                                      state_data,
221                                                      &weights_feature_dims,
222                                                      weights_feature_data,
223                                                      &weights_time_dims,
224                                                      weights_time_data,
225                                                      &bias_dims,
226                                                      svdf_int8_2_biases,
227                                                      &output_dims,
228                                                      output_data);
229             TEST_ASSERT_EQUAL(expected, result);
230         }
231 
232         TEST_ASSERT_TRUE(validate(output_data, output_ref, output_ref_size));
233     }
234 
235     if (ctx.buf)
236     {
237         // The caller is responsible to clear the scratch buffers for security reasons if applicable.
238         memset(ctx.buf, 0, buf_size);
239         free(ctx.buf);
240     }
241 
242     free(state_data);
243     free(input_data);
244     free(input_ctx.buf);
245     free(output_ctx.buf);
246 }
247