1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 2# 3# SPDX-License-Identifier: Apache-2.0 4# 5# Licensed under the Apache License, Version 2.0 (the License); you may 6# not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an AS IS BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17import Lib.op_utils 18import copy 19import tensorflow as tf 20import math 21import numpy as np 22 23from tensorflow.lite.python.interpreter import Interpreter 24from tensorflow.lite.python.interpreter import OpResolverType 25import tf_keras as keras 26 27 28class Op_transpose(Lib.op_utils.Op_type): 29 30 def get_shapes(params): 31 shapes = {} 32 input_shape = copy.deepcopy(params["in_dim"]) 33 shapes["input_tensor"] = input_shape 34 shapes["representational_dataset"] = input_shape 35 36 return shapes 37 38 def generate_keras_model(shapes, params): 39 input_shape = shapes["input_tensor"] 40 input_lhs = keras.layers.Input(batch_input_shape=input_shape) 41 layer = tf.transpose(input_lhs, perm=params["perm"]) 42 model = keras.Model([input_lhs], [layer]) 43 44 return model 45 46 def generate_data_tflite(tflite_fname, params): 47 tensors = {} 48 effective_scales = {} 49 scales = {} 50 generated_params = {} 51 aliases = {} 52 53 input_shape = params["in_dim"] 54 perm = params["perm"] 55 perm_size = len(perm) 56 57 generated_params["size"] = math.prod(x for x in input_shape) 58 generated_params["perm_size"] = perm_size 59 60 # Derive output dims and fill with zeroes for C unit test file glue 61 if perm_size == 2: 62 generated_params["out_dim"] = \ 63 [input_shape[perm[0]], input_shape[perm[1]], 0, 0] 64 params["in_dim"].append(0) 65 params["in_dim"].append(0) 66 elif perm_size == 3: 67 generated_params["out_dim"] = \ 68 [input_shape[perm[0]], input_shape[perm[1]], input_shape[perm[2]], 0] 69 params["in_dim"].append(0) 70 elif perm_size == 4: 71 generated_params["out_dim"] = \ 72 [input_shape[perm[0]], input_shape[perm[1]], input_shape[perm[2]], input_shape[perm[3]]] 73 else: 74 raise RuntimeError("Permutation size not supported") 75 76 return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales, aliases) 77