1#
2# Copyright (c) 2023, Arm Limited. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6
7from anytree import RenderTree
8from anytree.importer import DictImporter
9from prettytable import PrettyTable
10
11
12class TfaPrettyPrinter:
13    """A class for printing the memory layout of ELF files.
14
15    This class provides interfaces for printing various memory layout views of
16    ELF files in a TF-A build. It can be used to understand how the memory is
17    structured and consumed.
18    """
19
20    def __init__(self, columns: int = None, as_decimal: bool = False):
21        self.term_size = columns if columns and columns > 120 else 120
22        self._tree = None
23        self._footprint = None
24        self._symbol_map = None
25        self.as_decimal = as_decimal
26
27    def format_args(self, *args, width=10, fmt=None):
28        if not fmt and type(args[0]) is int:
29            fmt = f">{width}x" if not self.as_decimal else f">{width}"
30        return [f"{arg:{fmt}}" if fmt else arg for arg in args]
31
32    def format_row(self, leading, *args, width=10, fmt=None):
33        formatted_args = self.format_args(*args, width=width, fmt=fmt)
34        return leading + " ".join(formatted_args)
35
36    @staticmethod
37    def map_elf_symbol(
38        leading: str,
39        section_name: str,
40        rel_pos: int,
41        columns: int,
42        width: int = None,
43        is_edge: bool = False,
44    ):
45        empty_col = "{:{}{}}"
46
47        # Some symbols are longer than the column width, truncate them until
48        # we find a more elegant way to display them!
49        len_over = len(section_name) - width
50        if len_over > 0:
51            section_name = section_name[len_over:-len_over]
52
53        sec_row = f"+{section_name:-^{width-1}}+"
54        sep, fill = ("+", "-") if is_edge else ("|", "")
55
56        sec_row_l = empty_col.format(sep, fill + "<", width) * rel_pos
57        sec_row_r = empty_col.format(sep, fill + ">", width) * (
58            columns - rel_pos - 1
59        )
60
61        return leading + sec_row_l + sec_row + sec_row_r
62
63    def print_footprint(
64        self, app_mem_usage: dict, sort_key: str = None, fields: list = None
65    ):
66        assert len(app_mem_usage), "Empty memory layout dictionary!"
67        if not fields:
68            fields = ["Component", "Start", "Limit", "Size", "Free", "Total"]
69
70        sort_key = fields[0] if not sort_key else sort_key
71
72        # Iterate through all the memory types, create a table for each
73        # type, rows represent a single module.
74        for mem in sorted(set(k for _, v in app_mem_usage.items() for k in v)):
75            table = PrettyTable(
76                sortby=sort_key,
77                title=f"Memory Usage (bytes) [{mem.upper()}]",
78                field_names=fields,
79            )
80
81            for mod, vals in app_mem_usage.items():
82                if mem in vals.keys():
83                    val = vals[mem]
84                    table.add_row(
85                        [
86                            mod.upper(),
87                            *self.format_args(
88                                *[val[k.lower()] for k in fields[1:]]
89                            ),
90                        ]
91                    )
92            print(table, "\n")
93
94    def print_symbol_table(
95        self,
96        symbols: list,
97        modules: list,
98        start: int = 12,
99    ):
100        assert len(symbols), "Empty symbol list!"
101        modules = sorted(modules)
102        col_width = int((self.term_size - start) / len(modules))
103        address_fixed_width = 11
104
105        num_fmt = (
106            f"0=#0{address_fixed_width}x" if not self.as_decimal else ">10"
107        )
108
109        _symbol_map = [
110            " " * start
111            + "".join(self.format_args(*modules, fmt=f"^{col_width}"))
112        ]
113        last_addr = None
114
115        for i, (name, addr, mod) in enumerate(symbols):
116            # Do not print out an address twice if two symbols overlap,
117            # for example, at the end of one region and start of another.
118            leading = (
119                f"{addr:{num_fmt}}" + " " if addr != last_addr else " " * start
120            )
121
122            _symbol_map.append(
123                self.map_elf_symbol(
124                    leading,
125                    name,
126                    modules.index(mod),
127                    len(modules),
128                    width=col_width,
129                    is_edge=(not i or i == len(symbols) - 1),
130                )
131            )
132
133            last_addr = addr
134
135        self._symbol_map = ["Memory Layout:"]
136        self._symbol_map += list(reversed(_symbol_map))
137        print("\n".join(self._symbol_map))
138
139    def print_mem_tree(
140        self, mem_map_dict, modules, depth=1, min_pad=12, node_right_pad=12
141    ):
142        # Start column should have some padding between itself and its data
143        # values.
144        anchor = min_pad + node_right_pad * (depth - 1)
145        headers = ["start", "end", "size"]
146
147        self._tree = [
148            (f"{'name':<{anchor}}" + " ".join(f"{arg:>10}" for arg in headers))
149        ]
150
151        for mod in sorted(modules):
152            root = DictImporter().import_(mem_map_dict[mod])
153            for pre, fill, node in RenderTree(root, maxlevel=depth):
154                leading = f"{pre}{node.name}".ljust(anchor)
155                self._tree.append(
156                    self.format_row(
157                        leading,
158                        node.start,
159                        node.end,
160                        node.size,
161                    )
162                )
163        print("\n".join(self._tree), "\n")
164