1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2024 Antmicro
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18# SPDX-License-Identifier: Apache-2.0
19
20from typing import Any, Callable, Iterable
21from types import SimpleNamespace
22from itertools import chain
23
24from . import ast
25
26class Hole(ast.Node):
27    def __init__(self, **kwargs):
28        super().__init__(**kwargs)
29
30    def tokenize(self, _: ast.CodeGenerator) -> ast.Iterable[str | ast.CodeCC]:
31        raise RuntimeError('Incomplete AST - Hole in AST')
32
33class TemplateHole(ast.Node):
34    def __init__(self, name: str | None = None, **kwargs) -> None:
35        super().__init__(name=name, **kwargs)
36
37    def tokenize(self, _: ast.CodeGenerator) -> ast.Iterable[str | ast.CodeCC]:
38        raise RuntimeError('Incomplete AST - template not processed')
39
40class Visitor:
41    def __init__(self, nodes: ast.Node, verbose: bool = False) -> None:
42        self.visitor_methods: dict[type, Callable[[Any], None]] = {}
43
44        for c in chain([ast.Node], Visitor.m_get_all_subclasses(ast.Node)):
45            visitor_name = 'visit_' + c.__name__
46            if hasattr(self, visitor_name):
47                self.visitor_methods[c] = getattr(self, visitor_name)
48
49        self.verbose = verbose
50        self.depth = 0
51        if self.verbose:
52            print(f'Visitor {type(self)}:')
53
54        self.visit(nodes)
55
56    @staticmethod
57    def m_get_all_subclasses(ty: type):
58        for c in ty.__subclasses__():
59            yield c
60            for cc in Visitor.m_get_all_subclasses(c):
61                yield cc
62
63    @staticmethod
64    def m_iterate_class_hierarchy(ty: type) -> Iterable[type]:
65        q = [ty]
66        while len(q) != 0:
67            t = q.pop(0)
68            yield t
69            for b in t.__bases__:
70                if issubclass(b, ast.Node):
71                    q.append(b)
72
73
74    def iterate_children_dfs(self, node: ast.Node) -> None:
75        for child in node.children():
76            self.visit(child)
77
78    def visit_Node(self, node: ast.Node) -> None:
79        self.iterate_children_dfs(node)
80
81    def visit(self, node: ast.Node) -> None:
82        if node.null: return
83
84        if self.verbose:
85            print(' ' * (self.depth * 2) + f'* VISIT {type(node)}')
86
87        self.depth += 1
88
89        for c in Visitor.m_iterate_class_hierarchy(type(node)):
90            visitor = self.visitor_methods.get(c)
91            if visitor is not None:
92                visitor(node)
93                self.depth -= 1
94                return
95
96        self.depth -= 1
97
98        raise RuntimeError(f'No visitor found for type {type(node)}, '
99                           f'bases found: {list(Visitor.m_iterate_class_hierarchy(type(node)))}, ',
100                           f'visitor keys: {self.visitor_methods.keys()}')
101
102class TemplatedAST(Visitor):
103    def __init__(self, nodes: ast.Node, verbose: bool = False) -> None:
104        self.obj = SimpleNamespace()
105        super().__init__(nodes, verbose)
106        setattr(self.obj, 'ast', nodes)
107
108    @property
109    def template(self) -> Any:
110        return self.obj
111
112    def visit_TemplateHole(self, node: TemplateHole):
113        setattr(self.obj, node.name, node)
114