1# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the License); you may 6# not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an AS IS BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17import math 18from test_settings import TestSettings 19 20import tensorflow as tf 21import numpy as np 22import tf_keras as keras 23 24class LSTMSettings(TestSettings): 25 26 def __init__(self, 27 dataset, 28 testtype, 29 regenerate_weights, 30 regenerate_input, 31 regenerate_biases, 32 schema_file, 33 batches=2, 34 time_steps=2, 35 number_inputs=3, 36 number_units=4, 37 time_major=True, 38 randmin=TestSettings.INT8_MIN, 39 randmax=TestSettings.INT8_MAX, 40 generate_bias=True, 41 interpreter="tensorflow"): 42 super().__init__(dataset, 43 testtype, 44 regenerate_weights, 45 regenerate_input, 46 regenerate_biases, 47 schema_file, 48 1, 49 1, 50 1, 51 1, 52 1, 53 1, 54 1, 55 1, 56 False, 57 randmin, 58 randmax, 59 generate_bias=generate_bias, 60 interpreter=interpreter) 61 62 self.batches = batches 63 self.time_steps = time_steps 64 self.number_units = number_units 65 self.number_inputs = number_inputs 66 67 self.kernel_hidden_table_file = self.pregenerated_data_dir + self.testdataset + '/' + 'kernel_hidden.txt' 68 69 self.time_major = time_major 70 71 self.in_activation_max = TestSettings.INT16_MAX 72 self.in_activation_min = TestSettings.INT16_MIN 73 74 self.lstm_scales = [] 75 76 # Layer indexes. Works with tensorflow 2.10 and 2.11. 77 self.output_gate_bias_index = 1 78 self.cell_gate_bias_index = 2 79 self.forget_gate_bias_index = 3 80 self.input_gate_bias_index = 4 81 self.recurrent_input_to_output_w_index = 5 82 self.recurrent_input_to_cell_w_index = 6 83 self.recurrent_input_to_forget_w_index = 7 84 self.recurrent_input_to_input_w_index = 8 85 self.input_to_output_w_index = 9 86 self.input_to_cell_w_index = 10 87 self.input_to_forget_w_index = 11 88 self.input_to_input_w_index = 12 89 self.output_state_index = 13 90 self.cell_state_index = 14 91 self.input_norm_coeff_index = 15 92 self.forget_norm_coeff_index = 16 93 self.cell_norm_coeff_index = 17 94 self.output_norm_coeff_index = 18 95 self.effective_hidden_scale_intermediate_index = 20 96 97 def generate_data(self, input_data=None, weights=None, hidden_weights=None, biases=None) -> None: 98 99 input_dims = [self.batches, self.time_steps, self.number_inputs] 100 if input_data is not None: 101 input_data = tf.reshape(input_data, input_dims) 102 else: 103 input_data = self.get_randomized_data(input_dims, 104 self.inputs_table_file, 105 regenerate=self.regenerate_new_input) 106 107 # This will be the same size when there is no projection. 108 number_cells = self.number_units 109 110 # Each LSTM cell has 4 input weights, 4 hidden (recurrent or cell state) weights and 4 biases. 111 number_w_b = 4 112 113 if weights is not None: 114 weights = tf.reshape(weights, [self.number_inputs, number_cells * number_w_b]) 115 else: 116 weights = self.get_randomized_data([self.number_inputs, number_cells * number_w_b], 117 self.kernel_table_file, 118 regenerate=self.regenerate_new_weights, 119 decimals=8, 120 minrange=-1.0, 121 maxrange=1.0) 122 123 if hidden_weights is not None: 124 hidden_weights = tf.reshape(hidden_weights, [number_cells, number_cells * number_w_b]) 125 else: 126 hidden_weights = self.get_randomized_data([number_cells, number_cells * number_w_b], 127 self.kernel_hidden_table_file, 128 regenerate=self.regenerate_new_weights, 129 decimals=8, 130 minrange=-1.0, 131 maxrange=1.0) 132 if not self.generate_bias: 133 biases = [0] * number_cells * number_w_b 134 if biases is not None: 135 biases = tf.reshape(biases, [number_cells * number_w_b]) 136 else: 137 biases = self.get_randomized_data([number_cells * number_w_b], 138 self.bias_table_file, 139 regenerate=self.regenerate_new_bias, 140 decimals=8, 141 minrange=-1.0, 142 maxrange=1.0) 143 144 # Create a Keras based LSTM model. 145 input_layer = keras.layers.Input(shape=(self.time_steps, self.number_inputs), 146 batch_size=self.batches, 147 name='input') 148 if self.time_major: 149 input_layer_transposed = tf.transpose(input_layer, perm=[1, 0, 2]) 150 lstm_layer = keras.layers.LSTM(units=self.number_units, 151 time_major=self.time_major, 152 return_sequences=True)(input_layer_transposed) 153 else: 154 lstm_layer = keras.layers.LSTM(units=self.number_units, 155 time_major=self.time_major, 156 return_sequences=True)(input_layer) 157 model = keras.Model(input_layer, lstm_layer, name="LSTM") 158 159 if self.time_major: 160 time_major_offset = 1 161 shape = (self.time_steps, self.batches, self.number_inputs) 162 else: 163 time_major_offset = 0 164 shape = (self.batches, self.time_steps, self.number_inputs) 165 166 # Writing weight and bias to model. 167 print("Updating weights", model.layers[1 + time_major_offset].weights[0].name) 168 model.layers[1 + time_major_offset].weights[0].assign(weights) 169 print("Updating hidden weights", model.layers[1 + time_major_offset].weights[1].name) 170 model.layers[1 + time_major_offset].weights[1].assign(hidden_weights) 171 print("Updating bias", model.layers[1 + time_major_offset].weights[2].name) 172 model.layers[1 + time_major_offset].weights[2].assign(biases) 173 174 interpreter = self.convert_and_interpret(model, tf.int8, input_data, dataset_shape=shape) 175 176 all_layers_details = interpreter.get_tensor_details() 177 178 for i in all_layers_details: 179 self.lstm_scales.append(i['quantization_parameters']['scales']) 180 181 input_data_for_index = all_layers_details[0] 182 183 input_gate_bias = all_layers_details[self.input_gate_bias_index + time_major_offset] 184 forget_gate_bias = all_layers_details[self.forget_gate_bias_index + time_major_offset] 185 cell_gate_bias = all_layers_details[self.cell_gate_bias_index + time_major_offset] 186 output_gate_bias = all_layers_details[self.output_gate_bias_index + time_major_offset] 187 188 input_to_input_w = all_layers_details[self.input_to_input_w_index + time_major_offset] 189 input_to_forget_w = all_layers_details[self.input_to_forget_w_index + time_major_offset] 190 input_to_cell_w = all_layers_details[self.input_to_cell_w_index + time_major_offset] 191 input_to_output_w = all_layers_details[self.input_to_output_w_index + time_major_offset] 192 193 recurrent_input_to_input_w = all_layers_details[self.recurrent_input_to_input_w_index + time_major_offset] 194 recurrent_input_to_forget_w = all_layers_details[self.recurrent_input_to_forget_w_index + time_major_offset] 195 recurrent_input_to_cell_w = all_layers_details[self.recurrent_input_to_cell_w_index + time_major_offset] 196 recurrent_input_to_output_w = all_layers_details[self.recurrent_input_to_output_w_index + time_major_offset] 197 198 if self.time_major: 199 time_major_offset = 2 200 201 output_state = all_layers_details[self.output_state_index + time_major_offset] 202 cell_state = all_layers_details[self.cell_state_index + time_major_offset] 203 204 input_norm_coeff = all_layers_details[self.input_norm_coeff_index + time_major_offset] 205 forget_norm_coeff = all_layers_details[self.forget_norm_coeff_index + time_major_offset] 206 cell_norm_coeff = all_layers_details[self.cell_norm_coeff_index + time_major_offset] 207 output_norm_coeff = all_layers_details[self.output_norm_coeff_index + time_major_offset] 208 209 # For scale and zero point. 210 effective_hidden_scale_intermediate = all_layers_details[ 211 self.effective_hidden_scale_intermediate_index + time_major_offset] 212 213 input_details = interpreter.get_input_details() 214 output_details = interpreter.get_output_details() 215 actual_input_data = interpreter.get_tensor(input_details[0]["index"]) 216 if (input_data.numpy().shape != actual_input_data.shape) or \ 217 not ((input_data.numpy().astype(int) == actual_input_data).all().astype(int)): 218 raise RuntimeError("Input data mismatch") 219 220 self.generate_c_array(self.input_data_file_prefix, interpreter.get_tensor(input_data_for_index['index'])) 221 self.generate_c_array("input_to_input_w", interpreter.get_tensor(input_to_input_w['index'])) 222 self.generate_c_array("input_to_forget_w", interpreter.get_tensor(input_to_forget_w['index'])) 223 self.generate_c_array("input_to_cell_w", interpreter.get_tensor(input_to_cell_w['index'])) 224 self.generate_c_array("input_to_output_w", interpreter.get_tensor(input_to_output_w['index'])) 225 self.generate_c_array("recurrent_input_to_input_w", interpreter.get_tensor(recurrent_input_to_input_w['index'])) 226 self.generate_c_array("recurrent_input_to_forget_w", 227 interpreter.get_tensor(recurrent_input_to_forget_w['index'])) 228 self.generate_c_array("recurrent_input_to_cell_w", interpreter.get_tensor(recurrent_input_to_cell_w['index'])) 229 self.generate_c_array("recurrent_input_to_output_w", 230 interpreter.get_tensor(recurrent_input_to_output_w['index'])) 231 232 # Peephole not supported so these are nullptrs. 233 self.generate_c_array("cell_to_input", [], datatype='int16_t') 234 self.generate_c_array("cell_to_forget", [], datatype='int16_t') 235 self.generate_c_array("cell_to_output", [], datatype='int16_t') 236 237 self.generate_c_array("input_gate_bias", interpreter.get_tensor(input_gate_bias['index']), datatype='int32_t') 238 self.generate_c_array("cell_gate_bias", interpreter.get_tensor(cell_gate_bias['index']), datatype='int32_t') 239 self.generate_c_array("forget_gate_bias", interpreter.get_tensor(forget_gate_bias['index']), datatype='int32_t') 240 self.generate_c_array("output_gate_bias", interpreter.get_tensor(output_gate_bias['index']), datatype='int32_t') 241 242 # Projection not supported so these are nullptrs. 243 self.generate_c_array("projection_weights", []) 244 self.generate_c_array("projection_bias", [], datatype='int32_t') 245 246 self.generate_c_array("output_state", interpreter.get_tensor(output_state['index']), const="") 247 self.generate_c_array("cell_state", interpreter.get_tensor(cell_state['index']), datatype='int16_t', const="") 248 249 self.generate_c_array("input_norm_coeff", interpreter.get_tensor(input_norm_coeff['index'])) 250 self.generate_c_array("forget_norm_coeff", interpreter.get_tensor(forget_norm_coeff['index'])) 251 self.generate_c_array("cell_norm_coeff", interpreter.get_tensor(cell_norm_coeff['index'])) 252 self.generate_c_array("output_norm_coeff", interpreter.get_tensor(output_norm_coeff['index'])) 253 254 input_scale = input_data_for_index['quantization_parameters']['scales'][0] 255 self.data_zp = input_data_for_index['quantization_parameters']['zero_points'][0] 256 cell_scale = cell_state['quantization_parameters']['scales'][0] 257 output_state_scale = output_state['quantization_parameters']['scales'][0] 258 input_zp = input_data_for_index['quantization_parameters']['zero_points'][0] 259 output_zp = output_details[0]['quantization_parameters']['zero_points'][0] 260 output_state_zp = output_state['quantization_parameters']['zero_points'][0] 261 self.hidden_zp = effective_hidden_scale_intermediate['quantization_parameters']['zero_points'][0] 262 self.output_state_offset = output_state_zp 263 264 tmp = math.log(cell_scale) * (1 / math.log(2)) 265 self.cell_state_shift = int(round(tmp)) 266 267 self.calc_scales(input_scale, output_state_scale, cell_scale) 268 269 # Calculate effective biases. 270 input_zp = -input_zp 271 output_zp = -output_zp 272 output_state_zp = -output_state_zp 273 input_to_forget_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_forget_w, forget_gate_bias) 274 recurrent_to_forget_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, 275 recurrent_input_to_forget_w, None, False) 276 input_to_cell_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_cell_w, cell_gate_bias) 277 recurrent_to_cell_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_cell_w, 278 None, False) 279 input_to_output_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_output_w, output_gate_bias) 280 recurrent_to_output_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, 281 recurrent_input_to_output_w, None, False) 282 input_to_input_eff_bias = self.calc_effective_bias(interpreter, input_zp, input_to_input_w, input_gate_bias) 283 284 recurrent_to_input_eff_bias = self.calc_effective_bias(interpreter, output_state_zp, recurrent_input_to_input_w, 285 None, False) 286 287 self.generate_c_array("input_to_input_eff_bias", input_to_input_eff_bias, datatype='int32_t') 288 self.generate_c_array("input_to_forget_eff_bias", input_to_forget_eff_bias, datatype='int32_t') 289 self.generate_c_array("input_to_cell_eff_bias", input_to_cell_eff_bias, datatype='int32_t') 290 self.generate_c_array("input_to_output_eff_bias", input_to_output_eff_bias, datatype='int32_t') 291 self.generate_c_array("recurrent_to_input_eff_bias", recurrent_to_input_eff_bias, datatype='int32_t') 292 self.generate_c_array("recurrent_to_cell_eff_bias", recurrent_to_cell_eff_bias, datatype='int32_t') 293 self.generate_c_array("recurrent_to_forget_eff_bias", recurrent_to_forget_eff_bias, datatype='int32_t') 294 self.generate_c_array("recurrent_to_output_eff_bias", recurrent_to_output_eff_bias, datatype='int32_t') 295 296 # Generate reference 297 if self.use_tflite_micro_interpreter: 298 interpreter = self.tflite_micro.runtime.Interpreter.from_file(model_path=str(self.model_path_tflite)) 299 interpreter.set_input(tf.cast(input_data, tf.int8), input_details[0]["index"]) 300 interpreter.invoke() 301 output_data = interpreter.get_output(0) 302 else: 303 interpreter.invoke() 304 output_data = interpreter.get_tensor(output_details[0]["index"]) 305 306 self.generate_c_array(self.output_data_file_prefix, output_data, datatype='int8_t') 307 308 self.write_c_config_header() 309 self.write_c_header_wrapper() 310 311 def calc_scales(self, input_scale, output_state_scale, cell_scale): 312 intermediate_scale = pow(2, -12) 313 314 if self.time_major: 315 time_major_offset = 1 316 else: 317 time_major_offset = 0 318 319 320 self.effective_forget_scale = pow(2, -15) / cell_scale * cell_scale 321 self.effective_input_scale = pow(2, -15) / cell_scale * pow(2, -15) 322 self.effective_hidden_scale = pow(2, -15) / output_state_scale * pow(2, -15) 323 324 self.i2i_effective_scale = input_scale * self.lstm_scales[self.input_to_input_w_index + time_major_offset][0] \ 325 / intermediate_scale 326 self.i2f_effective_scale = input_scale * self.lstm_scales[self.input_to_forget_w_index + time_major_offset][0] \ 327 / intermediate_scale 328 self.i2c_effective_scale = input_scale * self.lstm_scales[self.input_to_cell_w_index + time_major_offset][0] \ 329 / intermediate_scale 330 self.i2o_effective_scale = input_scale * self.lstm_scales[self.input_to_output_w_index + time_major_offset][0] \ 331 / intermediate_scale 332 333 self.r2i_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_input_w_index + 334 time_major_offset][0] / intermediate_scale 335 self.r2f_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_forget_w_index + 336 time_major_offset][0] / intermediate_scale 337 self.r2c_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_cell_w_index + 338 time_major_offset][0] / intermediate_scale 339 self.r2o_effective_scale = output_state_scale * self.lstm_scales[self.recurrent_input_to_output_w_index + 340 time_major_offset][0] / intermediate_scale 341 342 def calc_effective_bias(self, interpreter, zero_point, weight_tensor, bias_tensor, has_bias=True) -> list: 343 344 weights = interpreter.get_tensor(weight_tensor['index']) 345 dims = weight_tensor['shape'] 346 row = dims[0] 347 col = dims[1] 348 349 if has_bias: 350 bias_data = interpreter.get_tensor(bias_tensor['index']) 351 output = bias_data 352 else: 353 output = np.zeros((row, ), dtype=np.int32) 354 355 for i_row in range(row): 356 row_sum = 0 357 for i_col in range(col): 358 row_sum = row_sum + weights[i_row][i_col] 359 output[i_row] = output[i_row] + row_sum * zero_point 360 361 return output 362 363 def write_c_config_header(self) -> None: 364 super().write_c_config_header(write_common_parameters=False) 365 366 filename = self.config_data 367 filepath = self.headers_dir + filename 368 prefix = self.testdataset.upper() 369 370 with open(filepath, "a") as f: 371 f.write("#define {}_BUFFER_SIZE {}\n".format(prefix, self.batches * self.number_units)) 372 f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches)) 373 f.write("#define {}_DST_SIZE {}\n".format(prefix, self.batches * self.time_steps * self.number_units)) 374 f.write("#define {}_TIME_STEPS {}\n".format(prefix, self.time_steps)) 375 f.write("#define {}_NUMBER_UNITS {}\n".format(prefix, self.number_units)) 376 f.write("#define {}_NUMBER_INPUTS {}\n".format(prefix, self.number_inputs)) 377 f.write("#define {}_TIME_MAJOR {}\n".format(prefix, int(self.time_major))) 378 f.write("#define {}_IN_ACTIVATION_MIN {}\n".format(prefix, self.in_activation_min)) 379 f.write("#define {}_IN_ACTIVATION_MAX {}\n".format(prefix, self.in_activation_max)) 380 381 (multiplier, shift) = self.quantize_scale(self.i2i_effective_scale) 382 f.write("#define {}_IN_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier)) 383 f.write("#define {}_IN_TO_INPUT_SHIFT {}\n".format(prefix, shift)) 384 (multiplier, shift) = self.quantize_scale(self.i2f_effective_scale) 385 f.write("#define {}_IN_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier)) 386 f.write("#define {}_IN_TO_FORGET_SHIFT {}\n".format(prefix, shift)) 387 (multiplier, shift) = self.quantize_scale(self.i2c_effective_scale) 388 f.write("#define {}_IN_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier)) 389 f.write("#define {}_IN_TO_CELL_SHIFT {}\n".format(prefix, shift)) 390 (multiplier, shift) = self.quantize_scale(self.i2o_effective_scale) 391 f.write("#define {}_IN_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier)) 392 f.write("#define {}_IN_TO_OUTPUT_SHIFT {}\n".format(prefix, shift)) 393 394 (multiplier, shift) = self.quantize_scale(self.r2i_effective_scale) 395 f.write("#define {}_RECURRENT_TO_INPUT_MULTIPLIER {}\n".format(prefix, multiplier)) 396 f.write("#define {}_RECURRENT_TO_INPUT_SHIFT {}\n".format(prefix, shift)) 397 (multiplier, shift) = self.quantize_scale(self.r2f_effective_scale) 398 f.write("#define {}_RECURRENT_TO_FORGET_MULTIPLIER {}\n".format(prefix, multiplier)) 399 f.write("#define {}_RECURRENT_TO_FORGET_SHIFT {}\n".format(prefix, shift)) 400 (multiplier, shift) = self.quantize_scale(self.r2c_effective_scale) 401 f.write("#define {}_RECURRENT_TO_CELL_MULTIPLIER {}\n".format(prefix, multiplier)) 402 f.write("#define {}_RECURRENT_TO_CELL_SHIFT {}\n".format(prefix, shift)) 403 (multiplier, shift) = self.quantize_scale(self.r2o_effective_scale) 404 f.write("#define {}_RECURRENT_TO_OUTPUT_MULTIPLIER {}\n".format(prefix, multiplier)) 405 f.write("#define {}_RECURRENT_TO_OUTPUT_SHIFT {}\n".format(prefix, shift)) 406 407 408 (multiplier, shift) = self.quantize_scale(self.effective_forget_scale) 409 f.write("#define {}_FORGET_MULTIPLIER {}\n".format(prefix, multiplier)) 410 f.write("#define {}_FORGET_SHIFT {}\n".format(prefix, shift)) 411 412 (multiplier, shift) = self.quantize_scale(self.effective_input_scale) 413 f.write("#define {}_INPUT_MULTIPLIER {}\n".format(prefix, multiplier)) 414 f.write("#define {}_INPUT_SHIFT {}\n".format(prefix, shift)) 415 416 (multiplier, shift) = self.quantize_scale(self.effective_hidden_scale) 417 f.write("#define {}_HIDDEN_MULTIPLIER {}\n".format(prefix, multiplier)) 418 f.write("#define {}_HIDDEN_SHIFT {}\n".format(prefix, shift)) 419 420 f.write("#define {}_HIDDEN_OFFSET {}\n".format(prefix, self.hidden_zp)) 421 f.write("#define {}_DATA_OFFSET {}\n".format(prefix, -self.data_zp)) 422 423 f.write("#define {}_OUTPUT_STATE_OFFSET {}\n".format(prefix, self.output_state_offset)) 424 f.write("#define {}_CELL_STATE_SHIFT {}\n".format(prefix, self.cell_state_shift)) 425 426 for i in range(len(self.lstm_scales)): 427 if len(self.lstm_scales[i]) == 0: 428 continue 429 (multiplier, shift) = self.quantize_scale(self.lstm_scales[i][0]) 430 431