1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import os
18import Lib.op_lstm
19import Lib.op_conv
20import Lib.op_batch_matmul
21import Lib.op_fully_connected
22import Lib.op_pooling
23import Lib.op_pad
24import Lib.op_maximum_minimum
25import Lib.op_transpose
26import tensorflow as tf
27import numpy as np
28from tensorflow.lite.python.interpreter import Interpreter
29from tensorflow.lite.python.interpreter import OpResolverType
30import pathlib
31import subprocess
32import sys
33import math
34import keras
35
36# Optional runtime interpreters
37try:
38    import tflite_micro
39    tflite_micro_imported = True
40except ModuleNotFoundError:
41    print("WARNING: tflite_micro not installed, skipping tests using this interpreter.")
42    tflite_micro_imported = False
43
44try:
45    from tflite_runtime.interpreter import Interpreter as TfliteRuntimeInterpreter
46    from tflite_runtime.interpreter import OpResolverType as TfliteRuntimeOpResolverType
47
48    tflite_runtime_imported = True
49except ModuleNotFoundError:
50    print("WARNING: tflite_runtime not installed, skipping tests using this interpreter.")
51    tflite_runtime_imported = False
52
53
54def generate(params, args, fpaths):
55    """ Create a test with given parameters """
56
57    # Check if test is valid, skip otherwise
58    if (params["interpreter"] == "tflite_runtime") and (not tflite_runtime_imported):
59        print("Skipping due to tflite_runtime not being installed...")
60        return
61    if (params["interpreter"] == "tflite_micro") and (not tflite_micro_imported):
62        print("Skipping due to tflite_micro not being installed...")
63        return
64
65    op_type = get_op_type(params["op_type"])
66    shapes = op_type.get_shapes(params)
67
68    # Create test related fpaths
69    fpaths["data_folder"] = pathlib.Path("TestCases") / "TestData" / params["name"]
70    fpaths["tflite"] = fpaths["data_folder"] / f"{params['name']}.tflite"
71    fpaths["config_data"] = fpaths["data_folder"] / "config_data.h"
72    fpaths["test_data"] = fpaths["data_folder"] / "test_data.h"
73
74    # Generate reference data
75    if params["tflite_generator"] == "keras":
76        keras_model = op_type.generate_keras_model(shapes, params)
77
78        per_tensor_quant_for_dense = False
79        try:
80            per_tensor_quant_for_dense = not params["per_channel_quant"]
81        except KeyError:
82            pass
83
84        if "bias_data_type" in params:
85            bias_dtype = params["bias_data_type"]
86        else:
87            bias_dtype = None
88
89        convert_keras_to_tflite(fpaths["tflite"],
90                                keras_model,
91                                quantize=True,
92                                dtype=params["input_data_type"],
93                                bias_dtype=bias_dtype,
94                                shape=shapes,
95                                per_tensor_quant_for_dense=per_tensor_quant_for_dense)
96
97        data = op_type.generate_data_tflite(fpaths["tflite"], params)
98
99    elif params["tflite_generator"] == "json":
100        data = op_type.generate_data_json(shapes, params)
101        json_template_fpath = fpaths["json_template_folder"] / f"{params['json_template']}"
102        json_output_fpath = fpaths["data_folder"] / f"{params['name']}.json"
103        replacements = {**params, **data.params, **data.scales}
104        convert_json_to_tflite(json_template_fpath, json_output_fpath, data.tensors, replacements, args.schema)
105
106    else:
107        raise ValueError(f"Invalid tflite generator in {params['name']}")
108
109    data = op_type.post_model_update(fpaths["tflite"], data, params)
110
111    params.update(data.params)
112
113    # Quantize scales
114    for name, scale in data.effective_scales.items():
115        mult, shift = quantize_scale(scale)
116        params[name + "_multiplier"] = mult
117        params[name + "_shift"] = shift
118
119    # Run reference model
120    minval = Lib.op_utils.get_dtype_min(params["input_data_type"]) if "input_min" not in params else params["input_min"]
121    maxval = Lib.op_utils.get_dtype_max(params["input_data_type"]) if "input_max" not in params else params["input_max"]
122
123    dtype = Lib.op_utils.get_tf_dtype(params["input_data_type"])
124
125    # Initialize input tensors
126    input_tensors = {}
127    for shape_name, shape in shapes.items():
128        if "input_tensor" in shape_name:
129            if shape_name in data.tensors:
130                input_tensors[shape_name] = data.tensors[shape_name]
131            else:
132                input_tensors[shape_name] = Lib.op_utils.generate_tf_tensor(shape, minval, maxval, decimals=0, datatype=dtype)
133                data.tensors[shape_name] = input_tensors[shape_name].numpy()
134
135    if not input_tensors:
136        raise ValueError("Op_type must initialize at least one input shape")
137
138    if params["interpreter"] == "tensorflow":
139        data.tensors["output"] = invoke_tflite(fpaths["tflite"], input_tensors)
140    elif params["interpreter"] == "tflite_runtime":
141        data.tensors["output"] = invoke_tflite_runtime(fpaths["tflite"], input_tensors)
142    elif params["interpreter"] == "tflite_micro":
143        if "arena_size" in params:
144            data.tensors["output"] = invoke_tflite_micro(fpaths["tflite"], input_tensors, params["arena_size"])
145        else:
146            data.tensors["output"] = invoke_tflite_micro(fpaths["tflite"], input_tensors)
147    else:
148        raise ValueError(f"Invalid interpreter in {params['name']}")
149
150    if "activation_min" in params:
151        data.tensors["output"] = np.maximum(data.tensors["output"], params["activation_min"])
152    if "activation_max" in params:
153        data.tensors["output"] = np.minimum(data.tensors["output"], params["activation_max"])
154
155    # Write data
156    header = get_header(params["tflite_generator"], params["interpreter"])
157
158    def include_in_config(key):
159        return key not in [
160            "suite_name", "name", "input_data_type", "op_type", "input_data_type", "weights_data_type",
161            "bias_data_type", "shift_and_mult_data_type", "interpreter", "tflite_generator", "json_template",
162            "groups", "generate_bias", "bias_min", "bias_max", "weights_min", "weights_max", "bias_zp", "w_zp",
163            "input_zp", "output_zp", "w_scale", "bias_scale", "input_scale", "output_scale", "arena_size"
164        ]
165
166    config_params = {key: val for key, val in params.items() if include_in_config(key)}
167    write_config(fpaths["config_data"], config_params, params["name"], fpaths["test_data"], header)
168
169    for name, tensor in data.tensors.items():
170        dtype = Lib.op_utils.get_dtype(name, params)
171        fpaths[name] = fpaths["data_folder"] / f"{name}.h"
172        if name == "output" and "out_activation_min" in params and "out_activation_max" in params:
173            tensor = np.clip(tensor, params["out_activation_min"], params["out_activation_max"])
174        write_c_array(tensor, fpaths[name], dtype, params["name"], name, fpaths["test_data"], header)
175
176        if name in data.aliases:
177            append_alias_to_c_array_file(fpaths[name], dtype, params["name"], name, data.aliases[name])
178
179
180def get_op_type(op_type_string):
181    if op_type_string == "lstm":
182        return Lib.op_lstm.Op_lstm
183    elif op_type_string == "conv":
184        return Lib.op_conv.Op_conv
185    elif op_type_string == "batch_matmul":
186        return Lib.op_batch_matmul.Op_batch_matmul
187    elif op_type_string == "fully_connected":
188        return Lib.op_fully_connected.Op_fully_connected
189    elif op_type_string == "avgpool" or op_type_string == "maxpool":
190        return Lib.op_pooling.Op_pooling
191    if op_type_string == "pad":
192        return Lib.op_pad.Op_pad
193    elif op_type_string == "maximum_minimum":
194        return Lib.op_maximum_minimum.Op_maximum_minimum
195    elif op_type_string == "transpose":
196        return Lib.op_transpose.Op_transpose
197    else:
198        raise ValueError(f"Unknown op type '{op_type_string}'")
199
200
201def convert_keras_to_tflite(
202        output_fpath, keras_model, quantize, dtype, bias_dtype, shape, per_tensor_quant_for_dense=False):
203    """ Convert a model generated with keras to tflite-format """
204    keras_model.compile(loss=keras.losses.categorical_crossentropy,
205                        metrics=['accuracy'])
206    n_inputs = len(keras_model.inputs)
207    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
208    if quantize:
209
210        if shape.get("different_in_shapes") is True:
211            def representative_dataset():
212                for _ in range(100):
213                    data1 = np.random.rand(*shape["representational_dataset"])
214                    data2 = np.random.rand(*shape["representational_dataset2"])
215                    yield [data1.astype(np.float32), data2.astype(np.float32)]
216        else:
217            def representative_dataset():
218                for _ in range(n_inputs):
219                    data = np.random.rand(*shape["representational_dataset"])
220                    yield [data.astype(np.float32)]
221
222        converter.representative_dataset = representative_dataset
223        converter.optimizations = [tf.lite.Optimize.DEFAULT]
224        converter.inference_input_type = Lib.op_utils.get_tf_dtype(dtype)
225        converter.inference_output_type = Lib.op_utils.get_tf_dtype(dtype)
226        converter._experimental_disable_per_channel_quantization_for_dense_layers = per_tensor_quant_for_dense
227
228        if dtype == "int8_t":
229            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
230        else:
231            if bias_dtype == "int32_t":
232                converter._experimental_full_integer_quantization_bias_type = tf.int32
233            converter.target_spec.supported_ops = [
234                tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
235            ]
236
237        tflite_model = converter.convert()
238
239    output_fpath.parent.mkdir(parents=True, exist_ok=True)
240    with output_fpath.open("wb") as f:
241        f.write(tflite_model)
242
243
244def invoke_tflite(tflite_path, input_tensor):
245    interpreter = Interpreter(str(tflite_path), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
246    interpreter.allocate_tensors()
247
248    for i, val in enumerate(input_tensor.values()):
249        input_index = interpreter.get_input_details()[i]["index"]
250        interpreter.set_tensor(input_index, val)
251
252    interpreter.invoke()
253    output_index = interpreter.get_output_details()[0]["index"]
254    data = interpreter.get_tensor(output_index)
255
256    return data.flatten()
257
258
259def invoke_tflite_runtime(tflite_path, input_tensor):
260    interpreter = TfliteRuntimeInterpreter(str(tflite_path),
261                                           experimental_op_resolver_type=TfliteRuntimeOpResolverType.BUILTIN_REF)
262    interpreter.allocate_tensors()
263
264    for i, val in enumerate(input_tensor.values()):
265        input_index = interpreter.get_input_details()[i]["index"]
266        interpreter.set_tensor(input_index, val)
267
268    interpreter.invoke()
269    output_index = interpreter.get_output_details()[0]["index"]
270    data = interpreter.get_tensor(output_index)
271
272    return data.flatten()
273
274
275def invoke_tflite_micro(tflite_path, input_tensor, arena_size=30000):
276    interpreter = tflite_micro.runtime.Interpreter.from_file(model_path=str(tflite_path), arena_size=arena_size)
277
278    for i, val in enumerate(input_tensor.values()):
279        interpreter.set_input(val, i)
280
281    interpreter.invoke()
282    data = interpreter.get_output(0)
283
284    return data.flatten()
285
286
287def write_config(config_fpath, params, prefix, test_data_fpath, header):
288    config_fpath.parent.mkdir(parents=True, exist_ok=True)
289    with config_fpath.open("w+") as f:
290        f.write(header)
291        f.write("#pragma once\n\n")
292
293        for key, val in params.items():
294            if isinstance(val, list):
295                f.write("#define " + f"{prefix}_{key} ".upper() + "{")
296                for v in val:
297                    f.write(f"{v}, ")
298                f.write("}\n")
299                continue
300            if isinstance(val, bool):
301                if val:
302                    val = "true"
303                else:
304                    val = "false"
305
306            f.write("#define " + f"{prefix}_{key} ".upper() + f"{val}\n")
307    format_output_file(config_fpath)
308
309    with test_data_fpath.open("w") as f:
310        f.write(header)
311        f.write(f'#include "{config_fpath.name}"\n')
312
313
314def write_c_array(data, fname, dtype, prefix, tensor_name, test_data_fpath, header):
315
316    # Check that the data looks reasonable
317    values, counts = np.unique(data, return_counts=True)
318    tf.experimental.numpy.experimental_enable_numpy_behavior()
319
320    size = 0 if data is None else data.size
321
322    if len(values) < size / 2 or max(counts) > size / 2:
323        print(f"WARNING: {fname} has repeating values, is this intended?")
324    if size and len(data) > 500:
325        print(f"WARNING: {fname} has more than 500 values, is this intended?")
326
327    with fname.open("w+") as f:
328        f.write(header)
329        f.write("#pragma once\n")
330        f.write("#include <stdint.h>\n\n")
331        if size > 0:
332            data_shape = data.shape
333            format_width = len(str(data.max())) + 1
334            data = data.flatten()
335
336            f.write(f"const {dtype} {prefix}_{tensor_name}[{len(data)}] = \n" + "{")
337
338            for i in range(len(data) - 1):
339                if i % data_shape[-1] == 0:
340                    f.write("\n")
341                f.write(f"{data[i]: {format_width}n}, ")
342
343            if len(data) - 1 % data_shape[-1] == 0:
344                f.write("\n")
345            f.write(f"{data[len(data) - 1]: {format_width}n}" + "\n};\n")
346
347        else:
348            f.write(f"const {dtype} *const {prefix}_{tensor_name} = NULL;\n")
349
350    with test_data_fpath.open("a") as f:
351        f.write(f'#include "{fname.name}"\n')
352
353    format_output_file(fname)
354    format_output_file(test_data_fpath)
355
356
357def append_alias_to_c_array_file(fname, dtype, prefix, tensor_name, alias_name):
358    with fname.open("a") as f:
359        f.write(f"\nconst {dtype} *const {prefix}_{alias_name} = {prefix}_{tensor_name};\n")
360
361
362def format_output_file(file):
363    CLANG_FORMAT = 'clang-format-12 -i'  # For formatting generated headers.
364    command_list = CLANG_FORMAT.split(' ')
365    command_list.append(file)
366    try:
367        process = subprocess.run(command_list)
368        if process.returncode != 0:
369            print(f"ERROR: {command_list = }")
370            sys.exit(1)
371    except Exception as e:
372        raise RuntimeError(f"{e} from: {command_list = }")
373
374
375def generate_test_from_template(name, test_functions_fpath, template_fpath, unity_fpath):
376    template = template_fpath.read_text()
377    template = template.replace("template", name)
378
379    with test_functions_fpath.open("a") as f:
380        f.write(f'#include "../TestData/{name}/test_data.h"\n\n')
381        f.write(template)
382
383    with unity_fpath.open("a") as f:
384        f.write("void test_" + name + "(void) { " + name + "(); }\n")
385
386
387def convert_json_to_tflite(json_template_fpath, json_output_fpath, tensors, params, schema_path):
388    """ Convert a model in json-format to tflite-format"""
389
390    # Generate json with values from template
391    # This way minimizes string searching/ copying
392    json_output_fpath.parent.mkdir(parents=True, exist_ok=True)
393    with json_template_fpath.open("r") as template:
394        with json_output_fpath.open("w+") as output:
395            for line in template:
396                line_list = line.replace(",", "").split()
397                replaced = False
398                for key, val in params.items():
399                    if key in line_list:
400                        if isinstance(val, bool):
401                            if val:
402                                val = "true"
403                            else:
404                                val = "false"
405                        # To be able to handle cases like "variable_name" : variable_name
406                        # make sure to only replace the last occurence per line
407                        new_line = str(val).join(line.rsplit(key, 1))
408                        output.write(new_line)
409                        replaced = True
410                        break
411
412                for key in tensors:
413                    if key in line:
414                        dtype = Lib.op_utils.get_dtype(key, params)
415                        dtype_len = Lib.op_utils.get_dtype_len(dtype)
416                        np_dtype = Lib.op_utils.get_np_dtype(dtype)
417
418                        # Tensors are stored byte-wise in schema
419                        weights_in_bytes = []
420                        for weight in tensors[key].flatten():
421                            weights_in_bytes.extend([b for b in int(np_dtype(weight)).to_bytes(dtype_len, 'little')])
422
423                        for byte in weights_in_bytes[:-1]:
424                            output.write(f"        {byte},\n")
425                        output.write(f"        {weights_in_bytes[-1]}\n")
426
427                        replaced = True
428                        break
429
430                if not replaced:
431                    output.write(line)
432
433    # Generate tflite from json
434    command = f"flatc  -o {json_output_fpath.parent} -c -b {schema_path} {json_output_fpath}"
435    command_list = command.split()
436    try:
437        process = subprocess.run(command_list, env={'PATH': os.getenv('PATH')})
438        if process.returncode != 0:
439            print(f"ERROR: {command = }")
440            sys.exit(1)
441    except Exception as e:
442        raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
443
444
445def quantize_scale(scale):
446    significand, shift = math.frexp(scale)
447    significand_q31 = round(significand * (1 << 31))
448    return significand_q31, shift
449
450
451def get_header(generator, interpreter):
452    header = f"// Generated by {os.path.basename(sys.argv[0])}"
453    if generator == "keras":
454        header += f" using tensorflow version {tf.__version__} (Keras version {keras.__version__}).\n"
455    elif generator == "json":
456        command = "flatc  --version"
457        command_list = command.split()
458        try:
459            process = subprocess.Popen(command_list,
460                                       stdout=subprocess.PIPE,
461                                       stderr=subprocess.PIPE,
462                                       env={'PATH': os.getenv('PATH')})
463            flatc_version, err = process.communicate()
464            if process.returncode != 0:
465                print(f"ERROR: {command = }")
466                sys.exit(1)
467        except Exception as e:
468            raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
469        header += f" using {str(flatc_version)[2:-3]}\n"
470    else:
471        raise Exception
472
473    if interpreter == "tensorflow":
474        version = tf.__version__
475        revision = tf.__git_version__
476        header += f"// Interpreter from tensorflow version {version} and revision {revision}.\n"
477    elif interpreter == "tflite_runtime":
478        import tflite_runtime as tfl_runtime
479
480        version = tfl_runtime.__version__
481        revision = tfl_runtime.__git_version__
482        header += f"// Interpreter from tflite_runtime version {version} and revision {revision}.\n"
483    elif interpreter == "tflite_micro":
484        version = tflite_micro.__version__
485        header += f"// Interpreter from tflite_micro runtime version {version}.\n"
486    else:
487        raise Exception
488
489    return header
490