1#!/usr/bin/env python3
2#
3# SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the License); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an AS IS BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19import os
20import sys
21import json
22import argparse
23import subprocess
24
25import numpy as np
26import tensorflow as tf
27
28from conv_settings import ConvSettings
29from softmax_settings import SoftmaxSettings
30from fully_connected_settings import FullyConnectedSettings
31
32
33class MODEL_EXTRACTOR(SoftmaxSettings, FullyConnectedSettings, ConvSettings):
34
35    def __init__(self, dataset, schema_file, tflite_model):
36
37        super().__init__(dataset, None, True, True, True, schema_file)
38
39        self.tflite_model = tflite_model
40
41        (self.quantized_multiplier, self.quantized_shift) = 0, 0
42        self.is_int16xint8 = False  # Only 8-bit supported.
43        self.diff_min, self.input_multiplier, self.input_left_shift = 0, 0, 0
44
45        self.supported_ops = ["CONV_2D", "DEPTHWISE_CONV_2D", "FULLY_CONNECTED", "AVERAGE_POOL_2D", "SOFTMAX"]
46
47    def from_bytes(self, tensor_data, type_size) -> list:
48        result = []
49        tmp_ints = []
50
51        if not (type_size == 1 or type_size == 2 or type_size == 4):
52            raise RuntimeError("Size not supported: {}".format(type_size))
53
54        count = 0
55        for val in tensor_data:
56            tmp_ints.append(val)
57            count = count + 1
58            if count % type_size == 0:
59                tmp_bytes = bytearray(tmp_ints)
60                result.append(int.from_bytes(tmp_bytes, 'little', signed=True))
61                tmp_ints.clear()
62
63        return result
64
65    def tflite_to_json(self, tflite_input, schema):
66        name_without_ext, ext = os.path.splitext(tflite_input)
67        new_name = name_without_ext + '.json'
68        dirname = os.path.dirname(tflite_input)
69
70        if schema is None:
71            raise RuntimeError("A schema file is required.")
72        command = f"flatc -o {dirname} --strict-json -t {schema} -- {tflite_input}"
73        command_list = command.split(' ')
74        try:
75            process = subprocess.run(command_list)
76            if process.returncode != 0:
77                print(f"ERROR: {command = }")
78                sys.exit(1)
79        except Exception as e:
80            raise RuntimeError(f"{e} from: {command = }. Did you install flatc?")
81
82        return new_name
83
84    def write_c_config_header(self, name_prefix, op_name, op_index) -> None:
85        filename = f"{name_prefix}_config_data.h"
86
87        self.generated_header_files.append(filename)
88        filepath = self.headers_dir + filename
89
90        prefix = f'{op_name}_{op_index}'
91
92        print("Writing C header with config data {}...".format(filepath))
93        with open(filepath, "w+") as f:
94            self.write_c_common_header(f)
95            f.write("#define {}_OUT_CH {}\n".format(prefix, self.output_ch))
96            f.write("#define {}_IN_CH {}\n".format(prefix, self.input_ch))
97            f.write("#define {}_INPUT_W {}\n".format(prefix, self.x_input))
98            f.write("#define {}_INPUT_H {}\n".format(prefix, self.y_input))
99            f.write("#define {}_DST_SIZE {}\n".format(prefix,
100                                                      self.x_output * self.y_output * self.output_ch * self.batches))
101            if op_name == "SOFTMAX":
102                f.write("#define {}_NUM_ROWS {}\n".format(prefix, self.y_input))
103                f.write("#define {}_ROW_SIZE {}\n".format(prefix, self.x_input))
104                f.write("#define {}_MULT {}\n".format(prefix, self.input_multiplier))
105                f.write("#define {}_SHIFT {}\n".format(prefix, self.input_left_shift))
106                if not self.is_int16xint8:
107                    f.write("#define {}_DIFF_MIN {}\n".format(prefix, -self.diff_min))
108            else:
109                f.write("#define {}_FILTER_X {}\n".format(prefix, self.filter_x))
110                f.write("#define {}_FILTER_Y {}\n".format(prefix, self.filter_y))
111                f.write("#define {}_FILTER_W {}\n".format(prefix, self.filter_x))
112                f.write("#define {}_FILTER_H {}\n".format(prefix, self.filter_y))
113                f.write("#define {}_STRIDE_X {}\n".format(prefix, self.stride_x))
114                f.write("#define {}_STRIDE_Y {}\n".format(prefix, self.stride_y))
115                f.write("#define {}_STRIDE_W {}\n".format(prefix, self.stride_x))
116                f.write("#define {}_STRIDE_H {}\n".format(prefix, self.stride_y))
117                f.write("#define {}_PAD_X {}\n".format(prefix, self.pad_x))
118                f.write("#define {}_PAD_Y {}\n".format(prefix, self.pad_y))
119                f.write("#define {}_PAD_W {}\n".format(prefix, self.pad_x))
120                f.write("#define {}_PAD_H {}\n".format(prefix, self.pad_y))
121                f.write("#define {}_OUTPUT_W {}\n".format(prefix, self.x_output))
122                f.write("#define {}_OUTPUT_H {}\n".format(prefix, self.y_output))
123                f.write("#define {}_INPUT_OFFSET {}\n".format(prefix, -self.input_zero_point))
124                f.write("#define {}_INPUT_SIZE {}\n".format(prefix, self.x_input * self.y_input * self.input_ch))
125                f.write("#define {}_OUT_ACTIVATION_MIN {}\n".format(prefix, self.out_activation_min))
126                f.write("#define {}_OUT_ACTIVATION_MAX {}\n".format(prefix, self.out_activation_max))
127                f.write("#define {}_INPUT_BATCHES {}\n".format(prefix, self.batches))
128                f.write("#define {}_OUTPUT_OFFSET {}\n".format(prefix, self.output_zero_point))
129                f.write("#define {}_DILATION_X {}\n".format(prefix, self.dilation_x))
130                f.write("#define {}_DILATION_Y {}\n".format(prefix, self.dilation_y))
131                f.write("#define {}_DILATION_W {}\n".format(prefix, self.dilation_x))
132                f.write("#define {}_DILATION_H {}\n".format(prefix, self.dilation_y))
133
134            if op_name == "FULLY_CONNECTED":
135                f.write("#define {}_OUTPUT_MULTIPLIER {}\n".format(prefix, self.quantized_multiplier))
136                f.write("#define {}_OUTPUT_SHIFT {}\n".format(prefix, self.quantized_shift))
137
138            if op_name == "DEPTHWISE_CONV_2D":
139                f.write("#define {}_ACCUMULATION_DEPTH {}\n".format(prefix,
140                                                                    self.input_ch * self.x_input * self.y_input))
141
142        self.format_output_file(filepath)
143
144    def shape_to_config(self, input_shape, filter_shape, output_shape, layer_name):
145        if layer_name == "AVERAGE_POOL_2D":
146            [_, self.filter_y, self.filter_x, _] = input_shape
147
148        elif layer_name == "CONV_2D" or layer_name == "DEPTHWISE_CONV_2D":
149            [self.batches, self.y_input, self.x_input, self.input_ch] = input_shape
150            [output_ch, self.filter_y, self.filter_x, self.input_ch] = filter_shape
151
152        elif layer_name == "FULLY_CONNECTED":
153            [self.batches, self.input_ch] = input_shape
154            [self.input_ch, self.output_ch] = filter_shape
155            [self.y_output, self.x_output] = output_shape
156            self.x_input = 1
157            self.y_input = 1
158
159        elif layer_name == "SOFTMAX":
160            [self.y_input, self.x_input] = input_shape
161
162        if len(input_shape) == 4:
163            if len(output_shape) == 2:
164                [self.y_output, self.x_output] = output_shape
165            else:
166                [d, self.y_output, self.x_output, d1] = output_shape
167
168        self.calculate_padding(self.x_output, self.y_output, self.x_input, self.y_input)
169
170    def extract_from_model(self, json_file, tensor_details):
171
172        with open(json_file, 'r') as in_file:
173            data = in_file.read()
174            data = json.loads(data)
175            tensors = data['subgraphs'][0]['tensors']
176            operators = data['subgraphs'][0]['operators']
177            operator_codes = data['operator_codes']
178            buffers = data['buffers']
179
180        op_index = 0
181        for op in operators:
182            if 'opcode_index' in op:
183                builtin_name = operator_codes[op['opcode_index']]['builtin_code']
184            else:
185                builtin_name = ""
186
187            # Get stride and padding.
188            if 'builtin_options' in op:
189                builtin_options = op['builtin_options']
190                if 'stride_w' in builtin_options:
191                    self.stride_x = builtin_options['stride_w']
192                if 'stride_h' in builtin_options:
193                    self.stride_y = builtin_options['stride_h']
194                    if 'padding' in builtin_options:
195                        self.has_padding = False
196                        self.padding = 'VALID'
197                    else:
198                        self.has_padding = True
199                        self.padding = 'SAME'
200
201            # Generate weights, bias, multipliers, shifts and config.
202            if builtin_name not in self.supported_ops:
203                print(f"WARNING: skipping unsupported operator {builtin_name}")
204            else:
205
206                input_index = op['inputs'][0]
207                output_index = op['outputs'][0]
208
209                input_tensor = tensor_details[input_index]
210                output_tensor = tensor_details[output_index]
211                input_scale = input_tensor['quantization'][0]
212                output_scale = output_tensor['quantization'][0]
213                self.input_zero_point = input_tensor['quantization'][1]
214                self.output_zero_point = output_tensor['quantization'][1]
215
216                input_shape = input_tensor['shape']
217                output_shape = output_tensor['shape']
218
219                if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
220                   or builtin_name == "FULLY_CONNECTED":
221                    weights_index = op['inputs'][1]
222                    bias_index = op['inputs'][2]
223
224                    weight_tensor = tensor_details[weights_index]
225                    scaling_factors = weight_tensor['quantization_parameters']['scales'].tolist()
226
227                    bias = tensors[bias_index]
228                    weights = tensors[weights_index]
229
230                    weights_data_index = weights['buffer']
231                    weights_data_buffer = buffers[weights_data_index]
232                    weights_data = self.from_bytes(weights_data_buffer['data'], 1)
233
234                    bias_data_index = bias['buffer']
235                    bias_data_buffer = buffers[bias_data_index]
236                    bias_data = self.from_bytes(bias_data_buffer['data'], 4)
237
238                    self.output_ch = len(scaling_factors)
239
240                    filter_shape = weights['shape']
241                else:
242                    filter_shape = []
243
244                self.input_scale, self.output_scale = input_scale, output_scale
245
246                if builtin_name == "SOFTMAX":
247                    self.calc_softmax_params()
248
249                self.shape_to_config(input_shape, filter_shape, output_shape, builtin_name)
250
251                nice_name = 'layer_' + str(op_index) + '_' + builtin_name.lower()
252
253                if builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D" \
254                   or builtin_name == "FULLY_CONNECTED":
255                    self.generate_c_array(nice_name + "_weights", weights_data)
256                    self.generate_c_array(nice_name + "_bias", bias_data, datatype='int32_t')
257
258                if builtin_name == "FULLY_CONNECTED":
259                    self.weights_scale = scaling_factors[0]
260                    self.quantize_multiplier()
261
262                elif builtin_name == "CONV_2D" or builtin_name == "DEPTHWISE_CONV_2D":
263                    self.scaling_factors = scaling_factors
264                    per_channel_multiplier, per_channel_shift = self.generate_quantize_per_channel_multiplier()
265
266                    self.generate_c_array(f"{nice_name}_output_mult", per_channel_multiplier, datatype='int32_t')
267                    self.generate_c_array(f"{nice_name}_output_shift", per_channel_shift, datatype='int32_t')
268
269                self.write_c_config_header(nice_name, builtin_name, op_index)
270
271            op_index = op_index + 1
272
273    def generate_data(self, input_data=None, weights=None, biases=None) -> None:
274
275        interpreter = self.Interpreter(model_path=str(self.tflite_model),
276                                       experimental_op_resolver_type=self.OpResolverType.BUILTIN_REF)
277        interpreter.allocate_tensors()
278
279        # Needed for input/output scale/zp as equivalant json file data has too low precision.
280        tensor_details = interpreter.get_tensor_details()
281
282        output_details = interpreter.get_output_details()
283        (self.output_scale, self.output_zero_point) = output_details[0]['quantization']
284
285        input_details = interpreter.get_input_details()
286        if len(input_details) != 1:
287            raise RuntimeError("Only single input supported.")
288        input_shape = input_details[0]['shape']
289        input_data = self.get_randomized_input_data(input_data, input_shape)
290        interpreter.set_tensor(input_details[0]["index"], tf.cast(input_data, tf.int8))
291
292        self.generate_c_array("input", input_data)
293
294        json_file = self.tflite_to_json(self.tflite_model, self.schema_file)
295        self.extract_from_model(json_file, tensor_details)
296
297        interpreter.invoke()
298        output_data = interpreter.get_tensor(output_details[0]["index"])
299        self.generate_c_array("output_ref", np.clip(output_data, self.out_activation_min, self.out_activation_max))
300
301        self.write_c_header_wrapper()
302
303
304if __name__ == '__main__':
305    parser = argparse.ArgumentParser(description="Extract operator data from given model if operator is supported."
306                                     "This provides a way for CMSIS-NN to directly process a model.")
307    parser.add_argument('--schema-file', type=str, required=True, help="Path to schema file.")
308    parser.add_argument('--tflite-model', type=str, required=True, help="Path to tflite file.")
309    parser.add_argument('--model-name',
310                        type=str,
311                        help="Descriptive model name. If left out it will be inferred from actual model.")
312
313    args = parser.parse_args()
314
315    schema_file = args.schema_file
316    tflite_model = args.tflite_model
317
318    if args.model_name:
319        dataset = args.model_name
320    else:
321        dataset, _ = os.path.splitext(os.path.basename(tflite_model))
322
323    model_extractor = MODEL_EXTRACTOR(dataset, schema_file, tflite_model)
324    model_extractor.generate_data()
325