1#!/usr/bin/env python3 2# 3# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the License); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an AS IS BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18# 19import os 20import sys 21import json 22import argparse 23import subprocess 24 25import numpy as np 26import tensorflow as tf 27 28from conv_settings import ConvSettings 29from softmax_settings import SoftmaxSettings 30from fully_connected_settings import FullyConnectedSettings 31 32 33class MODEL_EXTRACTOR(SoftmaxSettings, FullyConnectedSettings, ConvSettings): 34 35 def __init__(self, dataset, schema_file, tflite_model): 36 37 super().__init__(dataset, None, True, True, True, schema_file) 38 39 self.tflite_model = tflite_model 40 41 (self.quantized_multiplier, self.quantized_shift) = 0, 0 42 self.is_int16xint8 = False # Only 8-bit supported. 43 self.diff_min, self.input_multiplier, self.input_left_shift = 0, 0, 0 44 45 self.supported_ops = ["CONV_2D", "DEPTHWISE_CONV_2D", "FULLY_CONNECTED", "AVERAGE_POOL_2D", "SOFTMAX"] 46 47 def from_bytes(self, tensor_data, type_size) -> list: 48 result = [] 49 tmp_ints = [] 50 51 if not (type_size == 1 or type_size == 2 or type_size == 4): 52 raise RuntimeError("Size not supported: {}".format(type_size)) 53 54 count = 0 55 for val in tensor_data: 56 tmp_ints.append(val) 57 count = count + 1 58 if count % type_size == 0: 59 tmp_bytes = bytearray(tmp_ints) 60 result.append(int.from_bytes(tmp_bytes, 'little', signed=True)) 61 tmp_ints.clear() 62 63 return result 64 65 def tflite_to_json(self, tflite_input, schema): 66 name_without_ext, ext = os.path.splitext(tflite_input) 67 new_name = name_without_ext + '.json' 68 dirname = os.path.dirname(tflite_input) 69 70 if schema is None: 71 raise RuntimeError("A schema file is required.") 72 command = f"flatc -o {dirname} --strict-json -t {schema} -- {tflite_input}" 73 command_list = command.split(' ') 74 try: 75 process = subprocess.run(command_list) 76 if process.returncode != 0: 77 print(f"ERROR: {command = }") 78 sys.exit(1) 79 except Exception as e: 80 raise RuntimeError(f"{e} from: {command = }. Did you install flatc?") 81 82 return new_name 83 84 def write_c_config_header(self, name_prefix, op_name, op_index) -> None: 85 filename = f"{name_prefix}_config_data.h" 86 87 self.generated_header_files.append(filename) 88 filepath = self.headers_dir + filename 89 90 prefix = f'{op_name}_{op_index}' 91 92 print("Writing C header with config data {}...".format(filepath)) 93 with open(filepath, "w+") as f: 94 self.write_c_common_header(f) 95 f.write("#define {}_OUT_CH {}\n".format(prefix, self.output_ch)) 96 f.write("#define {}_IN_CH {}\n".format(prefix, self.input_ch)) 97 f.write("#define {}_INPUT_W {}\n".format(prefix, self.x_input)) 98 f.write("#define {}_INPUT_H {}\n".format(prefix, self.y_input)) 99 f.write("#define {}_DST_SIZE {}\n".format(prefix, 100 self.x_output * self.y_output * self.output_ch * self.batches)) 101 if op_name == "SOFTMAX": 102 f.write("#define {}_NUM_ROWS {}\n".format(prefix, self.y_input)) 103 f.write("#define {}_ROW_SIZE {}\n".format(prefix, self.x_input)) 104 f.write("#define {}_MULT {}\n".format(prefix, self.input_multiplier)) 105 f.write("#define {}_SHIFT {}\n".format(prefix, self.input_left_shift)) 106 if not self.is_int16xint8: 107 f.write("#define {}_DIFF_MIN {}\n".format(prefix, -self.diff_min)) 108 else: 109 f.write("#define {}_FILTER_X {}\n".format(prefix, self.filter_x)) 110 f.write("#define {}_FILTER_Y {}\n".format(prefix, self.filter_y)) 111 f.write("#define {}_FILTER_W {}\n".format(prefix, self.filter_x)) 112 f.write("#define {}_FILTER_H {}\n".format(prefix, self.filter_y)) 113 f.write("#define {}_STRIDE_X {}\n".format(prefix, self.stride_x)) 114 f.write("#define {}_STRIDE_Y {}\n".format(prefix, self.stride_y)) 115 f.write("#define {}_STRIDE_W {}\n".format(prefix, self.stride_x)) 116 f.write("#define {}_STRIDE_H {}\n".format(prefix, self.stride_y)) 117 f.write("#define {}_PAD_X {}\n".format(prefix, self.pad_x)) 118 f.write("#define {}_PAD_Y {}\n".format(prefix, self.pad_y)) 119 f.write("#define {}_PAD_W {}\n".format(prefix, self.pad_x)) 120 f.write("#define {}_PAD_H {}\n".format(prefix, self.pad_y)) 121 f.write("#define {}_OUTPUT_W {}\n".format(prefix, self.x_output)) 122 f.write("#define {}_OUTPUT_H {}\n".format(prefix, self.y_output)) 123 f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, -self.input_zero_point)) 124 f.write("#define {}_INPUT_SIZE {}\n".format(prefix, self.x_input * self.y_input * self.input_ch)) 125 f.write("#define {}_OUT_ACTIVATION_MIN {}\n".format(prefix, self.out_activation_min)) 126 f.write("#define {}_OUT_ACTIVATION_MAX {}\n".format(prefix, self.out_activation_max)) 127 f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches)) 128 f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point)) 129 f.write("#define {}_DILATION_X {}\n".format(prefix, self.dilation_x)) 130 f.write("#define {}_DILATION_Y {}\n".format(prefix, self.dilation_y)) 131 f.write("#define {}_DILATION_W {}\n".format(prefix, self.dilation_x)) 132 f.write("#define {}_DILATION_H {}\n".format(prefix, self.dilation_y)) 133 134 if op_name == "FULLY_CONNECTED": 135 f.write("#define {}_OUTPUT_MULTIPLIER {}\n".format(prefix, self.quantized_multiplier)) 136 f.write("#define {}_OUTPUT_SHIFT {}\n".format(prefix, self.quantized_shift)) 137 138 if op_name == "DEPTHWISE_CONV_2D": 139 f.write("#define {}_ACCUMULATION_DEPTH {}\n".format(prefix, 140 self.input_ch * self.x_input * self.y_input)) 141 142 self.format_output_file(filepath) 143 144 def shape_to_config(self, input_shape, filter_shape, output_shape, layer_name): 145 if layer_name == "AVERAGE_POOL_2D": 146 [_, self.filter_y, self.filter_x, _] = input_shape 147 148 elif layer_name == "CONV_2D" or layer_name == "DEPTHWISE_CONV_2D": 149 [self.batches, self.y_input, self.x_input, self.input_ch] = input_shape 150 [output_ch, self.filter_y, self.filter_x, self.input_ch] = filter_shape 151 152 elif layer_name == "FULLY_CONNECTED": 153 [self.batches, self.input_ch] = input_shape 154 [self.input_ch, self.output_ch] = filter_shape 155 [self.y_output, self.x_output] = output_shape 156 self.x_input = 1 157 self.y_input = 1 158 159 elif layer_name == "SOFTMAX": 160 [self.y_input, self.x_input] = input_shape 161 162 if len(input_shape) == 4: 163 if len(output_shape) == 2: 164 [self.y_output, self.x_output] = output_shape 165 else: 166 [d, self.y_output, self.x_output, d1] = output_shape 167 168 self.calculate_padding(self.x_output, self.y_output, self.x_input, self.y_input) 169 170 def extract_from_model(self, json_file, tensor_details): 171 172 with open(json_file, 'r') as in_file: 173 data = in_file.read() 174 data = json.loads(data) 175 tensors = data['subgraphs'][0]['tensors'] 176 operators = data['subgraphs'][0]['operators'] 177 operator_codes = data['operator_codes'] 178 buffers = data['buffers'] 179 180 op_index = 0 181 for op in operators: 182 if 'opcode_index' in op: 183 builtin_name = operator_codes[op['opcode_index']]['builtin_code'] 184 else: 185 builtin_name = "" 186 187 # Get stride and padding. 188 if 'builtin_options' in op: 189 builtin_options = op['builtin_options'] 190 if 'stride_w' in builtin_options: 191 self.stride_x = builtin_options['stride_w'] 192 if 'stride_h' in builtin_options: 193 self.stride_y = builtin_options['stride_h'] 194 if 'padding' in builtin_options: 195 self.has_padding = False 196 self.padding = 'VALID' 197 else: 198 self.has_padding = True 199 self.padding = 'SAME' 200 201 # Generate weights, bias, multipliers, shifts and config. 202 if builtin_name not in self.supported_ops: 203 print(f"WARNING: skipping unsupported operator {builtin_name}") 204 else: 205 206 input_index = op['inputs'][0] 207 output_index = op['outputs'][0] 208 209 input_tensor = tensor_details[input_index] 210 output_tensor = tensor_details[output_index] 211 input_scale = input_tensor['quantization'][0] 212 output_scale = output_tensor['quantization'][0] 213 self.input_zero_point = input_tensor['quantization'][1] 214 self.output_zero_point = output_tensor['quantization'][1] 215 216 input_shape = input_tensor['shape'] 217 output_shape = output_tensor['shape'] 218 219 if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \ 220 or builtin_name == "FULLY_CONNECTED": 221 weights_index = op['inputs'][1] 222 bias_index = op['inputs'][2] 223 224 weight_tensor = tensor_details[weights_index] 225 scaling_factors = weight_tensor['quantization_parameters']['scales'].tolist() 226 227 bias = tensors[bias_index] 228 weights = tensors[weights_index] 229 230 weights_data_index = weights['buffer'] 231 weights_data_buffer = buffers[weights_data_index] 232 weights_data = self.from_bytes(weights_data_buffer['data'], 1) 233 234 bias_data_index = bias['buffer'] 235 bias_data_buffer = buffers[bias_data_index] 236 bias_data = self.from_bytes(bias_data_buffer['data'], 4) 237 238 self.output_ch = len(scaling_factors) 239 240 filter_shape = weights['shape'] 241 else: 242 filter_shape = [] 243 244 self.input_scale, self.output_scale = input_scale, output_scale 245 246 if builtin_name == "SOFTMAX": 247 self.calc_softmax_params() 248 249 self.shape_to_config(input_shape, filter_shape, output_shape, builtin_name) 250 251 nice_name = 'layer_' + str(op_index) + '_' + builtin_name.lower() 252 253 if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \ 254 or builtin_name == "FULLY_CONNECTED": 255 self.generate_c_array(nice_name + "_weights", weights_data) 256 self.generate_c_array(nice_name + "_bias", bias_data, datatype='int32_t') 257 258 if builtin_name == "FULLY_CONNECTED": 259 self.weights_scale = scaling_factors[0] 260 self.quantize_multiplier() 261 262 elif builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D": 263 self.scaling_factors = scaling_factors 264 per_channel_multiplier, per_channel_shift = self.generate_quantize_per_channel_multiplier() 265 266 self.generate_c_array(f"{nice_name}_output_mult", per_channel_multiplier, datatype='int32_t') 267 self.generate_c_array(f"{nice_name}_output_shift", per_channel_shift, datatype='int32_t') 268 269 self.write_c_config_header(nice_name, builtin_name, op_index) 270 271 op_index = op_index + 1 272 273 def generate_data(self, input_data=None, weights=None, biases=None) -> None: 274 275 interpreter = self.Interpreter(model_path=str(self.tflite_model), 276 experimental_op_resolver_type=self.OpResolverType.BUILTIN_REF) 277 interpreter.allocate_tensors() 278 279 # Needed for input/output scale/zp as equivalant json file data has too low precision. 280 tensor_details = interpreter.get_tensor_details() 281 282 output_details = interpreter.get_output_details() 283 (self.output_scale, self.output_zero_point) = output_details[0]['quantization'] 284 285 input_details = interpreter.get_input_details() 286 if len(input_details) != 1: 287 raise RuntimeError("Only single input supported.") 288 input_shape = input_details[0]['shape'] 289 input_data = self.get_randomized_input_data(input_data, input_shape) 290 interpreter.set_tensor(input_details[0]["index"], tf.cast(input_data, tf.int8)) 291 292 self.generate_c_array("input", input_data) 293 294 json_file = self.tflite_to_json(self.tflite_model, self.schema_file) 295 self.extract_from_model(json_file, tensor_details) 296 297 interpreter.invoke() 298 output_data = interpreter.get_tensor(output_details[0]["index"]) 299 self.generate_c_array("output_ref", np.clip(output_data, self.out_activation_min, self.out_activation_max)) 300 301 self.write_c_header_wrapper() 302 303 304if __name__ == '__main__': 305 parser = argparse.ArgumentParser(description="Extract operator data from given model if operator is supported." 306 "This provides a way for CMSIS-NN to directly process a model.") 307 parser.add_argument('--schema-file', type=str, required=True, help="Path to schema file.") 308 parser.add_argument('--tflite-model', type=str, required=True, help="Path to tflite file.") 309 parser.add_argument('--model-name', 310 type=str, 311 help="Descriptive model name. If left out it will be inferred from actual model.") 312 313 args = parser.parse_args() 314 315 schema_file = args.schema_file 316 tflite_model = args.tflite_model 317 318 if args.model_name: 319 dataset = args.model_name 320 else: 321 dataset, _ = os.path.splitext(os.path.basename(tflite_model)) 322 323 model_extractor = MODEL_EXTRACTOR(dataset, schema_file, tflite_model) 324 model_extractor.generate_data() 325