1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import numpy as np
18import tensorflow as tf
19
20
21class Generated_data():
22    """ Container for all generated data"""
23
24    def __init__(self, params, tensors, scales, effective_scales, aliases={}):
25        self.params = params  # All other params which are generated rather than given in the test-plan
26        self.tensors = tensors  # All tensors
27        self.scales = scales  # Scales used in the tflite model
28        self.effective_scales = effective_scales  # Scales used by CMSIS-NN
29        self.aliases = aliases  # Optional for backward compability with old unit tests
30
31
32class Op_type():
33    """ op_type interface """
34
35    @staticmethod
36    def get_shapes(args):
37        """ Returns a struct of all shapes used by the operator """
38        raise NotImplementedError
39
40    @staticmethod
41    def generate_keras_model(output_path, shapes, params):
42        """ Returns a non-quantized tflite-model with given shapes and params"""
43        raise NotImplementedError
44
45    @staticmethod
46    def generate_data_tflite(tflite_path, params) -> Generated_data:
47        """
48            Parses quantized tensors, scales, and other parameter from the given tflite-file, calculates effective
49            scales, and returns them as three structs
50        """
51        raise NotImplementedError
52
53    @staticmethod
54    def generate_data_json(shapes, params) -> Generated_data:
55        """
56           Generates quantized tensors, scales, and other parameters with given shapes and params, calculates effecitve
57           scales, and returns them as four structs
58        """
59        raise NotImplementedError
60
61
62def generate_tf_tensor(dims, minval, maxval, decimals=0, datatype=tf.float32):
63    array = minval + (maxval - minval) * np.random.rand(*dims)
64    array = np.round(array, decimals=decimals)
65    tensor = tf.convert_to_tensor(array, dtype=datatype)
66
67    return tensor
68
69
70def get_dtype(name, params):
71    if "bias" in name:
72        return params["bias_data_type"]
73    elif "weight" in name or "kernel" in name:
74        return params["weights_data_type"]
75    elif "multiplier" in name or "shift" in name:
76        return params["shift_and_mult_data_type"]
77    elif "input" in name or "output" in name:
78        return params["input_data_type"]
79    else:
80        raise Exception(f"Unable to deduce dtype from name '{name}'")
81
82
83def get_tf_dtype(dtype):
84    if dtype == "int8_t":
85        return tf.int8
86    if dtype == "int16_t":
87        return tf.int16
88    else:
89        raise Exception(f"Unrecognized dtype '{dtype}'")
90
91
92def get_np_dtype(dtype):
93    if dtype == "int8_t":
94        return np.uint8
95    if dtype == "int16_t":
96        return np.uint16
97    if dtype == "int32_t":
98        return np.uint32
99    if dtype == "int64_t":
100        return np.uint64
101    else:
102        raise Exception(f"Unrecognized dtype '{dtype}'")
103
104
105def get_dtype_len(dtype):
106    if dtype == "int8_t":
107        return 1
108    elif dtype == "int16_t":
109        return 2
110    elif dtype == "int32_t" or dtype == "float":
111        return 4
112    elif dtype == "int64_t" or dtype == "double":
113        return 8
114    else:
115        raise Exception(f"Unrecognized dtype '{dtype}'")
116
117
118def get_dtype_max(dtype):
119    if dtype == "int4_t":
120        return 7
121    if dtype == "int8_t":
122        return 127
123    elif dtype == "int16_t":
124        return 32767
125    elif dtype == "int32_t":
126        return 2147483647
127    elif dtype == "int64_t":
128        return 9223372036854775807
129    else:
130        raise Exception(f"Unrecognized dtype '{dtype}'")
131
132
133def get_dtype_min(dtype):
134    if dtype == "int4_t":
135        return -8
136    if dtype == "int8_t":
137        return -128
138    elif dtype == "int16_t":
139        return -32768
140    elif dtype == "int32_t":
141        return -2147483648
142    elif dtype == "int64_t":
143        return -9223372036854775808
144    else:
145        raise Exception(f"Unrecognized dtype '{dtype}'")
146