1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the License); you may 6# not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an AS IS BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17import numpy as np 18import tensorflow as tf 19 20 21class Generated_data(): 22 """ Container for all generated data""" 23 24 def __init__(self, params, tensors, scales, effective_scales, aliases={}): 25 self.params = params # All other params which are generated rather than given in the test-plan 26 self.tensors = tensors # All tensors 27 self.scales = scales # Scales used in the tflite model 28 self.effective_scales = effective_scales # Scales used by CMSIS-NN 29 self.aliases = aliases # Optional for backward compability with old unit tests 30 31 32class Op_type(): 33 """ op_type interface """ 34 35 @staticmethod 36 def get_shapes(args): 37 """ Returns a struct of all shapes used by the operator """ 38 raise NotImplementedError 39 40 @staticmethod 41 def generate_keras_model(output_path, shapes, params): 42 """ Returns a non-quantized tflite-model with given shapes and params""" 43 raise NotImplementedError 44 45 @staticmethod 46 def generate_data_tflite(tflite_path, params) -> Generated_data: 47 """ 48 Parses quantized tensors, scales, and other parameter from the given tflite-file, calculates effective 49 scales, and returns them as three structs 50 """ 51 raise NotImplementedError 52 53 @staticmethod 54 def generate_data_json(shapes, params) -> Generated_data: 55 """ 56 Generates quantized tensors, scales, and other parameters with given shapes and params, calculates effecitve 57 scales, and returns them as four structs 58 """ 59 raise NotImplementedError 60 61 62def generate_tf_tensor(dims, minval, maxval, decimals=0, datatype=tf.float32): 63 array = minval + (maxval - minval) * np.random.rand(*dims) 64 array = np.round(array, decimals=decimals) 65 tensor = tf.convert_to_tensor(array, dtype=datatype) 66 67 return tensor 68 69 70def get_dtype(name, params): 71 if "bias" in name: 72 return params["bias_data_type"] 73 elif "weight" in name or "kernel" in name: 74 return params["weights_data_type"] 75 elif "multiplier" in name or "shift" in name: 76 return params["shift_and_mult_data_type"] 77 elif "input" in name or "output" in name: 78 return params["input_data_type"] 79 else: 80 raise Exception(f"Unable to deduce dtype from name '{name}'") 81 82 83def get_tf_dtype(dtype): 84 if dtype == "int8_t": 85 return tf.int8 86 if dtype == "int16_t": 87 return tf.int16 88 else: 89 raise Exception(f"Unrecognized dtype '{dtype}'") 90 91 92def get_np_dtype(dtype): 93 if dtype == "int8_t": 94 return np.uint8 95 if dtype == "int16_t": 96 return np.uint16 97 if dtype == "int32_t": 98 return np.uint32 99 if dtype == "int64_t": 100 return np.uint64 101 else: 102 raise Exception(f"Unrecognized dtype '{dtype}'") 103 104 105def get_dtype_len(dtype): 106 if dtype == "int8_t": 107 return 1 108 elif dtype == "int16_t": 109 return 2 110 elif dtype == "int32_t" or dtype == "float": 111 return 4 112 elif dtype == "int64_t" or dtype == "double": 113 return 8 114 else: 115 raise Exception(f"Unrecognized dtype '{dtype}'") 116 117 118def get_dtype_max(dtype): 119 if dtype == "int4_t": 120 return 7 121 if dtype == "int8_t": 122 return 127 123 elif dtype == "int16_t": 124 return 32767 125 elif dtype == "int32_t": 126 return 2147483647 127 elif dtype == "int64_t": 128 return 9223372036854775807 129 else: 130 raise Exception(f"Unrecognized dtype '{dtype}'") 131 132 133def get_dtype_min(dtype): 134 if dtype == "int4_t": 135 return -8 136 if dtype == "int8_t": 137 return -128 138 elif dtype == "int16_t": 139 return -32768 140 elif dtype == "int32_t": 141 return -2147483648 142 elif dtype == "int64_t": 143 return -9223372036854775808 144 else: 145 raise Exception(f"Unrecognized dtype '{dtype}'") 146