1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2024 Antmicro
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18# SPDX-License-Identifier: Apache-2.0
19
20from .ast import *
21from . import operators as op
22from .helper import Visitor, Hole
23from itertools import chain
24
25class Parenthesis(Expr):
26    def __init__(self, expr: Expr, **kwargs):
27        super().__init__(expr.type, **kwargs)
28        self.expr = expr
29
30    def children(self) -> op.Iterable[op.Node]:
31        return [self.expr]
32
33    def tokenize(self, cg: op.CodeGenerator) -> op.Iterable[str | op.CodeCC]:
34        return chain('(', self.expr.tokenize(cg), ')')
35
36class OrderOperators(Visitor):
37    def __init__(self, nodes: Node, verbose: bool = False) -> None:
38        super().__init__(nodes, verbose)
39
40    @staticmethod
41    def get_precedence(expr: Expr) -> int:
42        match expr:
43            case op.Mul() | op.Div(): return 13
44            case op.Add() | op.Sub(): return 12
45            case op.SHL() | op.SHR() | op.USHR(): return 11
46            case op.GT() | op.LT() | op.GTE() | op.LTE(): return 10
47            case op.EQ() | op.NEQ(): return 9
48            case op.AND(): return 8
49            case op.OR(): return 6
50            case op.LAND(): return 5
51            case op.LOR(): return 4
52            case op.Cond(): return 2
53            case BinaryOp(): return 0
54        return 18
55
56    @staticmethod
57    def m_parenthesize(expr: Expr) -> None:
58        hole = Hole()
59        expr.replace(hole)
60        parenthesis = Parenthesis(expr)
61        hole.replace(parenthesis)
62
63    @staticmethod
64    def m_process_Op(node: Expr) -> None:
65        for child in node.children():
66            if OrderOperators.get_precedence(node) > OrderOperators.get_precedence(child):
67                OrderOperators.m_parenthesize(child)
68
69    def visit_Cast(self, node: Cast) -> None:
70        if isinstance(node.expr, (BinaryOp, op.Cond)):
71            OrderOperators.m_parenthesize(node.expr)
72
73        self.iterate_children_dfs(node)
74
75    def visit_BinaryOp(self, node: BinaryOp) -> None:
76        OrderOperators.m_process_Op(node)
77        self.iterate_children_dfs(node)
78
79    def visit_Cond(self, node: op.Cond) -> None:
80        OrderOperators.m_process_Op(node)
81        self.iterate_children_dfs(node)
82