1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import tensorflow as tf
19import math
20import numpy as np
21
22from tensorflow.lite.python.interpreter import Interpreter
23from tensorflow.lite.python.interpreter import OpResolverType
24import tf_keras as keras
25
26class Op_lstm(Lib.op_utils.Op_type):
27
28    def get_shapes(params):
29        shapes = {}
30        if params["time_major"] and params["tflite_generator"] == "json":
31            shapes["input"] = (params["time_steps"], params["batch_size"], params["input_size"])
32        else:
33            shapes["input"] = (params["batch_size"], params["time_steps"], params["input_size"])
34
35        shapes["input_weights"] = (params["input_size"], params["hidden_size"])
36        shapes["all_input_weights"] = (params["input_size"], params["hidden_size"] * 4)
37
38        shapes["hidden_weights"] = (params["hidden_size"], params["hidden_size"])
39        shapes["all_hidden_weights"] = (params["hidden_size"], params["hidden_size"] * 4)
40
41        shapes["bias"] = (1, params["hidden_size"])
42        shapes["all_bias"] = (params["hidden_size"] * 4, )
43
44        shapes["representational_dataset"] = (params["batch_size"], params["time_steps"], params["input_size"])
45        return shapes
46
47    def generate_keras_model(shapes, params):
48        input_layer = keras.layers.Input(shape=(params["time_steps"], params["input_size"]),
49                                            batch_size=params["batch_size"],
50                                            name='input')
51
52        # NOTE: use_bias = False results in an unrolled lstm operator, so it is not really supported in TFLM
53        if params["time_major"]:
54            time_major_offset = 1
55            input_layer_transposed = tf.transpose(input_layer, perm=[1, 0, 2])
56            lstm_layer = keras.layers.LSTM(units=params["hidden_size"],
57                                              time_major=params["time_major"],
58                                              return_sequences=True)(input_layer_transposed)
59        else:
60            time_major_offset = 0
61            lstm_layer = keras.layers.LSTM(units=params["hidden_size"],
62                                              time_major=params["time_major"],
63                                              return_sequences=True)(input_layer)
64
65        model = keras.Model(input_layer, lstm_layer, name="LSTM")
66
67        input_weights = Lib.op_utils.generate_tf_tensor(shapes["all_input_weights"], -1, 1, decimals=8)
68        model.layers[1 + time_major_offset].weights[0].assign(input_weights)
69
70        hidden_weights = Lib.op_utils.generate_tf_tensor(shapes["all_hidden_weights"], -1, 1, decimals=8)
71        model.layers[1 + time_major_offset].weights[1].assign(hidden_weights)
72
73        biases = Lib.op_utils.generate_tf_tensor(shapes["all_bias"], -1, 1, decimals=8) * 0
74        model.layers[1 + time_major_offset].weights[2].assign(biases)
75
76        return model
77
78    def generate_data_tflite(tflite_fname, params):
79        tensors = {}
80        effective_scales = {}
81        scales = {}
82        generated_params = {}
83
84        interpreter = Interpreter(str(tflite_fname), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
85        interpreter.allocate_tensors()
86        tensor_details = interpreter.get_tensor_details()
87
88        if params["time_major"]:
89            time_major_offset = 1
90        else:
91            time_major_offset = 0
92
93        input_state = tensor_details[0]
94        scales["input_scale"] = input_state['quantization_parameters']['scales'][0]
95        cell_state = tensor_details[14 + time_major_offset * 2]
96        scales["cell_scale"] = cell_state['quantization_parameters']['scales'][0]
97        output_state = tensor_details[13 + time_major_offset * 2]
98        scales["output_scale"] = output_state['quantization_parameters']['scales'][0]
99
100        tmp = math.log(scales["cell_scale"]) * (1 / math.log(2))
101        generated_params["cell_scale_power"] = int(round(tmp))
102
103        effective_scales["forget_to_cell"] = pow(2, -15) * scales["cell_scale"] / scales["cell_scale"]
104        effective_scales["input_to_cell"] = pow(2, -15) * pow(2, -15) / scales["cell_scale"]
105        effective_scales["output"] = pow(2, -15) * pow(2, -15) / scales["output_scale"]
106
107        #Help-function to read tensors and scales from tflite-model and calculating corresponding effective scale
108        def calc_scale(name, input_scale, tensor_index):
109            detail = tensor_details[tensor_index + time_major_offset]
110            tensors[name + "_weights"] = interpreter.get_tensor(detail["index"]).flatten()
111            scales[name + "_scale"] = detail['quantization_parameters']['scales'][0]
112            effective_scales[name] = input_scale * scales[name + "_scale"] / pow(2, -12)
113
114        calc_scale("output_gate_hidden", scales["output_scale"], 5)
115        calc_scale("cell_gate_hidden", scales["output_scale"], 6)
116        calc_scale("forget_gate_hidden", scales["output_scale"], 7)
117        calc_scale("input_gate_hidden", scales["output_scale"], 8)
118
119        calc_scale("output_gate_input", scales["input_scale"], 9)
120        calc_scale("cell_gate_input", scales["input_scale"], 10)
121        calc_scale("forget_gate_input", scales["input_scale"], 11)
122        calc_scale("input_gate_input", scales["input_scale"], 12)
123
124        tensors["output_gate_bias"] = interpreter.get_tensor(1 + time_major_offset).flatten()
125        tensors["cell_gate_bias"] = interpreter.get_tensor(2 + time_major_offset).flatten()
126        tensors["forget_gate_bias"] = interpreter.get_tensor(3 + time_major_offset).flatten()
127        tensors["input_gate_bias"] = interpreter.get_tensor(4 + time_major_offset).flatten()
128
129        generated_params["input_zero_point"] = -input_state['quantization_parameters']['zero_points'][0]
130        generated_params["output_zero_point"] = tensor_details[20 + time_major_offset *
131                                                               2]['quantization_parameters']['zero_points'][0]
132        generated_params["cell_clip"] = Lib.op_utils.get_dtype_max("int16_t")
133
134        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales)
135
136    def generate_data_json(shapes, params):
137        tensors = {}
138        scales = {}
139        effective_scales = {}
140        generated_params = {}
141
142        maxval = 0.001
143        minval = 0.0001
144
145        scales["input_scale"] = np.round(np.random.rand(1) * (maxval - minval) + minval, 6)[0]
146        scales["cell_scale"] = np.round(np.random.rand(1) * (maxval - minval) + maxval, 6)[0]
147        scales["output_scale"] = np.round(np.random.rand(1) * (maxval - minval) + minval, 6)[0]
148
149        tmp = math.log(scales["cell_scale"]) * (1 / math.log(2))
150        generated_params["cell_scale_power"] = int(round(tmp))
151
152        effective_scales["forget_to_cell"] = pow(2, -15) * scales["cell_scale"] / scales["cell_scale"]
153        effective_scales["input_to_cell"] = pow(2, -15) * pow(2, -15) / scales["cell_scale"]
154        effective_scales["output"] = pow(2, -15) * pow(2, -15) / scales["output_scale"]
155
156        #Help-function to generate a scale, and calculating corresponding effective scale
157        def create_scales(name, input_scale1):
158            scales[name + "_scale"] = np.round(np.random.rand(1) * (maxval - minval) + minval, 6)[0]
159            effective_scales[name] = input_scale1 * scales[name + "_scale"] / pow(2, -12)
160
161        create_scales("output_gate_hidden", scales["output_scale"])
162        create_scales("cell_gate_hidden", scales["output_scale"])
163        create_scales("forget_gate_hidden", scales["output_scale"])
164        create_scales("input_gate_hidden", scales["output_scale"])
165
166        create_scales("output_gate_input", scales["input_scale"])
167        create_scales("cell_gate_input", scales["input_scale"])
168        create_scales("forget_gate_input", scales["input_scale"])
169        create_scales("input_gate_input", scales["input_scale"])
170
171        minval = Lib.op_utils.get_dtype_min(params["weights_data_type"])
172        maxval = Lib.op_utils.get_dtype_max(params["weights_data_type"])
173        tensors["input_gate_hidden_weights"] = np.random.randint(minval, maxval, size=shapes["hidden_weights"])
174        tensors["forget_gate_hidden_weights"] = np.random.randint(minval, maxval, size=shapes["hidden_weights"])
175        tensors["cell_gate_hidden_weights"] = np.random.randint(minval, maxval, size=shapes["hidden_weights"])
176        tensors["output_gate_hidden_weights"] = np.random.randint(minval, maxval, size=shapes["hidden_weights"])
177        tensors["input_gate_input_weights"] = np.random.randint(minval, maxval, size=shapes["input_weights"])
178        tensors["forget_gate_input_weights"] = np.random.randint(minval, maxval, size=shapes["input_weights"])
179        tensors["cell_gate_input_weights"] = np.random.randint(minval, maxval, size=shapes["input_weights"])
180        tensors["output_gate_input_weights"] = np.random.randint(minval, maxval, size=shapes["input_weights"])
181
182        maxval = Lib.op_utils.get_dtype_max(params["input_data_type"])
183        minval = 0 # Negative weights are not supported in test generation
184        tensors["input_gate_bias"] = np.random.randint(minval, maxval, size=shapes["bias"])
185        tensors["forget_gate_bias"] = np.random.randint(minval, maxval, size=shapes["bias"])
186        tensors["cell_gate_bias"] = np.random.randint(minval, maxval, size=shapes["bias"])
187        tensors["output_gate_bias"] = np.random.randint(minval, maxval, size=shapes["bias"])
188
189        minval = Lib.op_utils.get_dtype_min(params["input_data_type"])
190        maxval = Lib.op_utils.get_dtype_max(params["input_data_type"])
191        generated_params["output_zero_point"] = 0
192        generated_params["input_zero_point"] = 0
193        generated_params["cell_clip"] = Lib.op_utils.get_dtype_max("int16_t")
194
195        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales)
196