1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import tensorflow as tf
19import math
20import numpy as np
21
22from tensorflow.lite.python.interpreter import Interpreter
23from tensorflow.lite.python.interpreter import OpResolverType
24import tf_keras as keras
25
26class Op_pooling(Lib.op_utils.Op_type):
27
28    def get_shapes(params):
29        shapes = {}
30        shapes["input_tensor"] = (params["batch_size"], params["input_h"], params["input_w"], params["input_c"])
31        shapes["representational_dataset"] = shapes["input_tensor"]
32
33        return shapes
34
35    def generate_keras_model(shapes, params):
36        model = keras.models.Sequential()
37        model.add(keras.layers.InputLayer(input_shape=shapes["input_tensor"][1:], batch_size=shapes["input_tensor"][0]))
38        if params["op_type"] == 'avgpool':
39            model.add(
40                keras.layers.AveragePooling2D(pool_size=(params["filter_h"], params["filter_w"]),
41                                             strides=(params["stride_h"], params["stride_w"]),
42                                             padding=params["pad"],
43                                             input_shape=shapes["input_tensor"][1:]))
44        elif params["op_type"] == 'maxpool':
45            model.add(
46                keras.layers.MaxPooling2D(pool_size=(params["filter_h"], params["filter_w"]),
47                                             strides=(params["stride_h"], params["stride_w"]),
48                                             padding=params["pad"],
49                                             input_shape=shapes["input_tensor"][1:]))
50        else:
51            raise RuntimeError("Wrong test type")
52
53
54        return model
55
56    def generate_data_tflite(tflite_fname, params):
57        tensors = {}
58        effective_scales = {}
59        scales = {}
60        generated_params = {}
61
62        interpreter = Interpreter(str(tflite_fname), experimental_op_resolver_type=OpResolverType.BUILTIN_REF)
63        interpreter.allocate_tensors()
64        output_details = interpreter.get_output_details()
65
66        generated_params["output_c"] = output_details[0]['shape'][3]
67        generated_params["output_w"] = output_details[0]['shape'][2]
68        generated_params["output_h"] = output_details[0]['shape'][1]
69
70        if params["pad"] == "SAME":
71            pad_along_width = max((generated_params["output_w"] - 1) * params["stride_w"] + params["filter_w"] - params["input_w"], 0)
72            pad_along_height = max((generated_params["output_h"] - 1) * params["stride_h"] + params["filter_h"] - params["input_h"], 0)
73
74            generated_params["padding_h"] = pad_along_height // 2
75            generated_params["padding_w"] = pad_along_width // 2
76        else:
77            generated_params["padding_h"] = 0
78            generated_params["padding_w"] = 0
79
80        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales)
81