1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import os
18import Lib.op_lstm
19import Lib.op_conv
20import Lib.op_fully_connected
21import tensorflow as tf
22import numpy as np
23from tensorflow.lite.python.interpreter import Interpreter
24from tensorflow.lite.python.interpreter import OpResolverType
25import pathlib
26import subprocess
27import sys
28import math
29import tf_keras as keras
30
31# Optional runtime interpreters
32try:
33    import tflite_micro
34    tflite_micro_imported = True
35except ModuleNotFoundError:
36    print("WARNING: tflite_micro not installed, skipping tests using this interpreter.")
37    tflite_micro_imported = False
38
39try:
40    from tflite_runtime.interpreter import Interpreter as TfliteRuntimeInterpreter
41    from tflite_runtime.interpreter import OpResolverType as TfliteRuntimeOpResolverType
42
43    tflite_runtime_imported = True
44except ModuleNotFoundError:
45    print("WARNING: tflite_runtime not installed, skipping tests using this interpreter.")
46    tflite_runtime_imported = False
47
48
49def generate(params, args, fpaths):
50    """ Create a test with given parameters """
51
52    # Check if test is valid, skip otherwise
53    if (params["interpreter"] == "tflite_runtime") and (not tflite_runtime_imported):
54        print("Skipping...")
55        return
56    if (params["interpreter"] == "tflite_micro") and (not tflite_micro_imported):
57        print("Skipping...")
58        return
59
60    op_type = get_op_type(params["op_type"])
61    shapes = op_type.get_shapes(params)
62
63    # Create test related fpaths
64    fpaths["data_folder"] = pathlib.Path("TestCases") / "TestData" / params["name"]
65    fpaths["tflite"] = fpaths["data_folder"] / f"{params['name']}.tflite"
66    fpaths["config_data"] = fpaths["data_folder"] / "config_data.h"
67    fpaths["test_data"] = fpaths["data_folder"] / "test_data.h"
68
69    # Generate reference data
70    if params["tflite_generator"] == "keras":
71        keras_model = op_type.generate_keras_model(shapes, params)
72        convert_keras_to_tflite(fpaths["tflite"],
73                                keras_model,
74                                quantize=True,
75                                dtype=params["input_data_type"],
76                                bias_dtype=params["bias_data_type"],
77                                shape=shapes["representational_dataset"])
78        data = op_type.generate_data_tflite(fpaths["tflite"], params)
79
80    elif params["tflite_generator"] == "json":
81        data = op_type.generate_data_json(shapes, params)
82        json_template_fpath = fpaths["json_template_folder"] / f"{params['json_template']}"
83        json_output_fpath = fpaths["data_folder"] / f"{params['name']}.json"
84        replacements = {**params, **data.params, **data.scales}
85        convert_json_to_tflite(json_template_fpath, json_output_fpath, data.tensors, replacements, args.schema)
86
87    else:
88        raise ValueError(f"Invalid tflite generator in {params['name']}")
89
90    params.update(data.params)
91
92    # Quantize scales
93    for name, scale in data.effective_scales.items():
94        mult, shift = quantize_scale(scale)
95        params[name + "_multiplier"] = mult
96        params[name + "_shift"] = shift
97
98    # Run reference model
99    minval = Lib.op_utils.get_dtype_min(params["input_data_type"]) if "input_min" not in params else params["input_min"]
100    maxval = Lib.op_utils.get_dtype_max(params["input_data_type"]) if "input_max" not in params else params["input_max"]
101
102    dtype = Lib.op_utils.get_tf_dtype(params["input_data_type"])
103    input_tensor = Lib.op_utils.generate_tf_tensor(shapes["input"], minval, maxval, decimals=0, datatype=dtype)
104    data.tensors["input"] = input_tensor.numpy()
105
106    if params["interpreter"] == "tensorflow":
107        data.tensors["output"] = invoke_tflite(fpaths["tflite"], input_tensor)
108    elif params["interpreter"] == "tflite_runtime":
109        data.tensors["output"] = invoke_tflite_runtime(fpaths["tflite"], input_tensor)
110    elif params["interpreter"] == "tflite_micro":
111        data.tensors["output"] = invoke_tflite_micro(fpaths["tflite"], input_tensor)
112    else:
113        raise ValueError(f"Invalid interpreter in {params['name']}")
114
115    # Write data
116    header = get_header(params["tflite_generator"], params["interpreter"])
117
118    def include_in_config(key):
119        return key not in [
120            "suite_name", "name", "input_data_type", "op_type", "input_data_type", "weights_data_type",
121            "bias_data_type", "shift_and_mult_data_type", "interpreter", "tflite_generator", "json_template",
122            "groups", "generate_bias", "bias_min", "bias_max", "weights_min", "weights_max", "bias_zp", "w_zp",
123            "input_zp", "output_zp", "w_scale", "bias_scale", "input_scale", "output_scale"
124        ]
125
126    config_params = {key: val for key, val in params.items() if include_in_config(key)}
127    write_config(fpaths["config_data"], config_params, params["name"], fpaths["test_data"], header)
128
129    for name, tensor in data.tensors.items():
130        dtype = Lib.op_utils.get_dtype(name, params)
131        fpaths[name] = fpaths["data_folder"] / f"{name}.h"
132        if name == "output" and "out_activation_min" in params and "out_activation_max" in params:
133            tensor = np.clip(tensor, params["out_activation_min"], params["out_activation_max"])
134        write_c_array(tensor, fpaths[name], dtype, params["name"], name, fpaths["test_data"], header)
135
136        if name in data.aliases:
137            append_alias_to_c_array_file(fpaths[name], dtype, params["name"], name, data.aliases[name])
138
139
140def get_op_type(op_type_string):
141    if op_type_string == "lstm":
142        return Lib.op_lstm.Op_lstm
143    elif op_type_string == "conv":
144        return Lib.op_conv.Op_conv
145    elif op_type_string == "fully_connected":
146        return Lib.op_fully_connected.Op_fully_connected
147    else:
148        raise ValueError(f"Unknown op type '{op_type_string}'")
149
150
151def convert_keras_to_tflite(output_fpath, keras_model, quantize, dtype, bias_dtype, shape):
152    """ Convert a model generated with keras to tflite-format """
153    keras_model.compile(loss=keras.losses.categorical_crossentropy,
154                        optimizer=keras.optimizers.Adam(),
155                        metrics=['accuracy'])
156    n_inputs = len(keras_model.inputs)
157    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
158
159    if quantize:
160
161        def representative_dataset():
162            for _ in range(n_inputs):
163                data = np.random.rand(*shape)
164                yield [data.astype(np.float32)]
165
166        converter.representative_dataset = representative_dataset
167        converter.optimizations = [tf.lite.Optimize.DEFAULT]
168        converter.inference_input_type = Lib.op_utils.get_tf_dtype(dtype)
169        converter.inference_output_type = Lib.op_utils.get_tf_dtype(dtype)
170
171        if dtype == "int8_t":
172            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
173        else:
174            if bias_dtype == "int32_t":
175                converter._experimental_full_integer_quantization_bias_type = tf.int32
176            converter.target_spec.supported_ops = [
177                tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
178            ]
179
180        tflite_model = converter.convert()
181
182    output_fpath.parent.mkdir(parents=True, exist_ok=True)
183    with output_fpath.open("wb") as f:
184        f.write(tflite_model)
185
186
187def invoke_tflite(tflite_path, input_tensor):
188    interpreter = Interpreter(str(tflite_path), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
189    input_index = interpreter.get_input_details()[0]["index"]
190    interpreter.allocate_tensors()
191    interpreter.set_tensor(input_index, input_tensor)
192    interpreter.invoke()
193    output_index = interpreter.get_output_details()[0]["index"]
194    data = interpreter.get_tensor(output_index)
195
196    return data.flatten()
197
198
199def invoke_tflite_runtime(tflite_path, input_tensor):
200    interpreter = TfliteRuntimeInterpreter(str(tflite_path),
201                                           experimental_op_resolver_type=TfliteRuntimeOpResolverType.BUILTIN_REF)
202    input_index = interpreter.get_input_details()[0]["index"]
203    interpreter.allocate_tensors()
204    interpreter.set_tensor(input_index, input_tensor)
205    interpreter.invoke()
206    output_index = interpreter.get_output_details()[0]["index"]
207    data = interpreter.get_tensor(output_index)
208
209    return data.flatten()
210
211
212def invoke_tflite_micro(tflite_path, input_tensor):
213    interpreter = tflite_micro.runtime.Interpreter.from_file(model_path=str(tflite_path))
214    interpreter.set_input(input_tensor, 0)
215    interpreter.invoke()
216    data = interpreter.get_output(0)
217
218    return data.flatten()
219
220
221def write_config(config_fpath, params, prefix, test_data_fpath, header):
222    with config_fpath.open("w+") as f:
223        f.write(header)
224        f.write("#pragma once\n")
225
226        for key, val in params.items():
227            if isinstance(val, bool):
228                if val:
229                    val = "true"
230                else:
231                    val = "false"
232
233            f.write("#define " + f"{prefix}_{key} ".upper() + f"{val}\n")
234
235    with test_data_fpath.open("w") as f:
236        f.write(f'#include "{config_fpath.name}"\n')
237
238
239def write_c_array(data, fname, dtype, prefix, tensor_name, test_data_fpath, header):
240
241    # Check that the data looks reasonable
242    values, counts = np.unique(data, return_counts=True)
243
244    size = 0 if data is None else data.size
245
246    if len(values) < size / 2 or max(counts) > size / 2:
247        print(f"WARNING: {fname} has repeating values, is this intended?")
248    if size and len(data) > 500:
249        print(f"WARNING: {fname} has more than 500 values, is this intended?")
250
251    with fname.open("w+") as f:
252        f.write(header)
253        f.write("#pragma once\n")
254        f.write("#include <stdint.h>\n\n")
255        if size > 0:
256            data_shape = data.shape
257            format_width = len(str(data.max())) + 1
258            data = data.flatten()
259            f.write(f"const {dtype} {prefix}_{tensor_name}[{len(data)}] = \n" + "{")
260
261            for i in range(len(data) - 1):
262                if i % data_shape[-1] == 0:
263                    f.write("\n")
264                f.write(f"{data[i]: {format_width}n}, ")
265
266            if len(data) - 1 % data_shape[-1] == 0:
267                f.write("\n")
268            f.write(f"{data[len(data) - 1]: {format_width}n}" + "\n};\n")
269
270        else:
271            f.write(f"const {dtype} *const {prefix}_{tensor_name} = NULL;\n")
272
273    with test_data_fpath.open("a") as f:
274        f.write(f'#include "{fname.name}"\n')
275
276    format_output_file(fname)
277    format_output_file(test_data_fpath)
278
279
280def append_alias_to_c_array_file(fname, dtype, prefix, tensor_name, alias_name):
281    with fname.open("a") as f:
282        f.write(f"\nconst {dtype} *const {prefix}_{alias_name} = {prefix}_{tensor_name};\n")
283
284
285def format_output_file(file):
286    CLANG_FORMAT = 'clang-format-12 -i'  # For formatting generated headers.
287    command_list = CLANG_FORMAT.split(' ')
288    command_list.append(file)
289    try:
290        process = subprocess.run(command_list)
291        if process.returncode != 0:
292            print(f"ERROR: {command_list = }")
293            sys.exit(1)
294    except Exception as e:
295        raise RuntimeError(f"{e} from: {command_list = }")
296
297
298def generate_test_from_template(name, test_functions_fpath, template_fpath, unity_fpath):
299    template = template_fpath.read_text()
300    template = template.replace("template", name)
301
302    with test_functions_fpath.open("a") as f:
303        f.write(f'#include "../TestData/{name}/test_data.h"\n\n')
304        f.write(template)
305
306    with unity_fpath.open("a") as f:
307        f.write("void test_" + name + "(void) { " + name + "(); }\n")
308
309
310def convert_json_to_tflite(json_template_fpath, json_output_fpath, tensors, params, schema_path):
311    """ Convert a model in json-format to tflite-format"""
312
313    # Generate json with values from template
314    # This way minimizes string searching/ copying
315    json_output_fpath.parent.mkdir(parents=True, exist_ok=True)
316    with json_template_fpath.open("r") as template:
317        with json_output_fpath.open("w+") as output:
318            for line in template:
319                line_list = line.replace(",", "").split()
320                replaced = False
321                for key, val in params.items():
322                    if key in line_list:
323                        if isinstance(val, bool):
324                            if val:
325                                val = "true"
326                            else:
327                                val = "false"
328                        # To be able to handle cases like "variable_name" : variable_name
329                        # make sure to only replace the last occurence per line
330                        new_line = str(val).join(line.rsplit(key, 1))
331                        output.write(new_line)
332                        replaced = True
333                        break
334
335                for key in tensors:
336                    if key in line:
337                        dtype = Lib.op_utils.get_dtype(key, params)
338                        dtype_len = Lib.op_utils.get_dtype_len(dtype)
339                        np_dtype = Lib.op_utils.get_np_dtype(dtype)
340
341                        # Tensors are stored byte-wise in schema
342                        weights_in_bytes = []
343                        for weight in tensors[key].flatten():
344                            weights_in_bytes.extend([b for b in int(np_dtype(weight)).to_bytes(dtype_len, 'little')])
345
346                        for byte in weights_in_bytes[:-1]:
347                            output.write(f"        {byte},\n")
348                        output.write(f"        {weights_in_bytes[-1]}\n")
349
350                        replaced = True
351                        break
352
353                if not replaced:
354                    output.write(line)
355
356    # Generate tflite from json
357    command = f"flatc  -o {json_output_fpath.parent} -c -b {schema_path} {json_output_fpath}"
358    command_list = command.split()
359    try:
360        process = subprocess.run(command_list, env={'PATH': os.getenv('PATH')})
361        if process.returncode != 0:
362            print(f"ERROR: {command = }")
363            sys.exit(1)
364    except Exception as e:
365        raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
366
367
368def quantize_scale(scale):
369    significand, shift = math.frexp(scale)
370    significand_q31 = round(significand * (1 << 31))
371    return significand_q31, shift
372
373
374def get_header(generator, interpreter):
375    header = f"// Generated by {os.path.basename(sys.argv[0])}"
376    if generator == "keras":
377        header += f" using tensorflow version {tf.__version__} (Keras version {keras.__version__}).\n"
378    elif generator == "json":
379        command = "flatc  --version"
380        command_list = command.split()
381        try:
382            process = subprocess.Popen(command_list,
383                                       stdout=subprocess.PIPE,
384                                       stderr=subprocess.PIPE,
385                                       env={'PATH': os.getenv('PATH')})
386            flatc_version, err = process.communicate()
387            if process.returncode != 0:
388                print(f"ERROR: {command = }")
389                sys.exit(1)
390        except Exception as e:
391            raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
392        header += f" using {str(flatc_version)[2:-3]}\n"
393    else:
394        raise Exception
395
396    if interpreter == "tensorflow":
397        version = tf.__version__
398        revision = tf.__git_version__
399        header += f"// Interpreter from tensorflow version {version} and revision {revision}.\n"
400    elif interpreter == "tflite_runtime":
401        import tflite_runtime as tfl_runtime
402
403        version = tfl_runtime.__version__
404        revision = tfl_runtime.__git_version__
405        header += f"// Interpreter from tflite_runtime version {version} and revision {revision}.\n"
406    elif interpreter == "tflite_micro":
407        version = tflite_micro.__version__
408        header += f"// Interpreter from tflite_micro runtime version {version}.\n"
409    else:
410        raise Exception
411
412    return header
413