1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import tensorflow as tf
19import math
20import numpy as np
21
22from tensorflow.lite.python.interpreter import Interpreter
23from tensorflow.lite.python.interpreter import OpResolverType
24import tf_keras as keras
25
26class Op_batch_matmul(Lib.op_utils.Op_type):
27
28    def get_shapes(params):
29        shapes = {}
30        shapes["lhs_input_tensor"] = (params["lhs_batch"], params["lhs_height"], params["lhs_rows"], params["lhs_cols"])
31        shapes["rhs_input_tensor"] = (params["rhs_batch"], params["rhs_height"], params["rhs_rows"], params["rhs_cols"])
32        shapes["representational_dataset"] = (params["lhs_batch"], params["lhs_height"], params["lhs_rows"], params["lhs_cols"])
33        shapes["representational_dataset2"] = (params["rhs_batch"], params["rhs_height"], params["rhs_rows"], params["rhs_cols"])
34        shapes["different_in_shapes"]=True
35
36        return shapes
37
38    def generate_keras_model(shapes, params):
39        tf.keras.backend.clear_session()
40        input_shape_lhs = (params["lhs_batch"], params["lhs_height"], params["lhs_rows"], params["lhs_cols"])
41        input_shape_rhs = (params["rhs_batch"], params["rhs_height"], params["rhs_rows"], params["rhs_cols"])
42        input_lhs = keras.layers.Input(batch_input_shape=input_shape_lhs)
43        input_rhs = keras.layers.Input(batch_input_shape=input_shape_rhs)
44
45        layer = tf.matmul(input_lhs, input_rhs, transpose_a=params["adj_x"], transpose_b=params["adj_y"])
46        model = keras.Model([input_lhs, input_rhs], [layer])
47
48        return model
49
50    def generate_data_tflite(tflite_fname, params):
51        tensors = {}
52        effective_scales = {}
53        scales = {}
54        generated_params = {}
55        aliases = {}
56
57        # To be removed
58        aliases["output_multiplier"] = "output_mult"
59        aliases["output"] = "output_ref"
60
61        interpreter = Interpreter(str(tflite_fname), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
62        interpreter.allocate_tensors()
63        tensor_details = interpreter.get_tensor_details()
64
65        lhs = tensor_details[0]
66        rhs = tensor_details[1]
67
68        input_details = interpreter.get_input_details()
69        (scales["lhs_scale"], scales["lhs_zero_point"]) = input_details[0]['quantization']
70        (scales["rhs_scale"], scales["rhs_zero_point"]) = input_details[1]['quantization']
71
72        output_details = interpreter.get_output_details()
73        (scales["output_scale"], scales["output_zero_point"]) = output_details[0]['quantization']
74
75        tensors["lhs_input_tensor"] = interpreter.get_tensor(lhs['index'])
76        tensors["rhs_input_tensor"] = interpreter.get_tensor(rhs['index'])
77        tensors["lhs_transposed_tensor"] = tf.transpose(tensors["lhs_input_tensor"], [0,1,3,2]).numpy()
78        tensors["rhs_transposed_tensor"] = tf.transpose(tensors["rhs_input_tensor"], [0,1,3,2]).numpy()
79
80        minval = Lib.op_utils.get_dtype_min(params["input_data_type"])
81        maxval = Lib.op_utils.get_dtype_max(params["input_data_type"])
82
83        n_output = output_details[0]['shape'][0]
84        h_output = output_details[0]['shape'][1]
85        w_output = output_details[0]['shape'][2]
86        c_output = output_details[0]['shape'][3]
87
88        generated_params["dst_size"] = n_output * h_output * w_output * c_output
89        generated_params["output_batch"] = n_output
90        generated_params["output_height"] = h_output
91        generated_params["output_rows"] = w_output
92        generated_params["output_cols"] = c_output
93        generated_params["lhs_offset"] = -lhs['quantization_parameters']['zero_points'][0]
94        generated_params["rhs_offset"] = -rhs['quantization_parameters']['zero_points'][0]
95        generated_params["output_offset"] = output_details[0]['quantization'][1]
96        generated_params["activation_min"] = minval
97        generated_params["activation_max"] = maxval
98
99        def quantize_scale(scales):
100            effective_output_scale = scales["lhs_scale"] * scales["rhs_scale"] / scales["output_scale"]
101
102            significand, shift = math.frexp(effective_output_scale)
103            significand_q31 = round(significand * (1 << 31))
104            return significand_q31, shift
105
106        mult, shift = quantize_scale(scales)
107        generated_params["output_multiplier"] = mult
108        generated_params["output_shift"] = shift
109
110        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales, aliases)