1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import copy
19import tensorflow as tf
20import math
21import numpy as np
22
23from tensorflow.lite.python.interpreter import Interpreter
24from tensorflow.lite.python.interpreter import OpResolverType
25import tf_keras as keras
26
27
28class Op_transpose(Lib.op_utils.Op_type):
29
30    def get_shapes(params):
31        shapes = {}
32        input_shape = copy.deepcopy(params["in_dim"])
33        shapes["input_tensor"] = input_shape
34        shapes["representational_dataset"] = input_shape
35
36        return shapes
37
38    def generate_keras_model(shapes, params):
39        input_shape = shapes["input_tensor"]
40        input_lhs = keras.layers.Input(batch_input_shape=input_shape)
41        layer = tf.transpose(input_lhs, perm=params["perm"])
42        model = keras.Model([input_lhs], [layer])
43
44        return model
45
46    def generate_data_tflite(tflite_fname, params):
47        tensors = {}
48        effective_scales = {}
49        scales = {}
50        generated_params = {}
51        aliases = {}
52
53        input_shape = params["in_dim"]
54        perm = params["perm"]
55        perm_size = len(perm)
56
57        generated_params["size"] = math.prod(x for x in input_shape)
58        generated_params["perm_size"] = perm_size
59
60        # Derive output dims and fill with zeroes for C unit test file glue
61        if perm_size == 2:
62            generated_params["out_dim"] = \
63                [input_shape[perm[0]], input_shape[perm[1]], 0, 0]
64            params["in_dim"].append(0)
65            params["in_dim"].append(0)
66        elif perm_size == 3:
67            generated_params["out_dim"] = \
68                [input_shape[perm[0]], input_shape[perm[1]], input_shape[perm[2]], 0]
69            params["in_dim"].append(0)
70        elif perm_size == 4:
71            generated_params["out_dim"] = \
72                [input_shape[perm[0]], input_shape[perm[1]], input_shape[perm[2]], input_shape[perm[3]]]
73        else:
74            raise RuntimeError("Permutation size not supported")
75
76        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales, aliases)
77