1#!/usr/bin/env python3
2#
3# SPDX-FileCopyrightText: Copyright 2010-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the License); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an AS IS BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19import os
20import sys
21import json
22import argparse
23import subprocess
24
25import numpy as np
26import tensorflow as tf
27
28from generate_test_data import SoftmaxSettings, FullyConnectedSettings, ConvSettings, Interpreter, OpResolverType
29
30
31class MODEL_EXTRACTOR(SoftmaxSettings, FullyConnectedSettings, ConvSettings):
32
33    def __init__(self, dataset, schema_file, tflite_model):
34
35        super().__init__(dataset, None, True, True, True, schema_file)
36
37        self.tflite_model = tflite_model
38
39        (self.quantized_multiplier, self.quantized_shift) = 0, 0
40        self.is_int16xint8 = False  # Only 8-bit supported.
41        self.diff_min, self.input_multiplier, self.input_left_shift = 0, 0, 0
42
43        self.supported_ops = ["CONV_2D", "DEPTHWISE_CONV_2D", "FULLY_CONNECTED", "AVERAGE_POOL_2D", "SOFTMAX"]
44
45    def from_bytes(self, tensor_data, type_size) -> list:
46        result = []
47        tmp_ints = []
48
49        if type_size == 1:
50            tensor_type = np.uint8
51        elif type_size == 2:
52            tensor_type = np.uint16
53        elif type_size == 4:
54            tensor_type = np.uint32
55        else:
56            raise RuntimeError("Size not supported: {}".format(type_size))
57
58        count = 0
59        for val in tensor_data:
60            tmp_ints.append(val)
61            count = count + 1
62            if count % type_size == 0:
63                tmp_bytes = bytearray(tmp_ints)
64                result.append(int.from_bytes(tmp_bytes, 'little', signed=True))
65                tmp_ints.clear()
66
67        return result
68
69    def tflite_to_json(self, tflite_input, schema):
70        name_without_ext, ext = os.path.splitext(tflite_input)
71        new_name = name_without_ext + '.json'
72        dirname = os.path.dirname(tflite_input)
73
74        if schema is None:
75            raise RuntimeError("A schema file is required.")
76        command = f"flatc -o {dirname} --strict-json -t {schema} -- {tflite_input}"
77        command_list = command.split(' ')
78        try:
79            process = subprocess.run(command_list)
80            if process.returncode != 0:
81                print(f"ERROR: {command = }")
82                sys.exit(1)
83        except Exception as e:
84            raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
85
86        return new_name
87
88    def write_c_config_header(self, name_prefix, op_name, op_index) -> None:
89        filename = f"{name_prefix}_config_data.h"
90
91        self.generated_header_files.append(filename)
92        filepath = self.headers_dir + filename
93
94        prefix = f'{op_name}_{op_index}'
95
96        print("Writing C header with config data {}...".format(filepath))
97        with open(filepath, "w+") as f:
98            self.write_c_common_header(f)
99            f.write("#define {}_OUT_CH {}\n".format(prefix, self.output_ch))
100            f.write("#define {}_IN_CH {}\n".format(prefix, self.input_ch))
101            f.write("#define {}_INPUT_W {}\n".format(prefix, self.x_input))
102            f.write("#define {}_INPUT_H {}\n".format(prefix, self.y_input))
103            f.write("#define {}_DST_SIZE {}\n".format(prefix,
104                                                      self.x_output * self.y_output * self.output_ch * self.batches))
105            if op_name == "SOFTMAX":
106                f.write("#define {}_NUM_ROWS {}\n".format(prefix, self.y_input))
107                f.write("#define {}_ROW_SIZE {}\n".format(prefix, self.x_input))
108                f.write("#define {}_MULT {}\n".format(prefix, self.input_multiplier))
109                f.write("#define {}_SHIFT {}\n".format(prefix, self.input_left_shift))
110                if not self.is_int16xint8:
111                    f.write("#define {}_DIFF_MIN {}\n".format(prefix, -self.diff_min))
112            else:
113                f.write("#define {}_FILTER_X {}\n".format(prefix, self.filter_x))
114                f.write("#define {}_FILTER_Y {}\n".format(prefix, self.filter_y))
115                f.write("#define {}_FILTER_W {}\n".format(prefix, self.filter_x))
116                f.write("#define {}_FILTER_H {}\n".format(prefix, self.filter_y))
117                f.write("#define {}_STRIDE_X {}\n".format(prefix, self.stride_x))
118                f.write("#define {}_STRIDE_Y {}\n".format(prefix, self.stride_y))
119                f.write("#define {}_STRIDE_W {}\n".format(prefix, self.stride_x))
120                f.write("#define {}_STRIDE_H {}\n".format(prefix, self.stride_y))
121                f.write("#define {}_PAD_X {}\n".format(prefix, self.pad_x))
122                f.write("#define {}_PAD_Y {}\n".format(prefix, self.pad_y))
123                f.write("#define {}_PAD_W {}\n".format(prefix, self.pad_x))
124                f.write("#define {}_PAD_H {}\n".format(prefix, self.pad_y))
125                f.write("#define {}_OUTPUT_W {}\n".format(prefix, self.x_output))
126                f.write("#define {}_OUTPUT_H {}\n".format(prefix, self.y_output))
127                f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, -self.input_zero_point))
128                f.write("#define {}_INPUT_SIZE {}\n".format(prefix, self.x_input * self.y_input * self.input_ch))
129                f.write("#define {}_OUT_ACTIVATION_MIN {}\n".format(prefix, self.out_activation_min))
130                f.write("#define {}_OUT_ACTIVATION_MAX {}\n".format(prefix, self.out_activation_max))
131                f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
132                f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point))
133                f.write("#define {}_DILATION_X {}\n".format(prefix, self.dilation_x))
134                f.write("#define {}_DILATION_Y {}\n".format(prefix, self.dilation_y))
135                f.write("#define {}_DILATION_W {}\n".format(prefix, self.dilation_x))
136                f.write("#define {}_DILATION_H {}\n".format(prefix, self.dilation_y))
137
138            if op_name == "FULLY_CONNECTED":
139                f.write("#define {}_OUTPUT_MULTIPLIER {}\n".format(prefix, self.quantized_multiplier))
140                f.write("#define {}_OUTPUT_SHIFT {}\n".format(prefix, self.quantized_shift))
141
142            if op_name == "DEPTHWISE_CONV_2D":
143                f.write("#define {}_ACCUMULATION_DEPTH {}\n".format(prefix,
144                                                                    self.input_ch * self.x_input * self.y_input))
145
146        self.format_output_file(filepath)
147
148    def shape_to_config(self, input_shape, filter_shape, output_shape, layer_name):
149        if layer_name == "AVERAGE_POOL_2D":
150            [_, self.filter_y, self.filter_x, _] = input_shape
151
152        elif layer_name == "CONV_2D" or layer_name == "DEPTHWISE_CONV_2D":
153            [self.batches, self.y_input, self.x_input, self.input_ch] = input_shape
154            [output_ch, self.filter_y, self.filter_x, self.input_ch] = filter_shape
155
156        elif layer_name == "FULLY_CONNECTED":
157            [self.batches, self.input_ch] = input_shape
158            [self.input_ch, self.output_ch] = filter_shape
159            [self.y_output, self.x_output] = output_shape
160            self.x_input = 1
161            self.y_input = 1
162
163        elif layer_name == "SOFTMAX":
164            [self.y_input, self.x_input] = input_shape
165
166        if len(input_shape) == 4:
167            if len(output_shape) == 2:
168                [self.y_output, self.x_output] = output_shape
169            else:
170                [d, self.y_output, self.x_output, d1] = output_shape
171
172        self.set_output_dims_and_padding(self.x_output, self.y_output)
173
174    def extract_from_model(self, json_file, tensor_details):
175
176        with open(json_file, 'r') as in_file:
177            data = in_file.read()
178            data = json.loads(data)
179            tensors = data['subgraphs'][0]['tensors']
180            operators = data['subgraphs'][0]['operators']
181            operator_codes = data['operator_codes']
182            buffers = data['buffers']
183
184        op_index = 0
185        for op in operators:
186            if 'opcode_index' in op:
187                builtin_name = operator_codes[op['opcode_index']]['builtin_code']
188            else:
189                builtin_name = ""
190            #op_name = 'layer_' + str(op_index) + '_' + builtin_name
191
192            # Get stride and padding.
193            if 'builtin_options' in op:
194                builtin_options = op['builtin_options']
195                if 'stride_w' in builtin_options:
196                    self.stride_x = builtin_options['stride_w']
197                if 'stride_h' in builtin_options:
198                    self.stride_y = builtin_options['stride_h']
199                    if 'padding' in builtin_options:
200                        self.has_padding = False
201                        self.padding = 'VALID'
202                    else:
203                        self.has_padding = True
204                        self.padding = 'SAME'
205
206            # Generate weights, bias, multipliers, shifts and config.
207            if builtin_name not in self.supported_ops:
208                print(f"WARNING: skipping unsupported operator {builtin_name}")
209            else:
210
211                input_index = op['inputs'][0]
212                output_index = op['outputs'][0]
213
214                input_tensor = tensor_details[input_index]
215                output_tensor = tensor_details[output_index]
216                input_scale = input_tensor['quantization'][0]
217                output_scale = output_tensor['quantization'][0]
218                self.input_zero_point = input_tensor['quantization'][1]
219                self.output_zero_point = output_tensor['quantization'][1]
220
221                input_shape = input_tensor['shape']
222                output_shape = output_tensor['shape']
223
224                if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
225                   or builtin_name == "FULLY_CONNECTED":
226                    weights_index = op['inputs'][1]
227                    bias_index = op['inputs'][2]
228
229                    bias = tensors[bias_index]
230                    weights = tensors[weights_index]
231
232                    weights_data_index = weights['buffer']
233                    weights_data_buffer = buffers[weights_data_index]
234                    weights_data = self.from_bytes(weights_data_buffer['data'], 1)
235
236                    bias_data_index = bias['buffer']
237                    bias_data_buffer = buffers[bias_data_index]
238                    bias_data = self.from_bytes(bias_data_buffer['data'], 4)
239
240                    scaling_factors = weights['quantization']['scale']
241
242                    self.output_ch = len(scaling_factors)
243
244                    filter_shape = weights['shape']
245                else:
246                    filter_shape = []
247
248                self.input_scale, self.output_scale = input_scale, output_scale
249
250                if builtin_name == "SOFTMAX":
251                    self.calc_softmax_params()
252
253                self.shape_to_config(input_shape, filter_shape, output_shape, builtin_name)
254
255                nice_name = 'layer_' + str(op_index) + '_' + builtin_name.lower()
256
257                if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
258                   or builtin_name == "FULLY_CONNECTED":
259                    self.generate_c_array(nice_name + "_weights", weights_data)
260                    self.generate_c_array(nice_name + "_bias", bias_data, datatype='int32_t')
261
262                if builtin_name == "FULLY_CONNECTED":
263                    self.weights_scale = scaling_factors[0]
264                    self.quantize_multiplier()
265
266                elif builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D":
267                    self.scaling_factors = scaling_factors
268                    per_channel_multiplier, per_channel_shift = self.generate_quantize_per_channel_multiplier()
269
270                    self.generate_c_array(f"{nice_name}_output_mult", per_channel_multiplier, datatype='int32_t')
271                    self.generate_c_array(f"{nice_name}_output_shift", per_channel_shift, datatype='int32_t')
272
273                self.write_c_config_header(nice_name, builtin_name, op_index)
274
275            op_index = op_index + 1
276
277    def generate_data(self, input_data=None, weights=None, biases=None) -> None:
278
279        interpreter = Interpreter(model_path=str(self.tflite_model),
280                                  experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
281        interpreter.allocate_tensors()
282
283        # Needed for input/output scale/zp as equivalant json file data has too low precision.
284        tensor_details = interpreter.get_tensor_details()
285
286        output_details = interpreter.get_output_details()
287        (self.output_scale, self.output_zero_point) = output_details[0]['quantization']
288
289        input_details = interpreter.get_input_details()
290        if len(input_details) != 1:
291            raise RuntimeError(f"Only single input supported.")
292        input_shape = input_details[0]['shape']
293        input_data = self.get_randomized_input_data(input_data, input_shape)
294        interpreter.set_tensor(input_details[0]["index"], tf.cast(input_data, tf.int8))
295
296        self.generate_c_array("input", input_data)
297
298        json_file = self.tflite_to_json(self.tflite_model, self.schema_file)
299        self.extract_from_model(json_file, tensor_details)
300
301        interpreter.invoke()
302        output_data = interpreter.get_tensor(output_details[0]["index"])
303        self.generate_c_array("output_ref", np.clip(output_data, self.out_activation_min, self.out_activation_max))
304
305        self.write_c_header_wrapper()
306
307
308if __name__ == '__main__':
309    parser = argparse.ArgumentParser(description="Extract operator data from given model if operator is supported."
310                                     "This provides a way for CMSIS-NN to directly process a model.")
311    parser.add_argument('--schema-file', type=str, required=True, help="Path to schema file.")
312    parser.add_argument('--tflite-model', type=str, required=True, help="Path to tflite file.")
313    parser.add_argument('--model-name',
314                        type=str,
315                        help="Descriptive model name. If left out it will be inferred from actual model.")
316
317    args = parser.parse_args()
318
319    schema_file = args.schema_file
320    tflite_model = args.tflite_model
321
322    if args.model_name:
323        dataset = args.model_name
324    else:
325        dataset, _ = os.path.splitext(os.path.basename(tflite_model))
326
327    model_extractor = MODEL_EXTRACTOR(dataset, schema_file, tflite_model)
328    model_extractor.generate_data()
329