1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import math
19import numpy as np
20import tensorflow as tf
21
22
23class Generated_data():
24    """ Container for all generated data"""
25
26    def __init__(self, params, tensors, scales, effective_scales, aliases={}):
27        self.params = params  # All other params which are generated rather than given in the test-plan
28        self.tensors = tensors  # All tensors
29        self.scales = scales  # Scales used in the tflite model
30        self.effective_scales = effective_scales  # Scales used by CMSIS-NN
31        self.aliases = aliases  # Optional for backward compability with old unit tests
32
33
34class Op_type():
35    """ op_type interface """
36
37    @staticmethod
38    def get_shapes(args):
39        """ Returns a struct of all shapes used by the operator """
40        raise NotImplementedError
41
42    @staticmethod
43    def generate_keras_model(output_path, shapes, params):
44        """ Returns a non-quantized tflite-model with given shapes and params"""
45        raise NotImplementedError
46
47    @staticmethod
48    def generate_data_tflite(tflite_path, params) -> Generated_data:
49        """
50            Parses quantized tensors, scales, and other parameter from the given tflite-file, calculates effective
51            scales, and returns them as three structs
52        """
53        raise NotImplementedError
54
55    @staticmethod
56    def generate_data_json(shapes, params) -> Generated_data:
57        """
58           Generates quantized tensors, scales, and other parameters with given shapes and params, calculates effecitve
59           scales, and returns them as four structs
60        """
61        raise NotImplementedError
62
63    @staticmethod
64    def post_model_update(tflite_path, generated_data, params) -> Generated_data:
65        """
66           Optional function for updating parameters after model has been created.
67        """
68        return generated_data
69
70
71def generate_tf_tensor(dims, minval, maxval, decimals=0, datatype=tf.float32):
72    array = minval + (maxval - minval) * np.random.rand(*dims)
73    array = np.round(array, decimals=decimals)
74    tensor = tf.convert_to_tensor(array, dtype=datatype)
75
76    return tensor
77
78
79def get_dtype(name, params):
80    if "bias" in name:
81        return params["bias_data_type"]
82    elif "weight" in name or "kernel" in name:
83        if params["weights_data_type"] == "int4_t":
84            return "int8_t"
85        return params["weights_data_type"]
86    elif "multiplier" in name or "shift" in name:
87        return params["shift_and_mult_data_type"]
88    elif "input" in name or "output" in name or "transpose" in name:
89        return params["input_data_type"]
90    else:
91        raise Exception(f"Unable to deduce dtype from name '{name}'")
92
93
94def get_tf_dtype(dtype):
95    if dtype == "int8_t":
96        return tf.int8
97    if dtype == "int16_t":
98        return tf.int16
99    else:
100        raise Exception(f"Unrecognized dtype '{dtype}'")
101
102
103def get_np_dtype(dtype):
104    if dtype == "int8_t" or dtype == "int4_t":
105        return np.uint8
106    if dtype == "int16_t":
107        return np.uint16
108    if dtype == "int32_t":
109        return np.uint32
110    if dtype == "int64_t":
111        return np.uint64
112    else:
113        raise Exception(f"Unrecognized dtype '{dtype}'")
114
115
116def get_dtype_len(dtype):
117    if dtype == "int8_t" or dtype == "int4_t":
118        return 1
119    elif dtype == "int16_t":
120        return 2
121    elif dtype == "int32_t" or dtype == "float":
122        return 4
123    elif dtype == "int64_t" or dtype == "double":
124        return 8
125    else:
126        raise Exception(f"Unrecognized dtype '{dtype}'")
127
128
129def get_dtype_max(dtype):
130    if dtype == "int4_t":
131        return 7
132    if dtype == "int8_t":
133        return 127
134    elif dtype == "int16_t":
135        return 32767
136    elif dtype == "int32_t":
137        return 2147483647
138    elif dtype == "int64_t":
139        return 9223372036854775807
140    else:
141        raise Exception(f"Unrecognized dtype '{dtype}'")
142
143
144def get_dtype_min(dtype):
145    if dtype == "int4_t":
146        return -8
147    if dtype == "int8_t":
148        return -128
149    elif dtype == "int16_t":
150        return -32768
151    elif dtype == "int32_t":
152        return -2147483648
153    elif dtype == "int64_t":
154        return -9223372036854775808
155    else:
156        raise Exception(f"Unrecognized dtype '{dtype}'")
157
158def generate_quantize_per_channel_multiplier(params, scales):
159    def quantize_scale(scale):
160        significand, shift = math.frexp(scale)
161        significand_q31 = round(significand * (1 << 31))
162        return significand_q31, shift
163
164    num_channels = params["out_ch"]
165    per_channel_multiplier = []
166    per_channel_shift = []
167
168    if len(scales["scaling_factors"]) != num_channels:
169        raise RuntimeError("Missing scaling factors")
170
171    for i in range(num_channels):
172        effective_output_scale = scales["input_scale"] * scales["scaling_factors"][i] / scales["output_scale"]
173        (quantized_multiplier, shift) = quantize_scale(effective_output_scale)
174
175        per_channel_multiplier.append(quantized_multiplier)
176        per_channel_shift.append(shift)
177
178    return per_channel_multiplier, per_channel_shift
179