1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the License); you may 6# not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an AS IS BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17import os 18import Lib.op_lstm 19import Lib.op_conv 20import Lib.op_fully_connected 21import tensorflow as tf 22import numpy as np 23from tensorflow.lite.python.interpreter import Interpreter 24from tensorflow.lite.python.interpreter import OpResolverType 25import pathlib 26import subprocess 27import sys 28import math 29import tf_keras as keras 30 31# Optional runtime interpreters 32try: 33 import tflite_micro 34 tflite_micro_imported = True 35except ModuleNotFoundError: 36 print("WARNING: tflite_micro not installed, skipping tests using this interpreter.") 37 tflite_micro_imported = False 38 39try: 40 from tflite_runtime.interpreter import Interpreter as TfliteRuntimeInterpreter 41 from tflite_runtime.interpreter import OpResolverType as TfliteRuntimeOpResolverType 42 43 tflite_runtime_imported = True 44except ModuleNotFoundError: 45 print("WARNING: tflite_runtime not installed, skipping tests using this interpreter.") 46 tflite_runtime_imported = False 47 48 49def generate(params, args, fpaths): 50 """ Create a test with given parameters """ 51 52 # Check if test is valid, skip otherwise 53 if (params["interpreter"] == "tflite_runtime") and (not tflite_runtime_imported): 54 print("Skipping...") 55 return 56 if (params["interpreter"] == "tflite_micro") and (not tflite_micro_imported): 57 print("Skipping...") 58 return 59 60 op_type = get_op_type(params["op_type"]) 61 shapes = op_type.get_shapes(params) 62 63 # Create test related fpaths 64 fpaths["data_folder"] = pathlib.Path("TestCases") / "TestData" / params["name"] 65 fpaths["tflite"] = fpaths["data_folder"] / f"{params['name']}.tflite" 66 fpaths["config_data"] = fpaths["data_folder"] / "config_data.h" 67 fpaths["test_data"] = fpaths["data_folder"] / "test_data.h" 68 69 # Generate reference data 70 if params["tflite_generator"] == "keras": 71 keras_model = op_type.generate_keras_model(shapes, params) 72 convert_keras_to_tflite(fpaths["tflite"], 73 keras_model, 74 quantize=True, 75 dtype=params["input_data_type"], 76 bias_dtype=params["bias_data_type"], 77 shape=shapes["representational_dataset"]) 78 data = op_type.generate_data_tflite(fpaths["tflite"], params) 79 80 elif params["tflite_generator"] == "json": 81 data = op_type.generate_data_json(shapes, params) 82 json_template_fpath = fpaths["json_template_folder"] / f"{params['json_template']}" 83 json_output_fpath = fpaths["data_folder"] / f"{params['name']}.json" 84 replacements = {**params, **data.params, **data.scales} 85 convert_json_to_tflite(json_template_fpath, json_output_fpath, data.tensors, replacements, args.schema) 86 87 else: 88 raise ValueError(f"Invalid tflite generator in {params['name']}") 89 90 params.update(data.params) 91 92 # Quantize scales 93 for name, scale in data.effective_scales.items(): 94 mult, shift = quantize_scale(scale) 95 params[name + "_multiplier"] = mult 96 params[name + "_shift"] = shift 97 98 # Run reference model 99 minval = Lib.op_utils.get_dtype_min(params["input_data_type"]) if "input_min" not in params else params["input_min"] 100 maxval = Lib.op_utils.get_dtype_max(params["input_data_type"]) if "input_max" not in params else params["input_max"] 101 102 dtype = Lib.op_utils.get_tf_dtype(params["input_data_type"]) 103 input_tensor = Lib.op_utils.generate_tf_tensor(shapes["input"], minval, maxval, decimals=0, datatype=dtype) 104 data.tensors["input"] = input_tensor.numpy() 105 106 if params["interpreter"] == "tensorflow": 107 data.tensors["output"] = invoke_tflite(fpaths["tflite"], input_tensor) 108 elif params["interpreter"] == "tflite_runtime": 109 data.tensors["output"] = invoke_tflite_runtime(fpaths["tflite"], input_tensor) 110 elif params["interpreter"] == "tflite_micro": 111 data.tensors["output"] = invoke_tflite_micro(fpaths["tflite"], input_tensor) 112 else: 113 raise ValueError(f"Invalid interpreter in {params['name']}") 114 115 # Write data 116 header = get_header(params["tflite_generator"], params["interpreter"]) 117 118 def include_in_config(key): 119 return key not in [ 120 "suite_name", "name", "input_data_type", "op_type", "input_data_type", "weights_data_type", 121 "bias_data_type", "shift_and_mult_data_type", "interpreter", "tflite_generator", "json_template", 122 "groups", "generate_bias", "bias_min", "bias_max", "weights_min", "weights_max", "bias_zp", "w_zp", 123 "input_zp", "output_zp", "w_scale", "bias_scale", "input_scale", "output_scale" 124 ] 125 126 config_params = {key: val for key, val in params.items() if include_in_config(key)} 127 write_config(fpaths["config_data"], config_params, params["name"], fpaths["test_data"], header) 128 129 for name, tensor in data.tensors.items(): 130 dtype = Lib.op_utils.get_dtype(name, params) 131 fpaths[name] = fpaths["data_folder"] / f"{name}.h" 132 if name == "output" and "out_activation_min" in params and "out_activation_max" in params: 133 tensor = np.clip(tensor, params["out_activation_min"], params["out_activation_max"]) 134 write_c_array(tensor, fpaths[name], dtype, params["name"], name, fpaths["test_data"], header) 135 136 if name in data.aliases: 137 append_alias_to_c_array_file(fpaths[name], dtype, params["name"], name, data.aliases[name]) 138 139 140def get_op_type(op_type_string): 141 if op_type_string == "lstm": 142 return Lib.op_lstm.Op_lstm 143 elif op_type_string == "conv": 144 return Lib.op_conv.Op_conv 145 elif op_type_string == "fully_connected": 146 return Lib.op_fully_connected.Op_fully_connected 147 else: 148 raise ValueError(f"Unknown op type '{op_type_string}'") 149 150 151def convert_keras_to_tflite(output_fpath, keras_model, quantize, dtype, bias_dtype, shape): 152 """ Convert a model generated with keras to tflite-format """ 153 keras_model.compile(loss=keras.losses.categorical_crossentropy, 154 optimizer=keras.optimizers.Adam(), 155 metrics=['accuracy']) 156 n_inputs = len(keras_model.inputs) 157 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) 158 159 if quantize: 160 161 def representative_dataset(): 162 for _ in range(n_inputs): 163 data = np.random.rand(*shape) 164 yield [data.astype(np.float32)] 165 166 converter.representative_dataset = representative_dataset 167 converter.optimizations = [tf.lite.Optimize.DEFAULT] 168 converter.inference_input_type = Lib.op_utils.get_tf_dtype(dtype) 169 converter.inference_output_type = Lib.op_utils.get_tf_dtype(dtype) 170 171 if dtype == "int8_t": 172 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 173 else: 174 if bias_dtype == "int32_t": 175 converter._experimental_full_integer_quantization_bias_type = tf.int32 176 converter.target_spec.supported_ops = [ 177 tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 178 ] 179 180 tflite_model = converter.convert() 181 182 output_fpath.parent.mkdir(parents=True, exist_ok=True) 183 with output_fpath.open("wb") as f: 184 f.write(tflite_model) 185 186 187def invoke_tflite(tflite_path, input_tensor): 188 interpreter = Interpreter(str(tflite_path), experimental_op_resolver_type=OpResolverType.BUILTIN_REF) 189 input_index = interpreter.get_input_details()[0]["index"] 190 interpreter.allocate_tensors() 191 interpreter.set_tensor(input_index, input_tensor) 192 interpreter.invoke() 193 output_index = interpreter.get_output_details()[0]["index"] 194 data = interpreter.get_tensor(output_index) 195 196 return data.flatten() 197 198 199def invoke_tflite_runtime(tflite_path, input_tensor): 200 interpreter = TfliteRuntimeInterpreter(str(tflite_path), 201 experimental_op_resolver_type=TfliteRuntimeOpResolverType.BUILTIN_REF) 202 input_index = interpreter.get_input_details()[0]["index"] 203 interpreter.allocate_tensors() 204 interpreter.set_tensor(input_index, input_tensor) 205 interpreter.invoke() 206 output_index = interpreter.get_output_details()[0]["index"] 207 data = interpreter.get_tensor(output_index) 208 209 return data.flatten() 210 211 212def invoke_tflite_micro(tflite_path, input_tensor): 213 interpreter = tflite_micro.runtime.Interpreter.from_file(model_path=str(tflite_path)) 214 interpreter.set_input(input_tensor, 0) 215 interpreter.invoke() 216 data = interpreter.get_output(0) 217 218 return data.flatten() 219 220 221def write_config(config_fpath, params, prefix, test_data_fpath, header): 222 with config_fpath.open("w+") as f: 223 f.write(header) 224 f.write("#pragma once\n") 225 226 for key, val in params.items(): 227 if isinstance(val, bool): 228 if val: 229 val = "true" 230 else: 231 val = "false" 232 233 f.write("#define " + f"{prefix}_{key} ".upper() + f"{val}\n") 234 235 with test_data_fpath.open("w") as f: 236 f.write(f'#include "{config_fpath.name}"\n') 237 238 239def write_c_array(data, fname, dtype, prefix, tensor_name, test_data_fpath, header): 240 241 # Check that the data looks reasonable 242 values, counts = np.unique(data, return_counts=True) 243 244 size = 0 if data is None else data.size 245 246 if len(values) < size / 2 or max(counts) > size / 2: 247 print(f"WARNING: {fname} has repeating values, is this intended?") 248 if size and len(data) > 500: 249 print(f"WARNING: {fname} has more than 500 values, is this intended?") 250 251 with fname.open("w+") as f: 252 f.write(header) 253 f.write("#pragma once\n") 254 f.write("#include <stdint.h>\n\n") 255 if size > 0: 256 data_shape = data.shape 257 format_width = len(str(data.max())) + 1 258 data = data.flatten() 259 f.write(f"const {dtype} {prefix}_{tensor_name}[{len(data)}] = \n" + "{") 260 261 for i in range(len(data) - 1): 262 if i % data_shape[-1] == 0: 263 f.write("\n") 264 f.write(f"{data[i]: {format_width}n}, ") 265 266 if len(data) - 1 % data_shape[-1] == 0: 267 f.write("\n") 268 f.write(f"{data[len(data) - 1]: {format_width}n}" + "\n};\n") 269 270 else: 271 f.write(f"const {dtype} *const {prefix}_{tensor_name} = NULL;\n") 272 273 with test_data_fpath.open("a") as f: 274 f.write(f'#include "{fname.name}"\n') 275 276 format_output_file(fname) 277 format_output_file(test_data_fpath) 278 279 280def append_alias_to_c_array_file(fname, dtype, prefix, tensor_name, alias_name): 281 with fname.open("a") as f: 282 f.write(f"\nconst {dtype} *const {prefix}_{alias_name} = {prefix}_{tensor_name};\n") 283 284 285def format_output_file(file): 286 CLANG_FORMAT = 'clang-format-12 -i' # For formatting generated headers. 287 command_list = CLANG_FORMAT.split(' ') 288 command_list.append(file) 289 try: 290 process = subprocess.run(command_list) 291 if process.returncode != 0: 292 print(f"ERROR: {command_list = }") 293 sys.exit(1) 294 except Exception as e: 295 raise RuntimeError(f"{e} from: {command_list = }") 296 297 298def generate_test_from_template(name, test_functions_fpath, template_fpath, unity_fpath): 299 template = template_fpath.read_text() 300 template = template.replace("template", name) 301 302 with test_functions_fpath.open("a") as f: 303 f.write(f'#include "../TestData/{name}/test_data.h"\n\n') 304 f.write(template) 305 306 with unity_fpath.open("a") as f: 307 f.write("void test_" + name + "(void) { " + name + "(); }\n") 308 309 310def convert_json_to_tflite(json_template_fpath, json_output_fpath, tensors, params, schema_path): 311 """ Convert a model in json-format to tflite-format""" 312 313 # Generate json with values from template 314 # This way minimizes string searching/ copying 315 json_output_fpath.parent.mkdir(parents=True, exist_ok=True) 316 with json_template_fpath.open("r") as template: 317 with json_output_fpath.open("w+") as output: 318 for line in template: 319 line_list = line.replace(",", "").split() 320 replaced = False 321 for key, val in params.items(): 322 if key in line_list: 323 if isinstance(val, bool): 324 if val: 325 val = "true" 326 else: 327 val = "false" 328 # To be able to handle cases like "variable_name" : variable_name 329 # make sure to only replace the last occurence per line 330 new_line = str(val).join(line.rsplit(key, 1)) 331 output.write(new_line) 332 replaced = True 333 break 334 335 for key in tensors: 336 if key in line: 337 dtype = Lib.op_utils.get_dtype(key, params) 338 dtype_len = Lib.op_utils.get_dtype_len(dtype) 339 np_dtype = Lib.op_utils.get_np_dtype(dtype) 340 341 # Tensors are stored byte-wise in schema 342 weights_in_bytes = [] 343 for weight in tensors[key].flatten(): 344 weights_in_bytes.extend([b for b in int(np_dtype(weight)).to_bytes(dtype_len, 'little')]) 345 346 for byte in weights_in_bytes[:-1]: 347 output.write(f" {byte},\n") 348 output.write(f" {weights_in_bytes[-1]}\n") 349 350 replaced = True 351 break 352 353 if not replaced: 354 output.write(line) 355 356 # Generate tflite from json 357 command = f"flatc -o {json_output_fpath.parent} -c -b {schema_path} {json_output_fpath}" 358 command_list = command.split() 359 try: 360 process = subprocess.run(command_list, env={'PATH': os.getenv('PATH')}) 361 if process.returncode != 0: 362 print(f"ERROR: {command = }") 363 sys.exit(1) 364 except Exception as e: 365 raise RuntimeError(f"{e} from: {command = }. Did you install flatc?") 366 367 368def quantize_scale(scale): 369 significand, shift = math.frexp(scale) 370 significand_q31 = round(significand * (1 << 31)) 371 return significand_q31, shift 372 373 374def get_header(generator, interpreter): 375 header = f"// Generated by {os.path.basename(sys.argv[0])}" 376 if generator == "keras": 377 header += f" using tensorflow version {tf.__version__} (Keras version {keras.__version__}).\n" 378 elif generator == "json": 379 command = "flatc --version" 380 command_list = command.split() 381 try: 382 process = subprocess.Popen(command_list, 383 stdout=subprocess.PIPE, 384 stderr=subprocess.PIPE, 385 env={'PATH': os.getenv('PATH')}) 386 flatc_version, err = process.communicate() 387 if process.returncode != 0: 388 print(f"ERROR: {command = }") 389 sys.exit(1) 390 except Exception as e: 391 raise RuntimeError(f"{e} from: {command = }. Did you install flatc?") 392 header += f" using {str(flatc_version)[2:-3]}\n" 393 else: 394 raise Exception 395 396 if interpreter == "tensorflow": 397 version = tf.__version__ 398 revision = tf.__git_version__ 399 header += f"// Interpreter from tensorflow version {version} and revision {revision}.\n" 400 elif interpreter == "tflite_runtime": 401 import tflite_runtime as tfl_runtime 402 403 version = tfl_runtime.__version__ 404 revision = tfl_runtime.__git_version__ 405 header += f"// Interpreter from tflite_runtime version {version} and revision {revision}.\n" 406 elif interpreter == "tflite_micro": 407 version = tflite_micro.__version__ 408 header += f"// Interpreter from tflite_micro runtime version {version}.\n" 409 else: 410 raise Exception 411 412 return header 413