1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the License); you may 6# not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an AS IS BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import math 19import numpy as np 20import tensorflow as tf 21 22 23class Generated_data(): 24 """ Container for all generated data""" 25 26 def __init__(self, params, tensors, scales, effective_scales, aliases={}): 27 self.params = params # All other params which are generated rather than given in the test-plan 28 self.tensors = tensors # All tensors 29 self.scales = scales # Scales used in the tflite model 30 self.effective_scales = effective_scales # Scales used by CMSIS-NN 31 self.aliases = aliases # Optional for backward compability with old unit tests 32 33 34class Op_type(): 35 """ op_type interface """ 36 37 @staticmethod 38 def get_shapes(args): 39 """ Returns a struct of all shapes used by the operator """ 40 raise NotImplementedError 41 42 @staticmethod 43 def generate_keras_model(output_path, shapes, params): 44 """ Returns a non-quantized tflite-model with given shapes and params""" 45 raise NotImplementedError 46 47 @staticmethod 48 def generate_data_tflite(tflite_path, params) -> Generated_data: 49 """ 50 Parses quantized tensors, scales, and other parameter from the given tflite-file, calculates effective 51 scales, and returns them as three structs 52 """ 53 raise NotImplementedError 54 55 @staticmethod 56 def generate_data_json(shapes, params) -> Generated_data: 57 """ 58 Generates quantized tensors, scales, and other parameters with given shapes and params, calculates effecitve 59 scales, and returns them as four structs 60 """ 61 raise NotImplementedError 62 63 @staticmethod 64 def post_model_update(tflite_path, generated_data, params) -> Generated_data: 65 """ 66 Optional function for updating parameters after model has been created. 67 """ 68 return generated_data 69 70 71def generate_tf_tensor(dims, minval, maxval, decimals=0, datatype=tf.float32): 72 array = minval + (maxval - minval) * np.random.rand(*dims) 73 array = np.round(array, decimals=decimals) 74 tensor = tf.convert_to_tensor(array, dtype=datatype) 75 76 return tensor 77 78 79def get_dtype(name, params): 80 if "bias" in name: 81 return params["bias_data_type"] 82 elif "weight" in name or "kernel" in name: 83 if params["weights_data_type"] == "int4_t": 84 return "int8_t" 85 return params["weights_data_type"] 86 elif "multiplier" in name or "shift" in name: 87 return params["shift_and_mult_data_type"] 88 elif "input" in name or "output" in name or "transpose" in name: 89 return params["input_data_type"] 90 else: 91 raise Exception(f"Unable to deduce dtype from name '{name}'") 92 93 94def get_tf_dtype(dtype): 95 if dtype == "int8_t": 96 return tf.int8 97 if dtype == "int16_t": 98 return tf.int16 99 else: 100 raise Exception(f"Unrecognized dtype '{dtype}'") 101 102 103def get_np_dtype(dtype): 104 if dtype == "int8_t" or dtype == "int4_t": 105 return np.uint8 106 if dtype == "int16_t": 107 return np.uint16 108 if dtype == "int32_t": 109 return np.uint32 110 if dtype == "int64_t": 111 return np.uint64 112 else: 113 raise Exception(f"Unrecognized dtype '{dtype}'") 114 115 116def get_dtype_len(dtype): 117 if dtype == "int8_t" or dtype == "int4_t": 118 return 1 119 elif dtype == "int16_t": 120 return 2 121 elif dtype == "int32_t" or dtype == "float": 122 return 4 123 elif dtype == "int64_t" or dtype == "double": 124 return 8 125 else: 126 raise Exception(f"Unrecognized dtype '{dtype}'") 127 128 129def get_dtype_max(dtype): 130 if dtype == "int4_t": 131 return 7 132 if dtype == "int8_t": 133 return 127 134 elif dtype == "int16_t": 135 return 32767 136 elif dtype == "int32_t": 137 return 2147483647 138 elif dtype == "int64_t": 139 return 9223372036854775807 140 else: 141 raise Exception(f"Unrecognized dtype '{dtype}'") 142 143 144def get_dtype_min(dtype): 145 if dtype == "int4_t": 146 return -8 147 if dtype == "int8_t": 148 return -128 149 elif dtype == "int16_t": 150 return -32768 151 elif dtype == "int32_t": 152 return -2147483648 153 elif dtype == "int64_t": 154 return -9223372036854775808 155 else: 156 raise Exception(f"Unrecognized dtype '{dtype}'") 157 158def generate_quantize_per_channel_multiplier(params, scales): 159 def quantize_scale(scale): 160 significand, shift = math.frexp(scale) 161 significand_q31 = round(significand * (1 << 31)) 162 return significand_q31, shift 163 164 num_channels = params["out_ch"] 165 per_channel_multiplier = [] 166 per_channel_shift = [] 167 168 if len(scales["scaling_factors"]) != num_channels: 169 raise RuntimeError("Missing scaling factors") 170 171 for i in range(num_channels): 172 effective_output_scale = scales["input_scale"] * scales["scaling_factors"][i] / scales["output_scale"] 173 (quantized_multiplier, shift) = quantize_scale(effective_output_scale) 174 175 per_channel_multiplier.append(quantized_multiplier) 176 per_channel_shift.append(shift) 177 178 return per_channel_multiplier, per_channel_shift 179