1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import tensorflow as tf
19import math
20import numpy as np
21
22from tensorflow.lite.python.interpreter import Interpreter
23from tensorflow.lite.python.interpreter import OpResolverType
24import tf_keras as keras
25
26class Op_pad(Lib.op_utils.Op_type):
27
28    def get_shapes(params):
29        shapes = {}
30        shapes["input_tensor"] = (params["input_n"], params["input_h"], params["input_w"], params["input_c"])
31        shapes["representational_dataset"] = shapes["input_tensor"]
32
33        return shapes
34
35    def generate_keras_model(shapes, params):
36
37        model = keras.models.Sequential()
38        model.add(keras.layers.InputLayer(input_shape=shapes["input_tensor"][1:]))
39
40        if (params["pre_pad_n"] == params["post_pad_n"] == params["pre_pad_h"] == params["post_pad_h"] == 0):
41            model.add(keras.layers.ZeroPadding2D(padding=((params["pre_pad_w"], params["post_pad_w"]), (params["pre_pad_c"], params["post_pad_c"])), data_format="channels_first"))
42        elif (params["pre_pad_n"] == params["post_pad_n"] == params["pre_pad_c"] == params["post_pad_c"] == 0):
43            model.add(keras.layers.ZeroPadding2D(padding=((params["pre_pad_h"], params["post_pad_h"]), (params["pre_pad_w"], params["post_pad_w"])), data_format="channels_last"))
44        else:
45            raise ValueError(f"Keras can only generate padding for (h,w) or (w,c), the others must be zero.")
46
47        return model
48
49    def generate_data_tflite(tflite_fname, params):
50        tensors = {}
51        effective_scales = {}
52        scales = {}
53        generated_params = {}
54
55        generated_params["pad_value"] = -128
56
57        interpreter = Interpreter(str(tflite_fname), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
58        interpreter.allocate_tensors()
59
60        output_details = interpreter.get_output_details()
61        output_n = output_details[0]['shape'][3]
62        output_h = output_details[0]['shape'][2]
63        output_w = output_details[0]['shape'][1]
64        output_c = output_details[0]['shape'][0]
65
66        generated_params["output_size"] = output_n * output_h * output_w * output_c;
67
68        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales)
69
70