1# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17from test_settings import TestSettings
18
19import tensorflow as tf
20
21
22class SVDFSettings(TestSettings):
23
24    def __init__(self,
25                 dataset,
26                 testtype,
27                 regenerate_weights,
28                 regenerate_input,
29                 regenerate_biases,
30                 schema_file,
31                 batches=2,
32                 number_inputs=2,
33                 rank=8,
34                 memory_size=10,
35                 randmin=TestSettings.INT8_MIN,
36                 randmax=TestSettings.INT8_MAX,
37                 input_size=3,
38                 number_units=4,
39                 generate_bias=True,
40                 int8_time_weights=False,
41                 input_scale=0.1,
42                 input_zp=0,
43                 w_1_scale=0.005,
44                 w_1_zp=0,
45                 w_2_scale=0.005,
46                 w_2_zp=0,
47                 bias_scale=0.00002,
48                 bias_zp=0,
49                 state_scale=0.005,
50                 state_zp=0,
51                 output_scale=0.1,
52                 output_zp=0,
53                 interpreter="tensorflow"):
54        super().__init__(dataset,
55                         testtype,
56                         regenerate_weights,
57                         regenerate_input,
58                         regenerate_biases,
59                         schema_file,
60                         1,
61                         1,
62                         1,
63                         1,
64                         1,
65                         1,
66                         1,
67                         1,
68                         False,
69                         randmin,
70                         randmax,
71                         generate_bias=generate_bias,
72                         interpreter=interpreter)
73        self.batches = batches
74        self.number_units = number_units
75        self.input_size = input_size
76        self.memory_size = memory_size
77        self.rank = rank
78        self.number_filters = self.number_units * self.rank
79        self.time_table_file = self.pregenerated_data_dir + self.testdataset + '/' + 'time_data.txt'
80
81        self.number_inputs = number_inputs
82        self.input_sequence_length = self.number_inputs * self.input_size * self.batches
83
84        self.int8_time_weights = int8_time_weights
85
86        if self.int8_time_weights:
87            self.json_template = "TestCases/Common/svdf_s8_weights_template.json"
88            self.in_activation_max = TestSettings.INT8_MAX
89            self.in_activation_min = TestSettings.INT8_MIN
90
91        else:
92            self.json_template = "TestCases/Common/svdf_template.json"
93            self.in_activation_max = TestSettings.INT16_MAX
94            self.in_activation_min = TestSettings.INT16_MIN
95
96        self.json_replacements = {
97            "memory_sizeXnumber_filters": self.memory_size * self.number_filters,
98            "batches": self.batches,
99            "input_size": self.input_size,
100            "number_filters": self.number_filters,
101            "memory_size": self.memory_size,
102            "number_units": self.number_units,
103            "rank_value": self.rank,
104            "input_scale": input_scale,
105            "input_zp": input_zp,
106            "w_1_scale": w_1_scale,
107            "w_1_zp": w_1_zp,
108            "w_2_scale": w_2_scale,
109            "w_2_zp": w_2_zp,
110            "bias_scale": bias_scale,
111            "bias_zp": bias_zp,
112            "state_scale": state_scale,
113            "state_zp": state_zp,
114            "output_scale": output_scale,
115            "output_zp": output_zp
116        }
117
118    def calc_multipliers_and_shifts(self, input_scale, weights_1_scale, weights_2_scale, state_scale, output_scale):
119        effective_scale_1 = weights_1_scale * input_scale / state_scale
120        effective_scale_2 = state_scale * weights_2_scale / output_scale
121        (self.multiplier_in, self.shift_1) = self.quantize_scale(effective_scale_1)
122        (self.multiplier_out, self.shift_2) = self.quantize_scale(effective_scale_2)
123
124    def write_c_config_header(self) -> None:
125        super().write_c_config_header(write_common_parameters=False)
126
127        filename = self.config_data
128        filepath = self.headers_dir + filename
129        prefix = self.testdataset.upper()
130
131        with open(filepath, "a") as f:
132            f.write("#define {}_MULTIPLIER_IN {}\n".format(prefix, self.multiplier_in))
133            f.write("#define {}_MULTIPLIER_OUT {}\n".format(prefix, self.multiplier_out))
134            f.write("#define {}_SHIFT_1 {}\n".format(prefix, self.shift_1))
135            f.write("#define {}_SHIFT_2 {}\n".format(prefix, self.shift_2))
136            f.write("#define {}_IN_ACTIVATION_MIN {}\n".format(prefix, self.in_activation_min))
137            f.write("#define {}_IN_ACTIVATION_MAX {}\n".format(prefix, self.in_activation_max))
138            f.write("#define {}_RANK {}\n".format(prefix, self.rank))
139            f.write("#define {}_FEATURE_BATCHES {}\n".format(prefix, self.number_filters))
140            f.write("#define {}_TIME_BATCHES {}\n".format(prefix, self.memory_size))
141            f.write("#define {}_INPUT_SIZE {}\n".format(prefix, self.input_size))
142            f.write("#define {}_DST_SIZE {}\n".format(prefix, self.number_units * self.batches))
143            f.write("#define {}_OUT_ACTIVATION_MIN {}\n".format(prefix, self.out_activation_min))
144            f.write("#define {}_OUT_ACTIVATION_MAX {}\n".format(prefix, self.out_activation_max))
145            f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
146            f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, self.input_zero_point))
147            f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point))
148
149    def generate_data(self, input_data=None, weights=None, biases=None, time_data=None, state_data=None) -> None:
150        if self.int8_time_weights:
151            if not self.use_tflite_micro_interpreter:
152                print("Warning: interpreter tflite_micro must be used for SVDF int8. Skipping generating headers.")
153                return
154
155        # TODO: Make this compatible with newer versions than 2.10.
156        if float(('.'.join(tf.__version__.split('.')[:2]))) > 2.10:
157            print("Warning: tensorflow version > 2.10 not supported for SVDF unit tests. Skipping generating headers")
158            return
159
160        if input_data is not None:
161            input_data = tf.reshape(input_data, [self.input_sequence_length])
162        else:
163            input_data = self.get_randomized_data([self.input_sequence_length],
164                                                  self.inputs_table_file,
165                                                  regenerate=self.regenerate_new_input)
166        self.generate_c_array("input_sequence", input_data)
167
168        if weights is not None:
169            weights_feature_data = tf.reshape(weights, [self.number_filters, self.input_size])
170        else:
171            weights_feature_data = self.get_randomized_data([self.number_filters, self.input_size],
172                                                            self.kernel_table_file,
173                                                            regenerate=self.regenerate_new_weights)
174
175        if time_data is not None:
176            weights_time_data = tf.reshape(time_data, [self.number_filters, self.memory_size])
177        else:
178            weights_time_data = self.get_randomized_data([self.number_filters, self.memory_size],
179                                                         self.time_table_file,
180                                                         regenerate=self.regenerate_new_weights)
181
182        if not self.generate_bias:
183            biases = [0] * self.number_units
184        if biases is not None:
185            biases = tf.reshape(biases, [self.number_units])
186        else:
187            biases = self.get_randomized_data([self.number_units],
188                                              self.bias_table_file,
189                                              regenerate=self.regenerate_new_weights)
190
191        # Generate tflite model
192        generated_json = self.generate_json_from_template(weights_feature_data,
193                                                          weights_time_data,
194                                                          biases,
195                                                          self.int8_time_weights)
196        self.flatc_generate_tflite(generated_json, self.schema_file)
197
198        # Run TFL interpreter
199        interpreter = self.Interpreter(model_path=str(self.model_path_tflite),
200                                       experimental_op_resolver_type=self.OpResolverType.BUILTIN_REF)
201        interpreter.allocate_tensors()
202
203        # Read back scales and zero points from tflite model
204        all_layers_details = interpreter.get_tensor_details()
205        input_layer = all_layers_details[0]
206        weights_1_layer = all_layers_details[1]
207        weights_2_layer = all_layers_details[2]
208        bias_layer = all_layers_details[3]
209        state_layer = all_layers_details[4]
210        output_layer = all_layers_details[5]
211        (input_scale, self.input_zero_point) = self.get_scale_and_zp(input_layer)
212        (weights_1_scale, zero_point) = self.get_scale_and_zp(weights_1_layer)
213        (weights_2_scale, zero_point) = self.get_scale_and_zp(weights_2_layer)
214        (bias_scale, zero_point) = self.get_scale_and_zp(bias_layer)
215        (state_scale, zero_point) = self.get_scale_and_zp(state_layer)
216        (output_scale, self.output_zero_point) = self.get_scale_and_zp(output_layer)
217
218        self.calc_multipliers_and_shifts(input_scale, weights_1_scale, weights_2_scale, state_scale, output_scale)
219
220        # Generate unit test C headers
221        self.generate_c_array("weights_feature", interpreter.get_tensor(weights_1_layer['index']))
222        self.generate_c_array(self.bias_data_file_prefix, interpreter.get_tensor(bias_layer['index']), "int32_t")
223
224        if self.int8_time_weights:
225            self.generate_c_array("weights_time", interpreter.get_tensor(weights_2_layer['index']), datatype='int8_t')
226            self.generate_c_array("state", interpreter.get_tensor(state_layer['index']), "int8_t")
227        else:
228            self.generate_c_array("weights_time", interpreter.get_tensor(weights_2_layer['index']), datatype='int16_t')
229            self.generate_c_array("state", interpreter.get_tensor(state_layer['index']), "int16_t")
230
231        if self.use_tflite_micro_interpreter:
232            interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite))
233
234        # Generate reference output
235        svdf_ref = None
236        for i in range(self.number_inputs):
237            start = i * self.input_size * self.batches
238            end = i * self.input_size * self.batches + self.input_size * self.batches
239            input_sequence = input_data[start:end]
240            input_sequence = tf.reshape(input_sequence, [self.batches, self.input_size])
241            if self.use_tflite_micro_interpreter:
242                interpreter.set_input(tf.cast(input_sequence, tf.int8), input_layer["index"])
243            else:
244                interpreter.set_tensor(input_layer["index"], tf.cast(input_sequence, tf.int8))
245            interpreter.invoke()
246            if self.use_tflite_micro_interpreter:
247                svdf_ref = interpreter.get_output(0)
248            else:
249                svdf_ref = interpreter.get_tensor(output_layer["index"])
250        self.generate_c_array(self.output_data_file_prefix, svdf_ref)
251
252        self.write_c_config_header()
253        self.write_c_header_wrapper()
254
255    def get_scale_and_zp(self, layer):
256        return (layer['quantization_parameters']['scales'][0], layer['quantization_parameters']['zero_points'][0])
257