1#!/usr/bin/env python3
2# Copyright (c) 2025 Basalte bv
3#
4# SPDX-License-Identifier: Apache-2.0
5
6"""
7A script to plot data in a sunburst chart generated by size_report.
8When you call the ram_report or rom_report targets you end up
9with a json file in the build directory that can be used as input
10for this script.
11
12Example:
13   ./scripts/footprint/plot.py build/ram.json
14
15Requires plotly to be installed, for example with pip:
16   pip install plotly
17"""
18
19import argparse
20import json
21import sys
22
23
24def parse_args():
25    parser = argparse.ArgumentParser(
26        description=__doc__,
27        formatter_class=argparse.RawDescriptionHelpFormatter,
28        allow_abbrev=False,
29    )
30
31    parser.add_argument('input', help='Input json file')
32    parser.add_argument('--html', help='Output html file')
33    parser.add_argument(
34        '--depth',
35        help='Maximum render depth, pass -1 to render all levels. Defaults to 4',
36        type=int,
37        default=4,
38    )
39
40    return parser.parse_args()
41
42
43def main():
44    args = parse_args()
45
46    try:
47        import plotly.graph_objects as go
48    except ImportError:
49        sys.exit("Missing dependency: You need to install plotly.")
50
51    with open(args.input) as f:
52        data = json.load(f)
53
54    totalsize = data.get('total_size')
55    ids = []
56    labels = []
57    parents = []
58    values = []
59    hovertext = []
60
61    def iter_node(node: dict, parent=''):
62        identifier = node.get('identifier')
63        if identifier is None:
64            return
65
66        if identifier in ids:
67            # Identifiers aren't unique, add a suffix to make them unique
68            idx = 0
69            while f'{identifier}_{idx}' in ids:
70                idx += 1
71            identifier = f'{identifier}_{idx}'
72
73        ids.append(identifier)
74        labels.append(node.get('name', ''))
75        parents.append(parent)
76        values.append(node.get('size', 0))
77
78        details = [f'percentage: {node.get("size") / totalsize:.2%}']
79        if 'address' in node:
80            details.append(f'address: 0x{node.get("address"):08x}')
81        if 'section' in node:
82            details.append(f'section: {node.get("section")}')
83
84        hovertext.append("<br>".join(details))
85
86        for child in node.get('children', ()):
87            iter_node(child, identifier)
88
89    iter_node(data.get('symbols', {}))
90
91    fig = go.Figure(
92        go.Sunburst(
93            ids=ids,
94            labels=labels,
95            parents=parents,
96            values=values,
97            hovertext=hovertext,
98            branchvalues='total',
99            maxdepth=args.depth,
100        ),
101        skip_invalid=True,
102    )
103    fig.update_layout(margin={'t': 0, 'l': 0, 'r': 0, 'b': 0})
104
105    if args.html:
106        fig.write_html(args.html, auto_open=False)
107        return
108
109    print("Opening the default browser to render the generated plot.")
110    fig.show(renderer="browser")
111
112
113if __name__ == "__main__":
114    main()
115