1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Copyright (C) 2024 Antmicro 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# 18# SPDX-License-Identifier: Apache-2.0 19 20from typing import Any, Callable, Iterable 21from types import SimpleNamespace 22from itertools import chain 23 24from . import ast 25 26class Hole(ast.Node): 27 def __init__(self, **kwargs): 28 super().__init__(**kwargs) 29 30 def tokenize(self, _: ast.CodeGenerator) -> ast.Iterable[str | ast.CodeCC]: 31 raise RuntimeError('Incomplete AST - Hole in AST') 32 33class TemplateHole(ast.Node): 34 def __init__(self, name: str | None = None, **kwargs) -> None: 35 super().__init__(name=name, **kwargs) 36 37 def tokenize(self, _: ast.CodeGenerator) -> ast.Iterable[str | ast.CodeCC]: 38 raise RuntimeError('Incomplete AST - template not processed') 39 40class Visitor: 41 def __init__(self, nodes: ast.Node, verbose: bool = False) -> None: 42 self.visitor_methods: dict[type, Callable[[Any], None]] = {} 43 44 for c in chain([ast.Node], Visitor.m_get_all_subclasses(ast.Node)): 45 visitor_name = 'visit_' + c.__name__ 46 if hasattr(self, visitor_name): 47 self.visitor_methods[c] = getattr(self, visitor_name) 48 49 self.verbose = verbose 50 self.depth = 0 51 if self.verbose: 52 print(f'Visitor {type(self)}:') 53 54 self.visit(nodes) 55 56 @staticmethod 57 def m_get_all_subclasses(ty: type): 58 for c in ty.__subclasses__(): 59 yield c 60 for cc in Visitor.m_get_all_subclasses(c): 61 yield cc 62 63 @staticmethod 64 def m_iterate_class_hierarchy(ty: type) -> Iterable[type]: 65 q = [ty] 66 while len(q) != 0: 67 t = q.pop(0) 68 yield t 69 for b in t.__bases__: 70 if issubclass(b, ast.Node): 71 q.append(b) 72 73 74 def iterate_children_dfs(self, node: ast.Node) -> None: 75 for child in node.children(): 76 self.visit(child) 77 78 def visit_Node(self, node: ast.Node) -> None: 79 self.iterate_children_dfs(node) 80 81 def visit(self, node: ast.Node) -> None: 82 if node.null: return 83 84 if self.verbose: 85 print(' ' * (self.depth * 2) + f'* VISIT {type(node)}') 86 87 self.depth += 1 88 89 for c in Visitor.m_iterate_class_hierarchy(type(node)): 90 visitor = self.visitor_methods.get(c) 91 if visitor is not None: 92 visitor(node) 93 self.depth -= 1 94 return 95 96 self.depth -= 1 97 98 raise RuntimeError(f'No visitor found for type {type(node)}, ' 99 f'bases found: {list(Visitor.m_iterate_class_hierarchy(type(node)))}, ', 100 f'visitor keys: {self.visitor_methods.keys()}') 101 102class TemplatedAST(Visitor): 103 def __init__(self, nodes: ast.Node, verbose: bool = False) -> None: 104 self.obj = SimpleNamespace() 105 super().__init__(nodes, verbose) 106 setattr(self.obj, 'ast', nodes) 107 108 @property 109 def template(self) -> Any: 110 return self.obj 111 112 def visit_TemplateHole(self, node: TemplateHole): 113 setattr(self.obj, node.name, node) 114