#!/usr/bin/env python3
# Copyright (c) 2025 Basalte bv
#
# SPDX-License-Identifier: Apache-2.0

"""
A script to plot data in a sunburst chart generated by size_report.
When you call the ram_report or rom_report targets you end up
with a json file in the build directory that can be used as input
for this script.

Example:
   ./scripts/footprint/plot.py build/ram.json

Requires plotly to be installed, for example with pip:
   pip install plotly
"""

import argparse
import json
import sys


def parse_args():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        allow_abbrev=False,
    )

    parser.add_argument('input', help='Input json file')
    parser.add_argument('--html', help='Output html file')
    parser.add_argument(
        '--depth',
        help='Maximum render depth, pass -1 to render all levels. Defaults to 4',
        type=int,
        default=4,
    )

    return parser.parse_args()


def main():
    args = parse_args()

    try:
        import plotly.graph_objects as go
    except ImportError:
        sys.exit("Missing dependency: You need to install plotly.")

    with open(args.input) as f:
        data = json.load(f)

    totalsize = data.get('total_size')
    ids = []
    labels = []
    parents = []
    values = []
    hovertext = []

    def iter_node(node: dict, parent=''):
        identifier = node.get('identifier')
        if identifier is None:
            return

        if identifier in ids:
            # Identifiers aren't unique, add a suffix to make them unique
            idx = 0
            while f'{identifier}_{idx}' in ids:
                idx += 1
            identifier = f'{identifier}_{idx}'

        ids.append(identifier)
        labels.append(node.get('name', ''))
        parents.append(parent)
        values.append(node.get('size', 0))

        details = [f'percentage: {node.get("size") / totalsize:.2%}']
        if 'address' in node:
            details.append(f'address: 0x{node.get("address"):08x}')
        if 'section' in node:
            details.append(f'section: {node.get("section")}')

        hovertext.append("<br>".join(details))

        for child in node.get('children', ()):
            iter_node(child, identifier)

    iter_node(data.get('symbols', {}))

    fig = go.Figure(
        go.Sunburst(
            ids=ids,
            labels=labels,
            parents=parents,
            values=values,
            hovertext=hovertext,
            branchvalues='total',
            maxdepth=args.depth,
        ),
        skip_invalid=True,
    )
    fig.update_layout(margin={'t': 0, 'l': 0, 'r': 0, 'b': 0})

    if args.html:
        fig.write_html(args.html, auto_open=False)
        return

    print("Opening the default browser to render the generated plot.")
    fig.show(renderer="browser")


if __name__ == "__main__":
    main()
