1# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import math
18from test_settings import TestSettings
19
20import tensorflow as tf
21import numpy as np
22import tf_keras as keras
23
24class LSTMSettings(TestSettings):
25
26    def __init__(self,
27                 dataset,
28                 testtype,
29                 regenerate_weights,
30                 regenerate_input,
31                 regenerate_biases,
32                 schema_file,
33                 batches=2,
34                 time_steps=2,
35                 number_inputs=3,
36                 number_units=4,
37                 time_major=True,
38                 randmin=TestSettings.INT8_MIN,
39                 randmax=TestSettings.INT8_MAX,
40                 generate_bias=True,
41                 interpreter="tensorflow"):
42        super().__init__(dataset,
43                         testtype,
44                         regenerate_weights,
45                         regenerate_input,
46                         regenerate_biases,
47                         schema_file,
48                         1,
49                         1,
50                         1,
51                         1,
52                         1,
53                         1,
54                         1,
55                         1,
56                         False,
57                         randmin,
58                         randmax,
59                         generate_bias=generate_bias,
60                         interpreter=interpreter)
61
62        self.batches = batches
63        self.time_steps = time_steps
64        self.number_units = number_units
65        self.number_inputs = number_inputs
66
67        self.kernel_hidden_table_file = self.pregenerated_data_dir + self.testdataset + '/' + 'kernel_hidden.txt'
68
69        self.time_major = time_major
70
71        self.in_activation_max = TestSettings.INT16_MAX
72        self.in_activation_min = TestSettings.INT16_MIN
73
74        self.lstm_scales = []
75
76        # Layer indexes. Works with tensorflow 2.10 and 2.11.
77        self.output_gate_bias_index = 1
78        self.cell_gate_bias_index = 2
79        self.forget_gate_bias_index = 3
80        self.input_gate_bias_index = 4
81        self.recurrent_input_to_output_w_index = 5
82        self.recurrent_input_to_cell_w_index = 6
83        self.recurrent_input_to_forget_w_index = 7
84        self.recurrent_input_to_input_w_index = 8
85        self.input_to_output_w_index = 9
86        self.input_to_cell_w_index = 10
87        self.input_to_forget_w_index = 11
88        self.input_to_input_w_index = 12
89        self.output_state_index = 13
90        self.cell_state_index = 14
91        self.input_norm_coeff_index = 15
92        self.forget_norm_coeff_index = 16
93        self.cell_norm_coeff_index = 17
94        self.output_norm_coeff_index = 18
95        self.effective_hidden_scale_intermediate_index = 20
96
97    def generate_data(self, input_data=None, weights=None, hidden_weights=None, biases=None) -> None:
98
99        input_dims = [self.batches, self.time_steps, self.number_inputs]
100        if input_data is not None:
101            input_data = tf.reshape(input_data, input_dims)
102        else:
103            input_data = self.get_randomized_data(input_dims,
104                                                  self.inputs_table_file,
105                                                  regenerate=self.regenerate_new_input)
106
107        # This will be the same size when there is no projection.
108        number_cells = self.number_units
109
110        # Each LSTM cell has 4 input weights, 4 hidden (recurrent or cell state) weights and 4 biases.
111        number_w_b = 4
112
113        if weights is not None:
114            weights = tf.reshape(weights, [self.number_inputs, number_cells * number_w_b])
115        else:
116            weights = self.get_randomized_data([self.number_inputs, number_cells * number_w_b],
117                                               self.kernel_table_file,
118                                               regenerate=self.regenerate_new_weights,
119                                               decimals=8,
120                                               minrange=-1.0,
121                                               maxrange=1.0)
122
123        if hidden_weights is not None:
124            hidden_weights = tf.reshape(hidden_weights, [number_cells, number_cells * number_w_b])
125        else:
126            hidden_weights = self.get_randomized_data([number_cells, number_cells * number_w_b],
127                                                      self.kernel_hidden_table_file,
128                                                      regenerate=self.regenerate_new_weights,
129                                                      decimals=8,
130                                                      minrange=-1.0,
131                                                      maxrange=1.0)
132        if not self.generate_bias:
133            biases = [0] * number_cells * number_w_b
134        if biases is not None:
135            biases = tf.reshape(biases, [number_cells * number_w_b])
136        else:
137            biases = self.get_randomized_data([number_cells * number_w_b],
138                                              self.bias_table_file,
139                                              regenerate=self.regenerate_new_bias,
140                                              decimals=8,
141                                              minrange=-1.0,
142                                              maxrange=1.0)
143
144        # Create a Keras based LSTM model.
145        input_layer = keras.layers.Input(shape=(self.time_steps, self.number_inputs),
146                                            batch_size=self.batches,
147                                            name='input')
148        if self.time_major:
149            input_layer_transposed = tf.transpose(input_layer, perm=[1, 0, 2])
150            lstm_layer = keras.layers.LSTM(units=self.number_units,
151                                              time_major=self.time_major,
152                                              return_sequences=True)(input_layer_transposed)
153        else:
154            lstm_layer = keras.layers.LSTM(units=self.number_units,
155                                              time_major=self.time_major,
156                                              return_sequences=True)(input_layer)
157        model = keras.Model(input_layer, lstm_layer, name="LSTM")
158
159        if self.time_major:
160            time_major_offset = 1
161            shape = (self.time_steps, self.batches, self.number_inputs)
162        else:
163            time_major_offset = 0
164            shape = (self.batches, self.time_steps, self.number_inputs)
165
166        # Writing weight and bias to model.
167        print("Updating weights", model.layers[1 + time_major_offset].weights[0].name)
168        model.layers[1 + time_major_offset].weights[0].assign(weights)
169        print("Updating hidden weights", model.layers[1 + time_major_offset].weights[1].name)
170        model.layers[1 + time_major_offset].weights[1].assign(hidden_weights)
171        print("Updating bias", model.layers[1 + time_major_offset].weights[2].name)
172        model.layers[1 + time_major_offset].weights[2].assign(biases)
173
174        interpreter = self.convert_and_interpret(model, tf.int8, input_data, dataset_shape=shape)
175
176        all_layers_details = interpreter.get_tensor_details()
177
178        for i in all_layers_details:
179            self.lstm_scales.append(i['quantization_parameters']['scales'])
180
181        input_data_for_index = all_layers_details[0]
182
183        input_gate_bias = all_layers_details[self.input_gate_bias_index + time_major_offset]
184        forget_gate_bias = all_layers_details[self.forget_gate_bias_index + time_major_offset]
185        cell_gate_bias = all_layers_details[self.cell_gate_bias_index + time_major_offset]
186        output_gate_bias = all_layers_details[self.output_gate_bias_index + time_major_offset]
187
188        input_to_input_w = all_layers_details[self.input_to_input_w_index + time_major_offset]
189        input_to_forget_w = all_layers_details[self.input_to_forget_w_index + time_major_offset]
190        input_to_cell_w = all_layers_details[self.input_to_cell_w_index + time_major_offset]
191        input_to_output_w = all_layers_details[self.input_to_output_w_index + time_major_offset]
192
193        recurrent_input_to_input_w = all_layers_details[self.recurrent_input_to_input_w_index + time_major_offset]
194        recurrent_input_to_forget_w = all_layers_details[self.recurrent_input_to_forget_w_index + time_major_offset]
195        recurrent_input_to_cell_w = all_layers_details[self.recurrent_input_to_cell_w_index + time_major_offset]
196        recurrent_input_to_output_w = all_layers_details[self.recurrent_input_to_output_w_index + time_major_offset]
197
198        if self.time_major:
199            time_major_offset = 2
200
201        output_state = all_layers_details[self.output_state_index + time_major_offset]
202        cell_state = all_layers_details[self.cell_state_index + time_major_offset]
203
204        input_norm_coeff = all_layers_details[self.input_norm_coeff_index + time_major_offset]
205        forget_norm_coeff = all_layers_details[self.forget_norm_coeff_index + time_major_offset]
206        cell_norm_coeff = all_layers_details[self.cell_norm_coeff_index + time_major_offset]
207        output_norm_coeff = all_layers_details[self.output_norm_coeff_index + time_major_offset]
208
209        # For scale and zero point.
210        effective_hidden_scale_intermediate = all_layers_details[
211            self.effective_hidden_scale_intermediate_index + time_major_offset]
212
213        input_details = interpreter.get_input_details()
214        output_details = interpreter.get_output_details()
215        actual_input_data = interpreter.get_tensor(input_details[0]["index"])
216        if (input_data.numpy().shape != actual_input_data.shape) or \
217           not ((input_data.numpy().astype(int) == actual_input_data).all().astype(int)):
218            raise RuntimeError("Input data mismatch")
219
220        self.generate_c_array(self.input_data_file_prefix, interpreter.get_tensor(input_data_for_index['index']))
221        self.generate_c_array("input_to_input_w", interpreter.get_tensor(input_to_input_w['index']))
222        self.generate_c_array("input_to_forget_w", interpreter.get_tensor(input_to_forget_w['index']))
223        self.generate_c_array("input_to_cell_w", interpreter.get_tensor(input_to_cell_w['index']))
224        self.generate_c_array("input_to_output_w", interpreter.get_tensor(input_to_output_w['index']))
225        self.generate_c_array("recurrent_input_to_input_w", interpreter.get_tensor(recurrent_input_to_input_w['index']))
226        self.generate_c_array("recurrent_input_to_forget_w",
227                              interpreter.get_tensor(recurrent_input_to_forget_w['index']))
228        self.generate_c_array("recurrent_input_to_cell_w", interpreter.get_tensor(recurrent_input_to_cell_w['index']))
229        self.generate_c_array("recurrent_input_to_output_w",
230                              interpreter.get_tensor(recurrent_input_to_output_w['index']))
231
232        # Peephole not supported so these are nullptrs.
233        self.generate_c_array("cell_to_input", [], datatype='int16_t')
234        self.generate_c_array("cell_to_forget", [], datatype='int16_t')
235        self.generate_c_array("cell_to_output", [], datatype='int16_t')
236
237        self.generate_c_array("input_gate_bias", interpreter.get_tensor(input_gate_bias['index']), datatype='int32_t')
238        self.generate_c_array("cell_gate_bias", interpreter.get_tensor(cell_gate_bias['index']), datatype='int32_t')
239        self.generate_c_array("forget_gate_bias", interpreter.get_tensor(forget_gate_bias['index']), datatype='int32_t')
240        self.generate_c_array("output_gate_bias", interpreter.get_tensor(output_gate_bias['index']), datatype='int32_t')
241
242        # Projection not supported so these are nullptrs.
243        self.generate_c_array("projection_weights", [])
244        self.generate_c_array("projection_bias", [], datatype='int32_t')
245
246        self.generate_c_array("output_state", interpreter.get_tensor(output_state['index']), const="")
247        self.generate_c_array("cell_state", interpreter.get_tensor(cell_state['index']), datatype='int16_t', const="")
248
249        self.generate_c_array("input_norm_coeff", interpreter.get_tensor(input_norm_coeff['index']))
250        self.generate_c_array("forget_norm_coeff", interpreter.get_tensor(forget_norm_coeff['index']))
251        self.generate_c_array("cell_norm_coeff", interpreter.get_tensor(cell_norm_coeff['index']))
252        self.generate_c_array("output_norm_coeff", interpreter.get_tensor(output_norm_coeff['index']))
253
254        input_scale = input_data_for_index['quantization_parameters']['scales'][0]
255        self.data_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
256        cell_scale = cell_state['quantization_parameters']['scales'][0]
257        output_state_scale = output_state['quantization_parameters']['scales'][0]
258        input_zp = input_data_for_index['quantization_parameters']['zero_points'][0]
259        output_zp = output_details[0]['quantization_parameters']['zero_points'][0]
260        output_state_zp = output_state['quantization_parameters']['zero_points'][0]
261        self.hidden_zp = effective_hidden_scale_intermediate['quantization_parameters']['zero_points'][0]
262        self.output_state_offset = output_state_zp
263
264        tmp = math.log(cell_scale) * (1 / math.log(2))
265        self.cell_state_shift = int(round(tmp))
266
267        self.calc_scales(input_scale, output_state_scale, cell_scale)
268
269        # Calculate effective biases.
270        input_zp = -input_zp
271        output_zp = -output_zp
272        output_state_zp = -output_state_zp
273        input_to_forget_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_forget_w, forget_gate_bias)
274        recurrent_to_forget_eff_bias = self.calc_effective_bias(interpreter, output_state_zp,
275                                                                recurrent_input_to_forget_w, None, False)
276        input_to_cell_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_cell_w, cell_gate_bias)
277        recurrent_to_cell_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_cell_w,
278                                                              None, False)
279        input_to_output_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_output_w, output_gate_bias)
280        recurrent_to_output_eff_bias = self.calc_effective_bias(interpreter, output_state_zp,
281                                                                recurrent_input_to_output_w, None, False)
282        input_to_input_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_input_w, input_gate_bias)
283
284        recurrent_to_input_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_input_w,
285                                                               None, False)
286
287        self.generate_c_array("input_to_input_eff_bias", input_to_input_eff_bias, datatype='int32_t')
288        self.generate_c_array("input_to_forget_eff_bias", input_to_forget_eff_bias, datatype='int32_t')
289        self.generate_c_array("input_to_cell_eff_bias", input_to_cell_eff_bias, datatype='int32_t')
290        self.generate_c_array("input_to_output_eff_bias", input_to_output_eff_bias, datatype='int32_t')
291        self.generate_c_array("recurrent_to_input_eff_bias", recurrent_to_input_eff_bias, datatype='int32_t')
292        self.generate_c_array("recurrent_to_cell_eff_bias", recurrent_to_cell_eff_bias, datatype='int32_t')
293        self.generate_c_array("recurrent_to_forget_eff_bias", recurrent_to_forget_eff_bias, datatype='int32_t')
294        self.generate_c_array("recurrent_to_output_eff_bias", recurrent_to_output_eff_bias, datatype='int32_t')
295
296        # Generate reference
297        if self.use_tflite_micro_interpreter:
298            interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite))
299            interpreter.set_input(tf.cast(input_data, tf.int8), input_details[0]["index"])
300            interpreter.invoke()
301            output_data = interpreter.get_output(0)
302        else:
303            interpreter.invoke()
304            output_data = interpreter.get_tensor(output_details[0]["index"])
305
306        self.generate_c_array(self.output_data_file_prefix, output_data, datatype='int8_t')
307
308        self.write_c_config_header()
309        self.write_c_header_wrapper()
310
311    def calc_scales(self, input_scale, output_state_scale, cell_scale):
312        intermediate_scale = pow(2, -12)
313
314        if self.time_major:
315            time_major_offset = 1
316        else:
317            time_major_offset = 0
318
319
320        self.effective_forget_scale = pow(2, -15) / cell_scale * cell_scale
321        self.effective_input_scale = pow(2, -15) / cell_scale * pow(2, -15)
322        self.effective_hidden_scale = pow(2, -15) / output_state_scale * pow(2, -15)
323
324        self.i2i_effective_scale = input_scale * self.lstm_scales[self.input_to_input_w_index + time_major_offset][0] \
325            / intermediate_scale
326        self.i2f_effective_scale = input_scale * self.lstm_scales[self.input_to_forget_w_index + time_major_offset][0] \
327            / intermediate_scale
328        self.i2c_effective_scale = input_scale * self.lstm_scales[self.input_to_cell_w_index + time_major_offset][0] \
329            / intermediate_scale
330        self.i2o_effective_scale = input_scale * self.lstm_scales[self.input_to_output_w_index + time_major_offset][0] \
331            / intermediate_scale
332
333        self.r2i_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_input_w_index +
334                                                                         time_major_offset][0] / intermediate_scale
335        self.r2f_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_forget_w_index +
336                                                                         time_major_offset][0] / intermediate_scale
337        self.r2c_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_cell_w_index +
338                                                                         time_major_offset][0] / intermediate_scale
339        self.r2o_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_output_w_index +
340                                                                         time_major_offset][0] / intermediate_scale
341
342    def calc_effective_bias(self, interpreter, zero_point, weight_tensor, bias_tensor, has_bias=True) -> list:
343
344        weights = interpreter.get_tensor(weight_tensor['index'])
345        dims = weight_tensor['shape']
346        row = dims[0]
347        col = dims[1]
348
349        if has_bias:
350            bias_data = interpreter.get_tensor(bias_tensor['index'])
351            output = bias_data
352        else:
353            output = np.zeros((row, ), dtype=np.int32)
354
355        for i_row in range(row):
356            row_sum = 0
357            for i_col in range(col):
358                row_sum = row_sum + weights[i_row][i_col]
359            output[i_row] = output[i_row] + row_sum * zero_point
360
361        return output
362
363    def write_c_config_header(self) -> None:
364        super().write_c_config_header(write_common_parameters=False)
365
366        filename = self.config_data
367        filepath = self.headers_dir + filename
368        prefix = self.testdataset.upper()
369
370        with open(filepath, "a") as f:
371            f.write("#define {}_BUFFER_SIZE {}\n".format(prefix, self.batches * self.number_units))
372            f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
373            f.write("#define {}_DST_SIZE {}\n".format(prefix, self.batches * self.time_steps * self.number_units))
374            f.write("#define {}_TIME_STEPS {}\n".format(prefix, self.time_steps))
375            f.write("#define {}_NUMBER_UNITS {}\n".format(prefix, self.number_units))
376            f.write("#define {}_NUMBER_INPUTS {}\n".format(prefix, self.number_inputs))
377            f.write("#define {}_TIME_MAJOR {}\n".format(prefix, int(self.time_major)))
378            f.write("#define {}_IN_ACTIVATION_MIN {}\n".format(prefix, self.in_activation_min))
379            f.write("#define {}_IN_ACTIVATION_MAX {}\n".format(prefix, self.in_activation_max))
380
381            (multiplier, shift) = self.quantize_scale(self.i2i_effective_scale)
382            f.write("#define {}_IN_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
383            f.write("#define {}_IN_TO_INPUT_SHIFT {}\n".format(prefix, shift))
384            (multiplier, shift) = self.quantize_scale(self.i2f_effective_scale)
385            f.write("#define {}_IN_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
386            f.write("#define {}_IN_TO_FORGET_SHIFT {}\n".format(prefix, shift))
387            (multiplier, shift) = self.quantize_scale(self.i2c_effective_scale)
388            f.write("#define {}_IN_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier))
389            f.write("#define {}_IN_TO_CELL_SHIFT {}\n".format(prefix, shift))
390            (multiplier, shift) = self.quantize_scale(self.i2o_effective_scale)
391            f.write("#define {}_IN_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier))
392            f.write("#define {}_IN_TO_OUTPUT_SHIFT {}\n".format(prefix, shift))
393
394            (multiplier, shift) = self.quantize_scale(self.r2i_effective_scale)
395            f.write("#define {}_RECURRENT_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
396            f.write("#define {}_RECURRENT_TO_INPUT_SHIFT {}\n".format(prefix, shift))
397            (multiplier, shift) = self.quantize_scale(self.r2f_effective_scale)
398            f.write("#define {}_RECURRENT_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
399            f.write("#define {}_RECURRENT_TO_FORGET_SHIFT {}\n".format(prefix, shift))
400            (multiplier, shift) = self.quantize_scale(self.r2c_effective_scale)
401            f.write("#define {}_RECURRENT_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier))
402            f.write("#define {}_RECURRENT_TO_CELL_SHIFT {}\n".format(prefix, shift))
403            (multiplier, shift) = self.quantize_scale(self.r2o_effective_scale)
404            f.write("#define {}_RECURRENT_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier))
405            f.write("#define {}_RECURRENT_TO_OUTPUT_SHIFT {}\n".format(prefix, shift))
406
407
408            (multiplier, shift) = self.quantize_scale(self.effective_forget_scale)
409            f.write("#define {}_FORGET_MULTIPLIER {}\n".format(prefix, multiplier))
410            f.write("#define {}_FORGET_SHIFT {}\n".format(prefix, shift))
411
412            (multiplier, shift) = self.quantize_scale(self.effective_input_scale)
413            f.write("#define {}_INPUT_MULTIPLIER {}\n".format(prefix, multiplier))
414            f.write("#define {}_INPUT_SHIFT {}\n".format(prefix, shift))
415
416            (multiplier, shift) = self.quantize_scale(self.effective_hidden_scale)
417            f.write("#define {}_HIDDEN_MULTIPLIER {}\n".format(prefix, multiplier))
418            f.write("#define {}_HIDDEN_SHIFT {}\n".format(prefix, shift))
419
420            f.write("#define {}_HIDDEN_OFFSET {}\n".format(prefix, self.hidden_zp))
421            f.write("#define {}_DATA_OFFSET {}\n".format(prefix, -self.data_zp))
422
423            f.write("#define {}_OUTPUT_STATE_OFFSET {}\n".format(prefix, self.output_state_offset))
424            f.write("#define {}_CELL_STATE_SHIFT {}\n".format(prefix, self.cell_state_shift))
425
426            for i in range(len(self.lstm_scales)):
427                if len(self.lstm_scales[i]) == 0:
428                    continue
429                (multiplier, shift) = self.quantize_scale(self.lstm_scales[i][0])
430
431