1# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2#
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the License); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an AS IS BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17import Lib.op_utils
18import math
19
20import numpy as np
21
22
23class Op_fully_connected(Lib.op_utils.Op_type):
24
25    def get_shapes(params):
26        shapes = {}
27
28        # Common default parameters
29        params["batch_size"] = 1 if "batch_size" not in params else params["batch_size"]
30        params["generate_bias"] = True if "generate_bias" not in params else params["generate_bias"]
31        if "out_activation_min" not in params:
32            params["out_activation_min"] = Lib.op_utils.get_dtype_min(params["input_data_type"])
33        if "out_activation_max" not in params:
34            params["out_activation_max"] = Lib.op_utils.get_dtype_max(params["input_data_type"])
35        if "bias_min" not in params:
36            params["bias_min"] = Lib.op_utils.get_dtype_min("int32_t")
37        if "bias_max" not in params:
38            params["bias_max"] = Lib.op_utils.get_dtype_max("int32_t")
39        if "weights_min" not in params:
40            params["weights_min"] = Lib.op_utils.get_dtype_min("int32_t")
41        if "weights_max" not in params:
42            params["weights_max"] = Lib.op_utils.get_dtype_max("int32_t")
43
44        in_ch = params["in_ch"]
45        out_ch = params["out_ch"]
46
47        shapes["input"] = (params["batch_size"], in_ch)
48        shapes["weight_shape"] = (in_ch, 1, 1, out_ch)
49
50        if params["generate_bias"]:
51            shapes["bias_shape"] = [out_ch]
52            params["json_template"] = "fully_connected.json"
53        else:
54            shapes["bias_shape"] = []
55            params["json_template"] = "fully_connected_null_bias.json"
56
57        return shapes
58
59    def generate_data_json(shapes, params):
60        tensors = {}
61        effective_scales = {}
62        scales = {}
63        generated_params = {}
64        aliases = {}
65
66        generated_params["input_batches"] = params["batch_size"]
67        generated_params["input_w"] = 1
68        generated_params["input_h"] = 1
69        generated_params["dst_size"] = params["out_ch"] * params["batch_size"]
70        generated_params["accumulation_depth"] = params["in_ch"]
71        generated_params["input_offset"] = -params["input_zp"]
72        generated_params["output_offset"] = params["output_zp"]
73
74        # To be removed
75        aliases["input_bias"] = "biases"
76        aliases["output"] = "output_ref"
77        aliases["input_weights"] = "weights"
78
79        # TODOx
80        minval = -7
81        maxval = 8
82        weights = np.random.randint(minval, maxval, size=shapes["weight_shape"])
83
84        uneven = weights.size % 2
85        if uneven:
86            weights = np.append(weights, 0)
87
88        temp = np.reshape(weights, (weights.size // 2, 2)).astype(np.uint8)
89        weights = 0xff & ((0xf0 & (temp[:, 1] << 4)) | (temp[:, 0] & 0xf))
90        tensors["input_weights"] = weights
91
92        if params["generate_bias"]:
93            tensors["input_bias"] = np.random.randint(minval, maxval, size=shapes["bias_shape"])
94        else:
95            tensors["input_bias"] = None
96
97        def quantize_multiplier(input_scale, weights_scale, output_scale):
98            def quantize_scale(scale):
99                significand, shift = math.frexp(scale)
100                significand_q31 = round(significand * (1 << 31))
101                return significand_q31, shift
102
103            input_product_scale = input_scale * weights_scale
104            if input_product_scale < 0:
105                raise RuntimeError("negative input product scale")
106            real_multipler = input_product_scale / output_scale
107            return quantize_scale(real_multipler)
108
109        generated_params["output_multiplier"], generated_params["output_shift"] = quantize_multiplier(
110            params["input_scale"], params["w_scale"], params["output_scale"])
111
112        return Lib.op_utils.Generated_data(generated_params, tensors, scales, effective_scales, aliases)
113