1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2024 Antmicro
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18# SPDX-License-Identifier: Apache-2.0
19
20from systemrdl.node import RegNode, FieldNode
21from types import SimpleNamespace
22
23from .csharp import ast as ast
24from .csharp.helper import TemplatedAST, TemplateHole
25from .csharp import operators as op
26from .util import PascalCase, camelCase
27
28PUBLIC = ast.AccessibilityMod.PUBLIC
29PROTECTED = ast.AccessibilityMod.PROTECTED
30PRIVATE = ast.AccessibilityMod.PRIVATE
31
32class RegArray:
33    def __init__(self, name: str, register: RegNode, address: int):
34        self.name = name
35        self.register = register
36        self.addr = address
37
38    @property
39    def count(self) -> int:
40        return self.register.array_dimensions[0]
41
42    @property
43    def stride(self) -> int:
44        return self.register.array_stride
45
46    @staticmethod
47    def m_get_underlying_field_type(field: FieldNode) -> ast.Type:
48        match field.high - field.low + 1:
49            case width if width == 1: return ast.Type.bool
50            case width if width in range(2, 9): return ast.Type.byte
51            case width if width in range(9, 17): return ast.Type.ushort
52            case width if width in range(17, 33): return ast.Type.uint
53            case width if width in range(33, 65): return ast.Type.ulong
54            case _: raise RuntimeError(f'The field `{field.inst_name}` is too wide')
55
56    @staticmethod
57    def m_cast_to_field_type(ty: ast.Type, expr: ast.Expr) -> ast.Expr:
58        if ty == expr.type:
59            return expr
60
61        if ty == ast.Type.bool:
62            return op.NEQ(expr, ast.IntLit(0, expr.type.is_unsigned, expr.type.is_long))
63
64        return ast.Cast(ty, expr)
65
66    @staticmethod
67    def m_get_underlying_field_mask(field: FieldNode) -> ast.IntLit:
68        return ast.IntLit(1 << (field.high - field.low + 1) - 1, fmt='h')
69
70    @staticmethod
71    def m_generate_underlying_field_decl(field: FieldNode) -> ast.VariableDecl:
72        return ast.VariableDecl(
73            name = camelCase(field.inst_name),
74            ty = RegArray.m_get_underlying_field_type(field)
75        )
76
77    @staticmethod
78    def m_generate_underlying_property(field: FieldNode) -> ast.VariableDecl:
79        width = field.high - field.low + 1
80        bytes_to_access = (width + 7) // 8
81        first_byte = field.low // 8
82        shift = field.low % 8
83
84        field_type = RegArray.m_get_underlying_field_type(field)
85
86        def idx_byte(offset: int) -> ast.HardExpr:
87            return ast.HardExpr(f'memory[spanBegin + {first_byte + offset}]', ast.Type.byte)
88
89        def byte_mask(byte_idx) -> int:
90            return ((((1 << (width)) - 1) << shift) >> (byte_idx * 8)) & 0xff
91
92        get_tempvar = ast.VariableDecl(
93            'temp',
94            field_type,
95            init = ast.IntLit(0) if field_type is not ast.Type.bool else None
96        )
97        set_value = ast.VariableDecl('value', field_type)
98
99        def generate_getter_assignmnets() -> ast.Stmt:
100            if field_type == ast.Type.bool:
101                return ast.Assign(
102                    lhs = get_tempvar.ref(),
103                    rhs = op.NEQ(
104                        lhs = op.AND(
105                            lhs = idx_byte(0),
106                            rhs = ast.IntLit(byte_mask(0), fmt='h')
107                        ),
108                        rhs = ast.IntLit(0)
109                    )
110                ).into_stmt()
111
112            return ast.Node.join(
113                # for idx in range(bytes_to_access)
114                ast.Assign(
115                    lhs = get_tempvar.ref(),
116                    rhs = ast.Cast(field_type, op.OR(
117                        lhs = get_tempvar.ref(),
118                        rhs = # if idx == 0
119                            op.SHR(
120                                lhs = ast.Cast(field_type, op.AND(
121                                    lhs = idx_byte(idx),
122                                    rhs = ast.IntLit(byte_mask(idx), fmt='h')
123                                )),
124                                rhs = ast.IntLit(shift)
125                            )
126                            if idx == 0 else
127                            op.SHL(
128                                lhs = ast.Cast(field_type, op.AND(
129                                    lhs = idx_byte(idx),
130                                    rhs = ast.IntLit(byte_mask(idx), fmt='h')
131                                )),
132                                rhs = ast.IntLit(((idx - 1) * 8) + (8 - shift))
133                            )
134                    )
135                )).into_stmt()
136                for idx in range(bytes_to_access)
137            )
138
139        def generate_setter_assignments() -> ast.Stmt:
140            def get_field_value_masked(mask: int) -> ast.Expr:
141                if field_type == ast.Type.bool:
142                    return op.Cond(
143                        cond = set_value.ref(),
144                        then_ = ast.IntLit(1, unsigned=True),
145                        else_ = ast.IntLit(0, unsigned=True)
146                    )
147                return op.AND(
148                    lhs = set_value.ref(),
149                    rhs = ast.IntLit(mask, unsigned = True, fmt='h')
150                )
151
152            return ast.Node.join(
153                # for idx in range(bytes_to_access)
154                ast.Assign(
155                    lhs = idx_byte(idx),
156                    rhs = ast.Cast(ast.Type.byte, op.OR(
157                        lhs = op.AND(
158                            lhs = idx_byte(idx),
159                            rhs = ast.IntLit(0xff - byte_mask(idx), unsigned=True, fmt='h')
160                        ),
161                        rhs = # if idx == 0
162                            op.SHL(
163                                lhs = get_field_value_masked((1 << width) - 1),
164                                rhs = ast.IntLit(shift)
165                            )
166                            if idx == 0 else
167                            op.SHR(
168                                lhs = get_field_value_masked((1 << width) - 1),
169                                rhs = ast.IntLit(((idx - 1) * 8) + (8 - shift))
170                            )
171                    ))
172                ).into_stmt()
173                for idx in range(bytes_to_access)
174            )
175
176        return ast.PropertyDefintion(
177            name = field.inst_name.upper(),
178            access = PUBLIC,
179            doc = f'Offset: {hex(field.low)}, Width: {width} bits',
180            get = ast.Node.join([
181                get_tempvar,
182                generate_getter_assignmnets(),
183                ast.Return(get_tempvar.ref())
184            ]),
185            set = generate_setter_assignments(),
186            ret_ty = field_type
187        )
188
189    def generate_csharp_wrapper_type(self) -> ast.Class:
190        class_name = PascalCase(self.name) + '_' + PascalCase(self.register.inst_name) + 'Wrapper'
191
192        return ast.Class(
193            name = class_name,
194            struct = False,
195            access = PUBLIC,
196            fields = ast.Node.join([
197                ast.VariableDecl('spanBegin', ast.Type.long, access=PRIVATE),
198                ast.VariableDecl('memory', ast.Type.byte.array(), access=PRIVATE)
199            ]),
200            properties = ast.Node.join(
201                RegArray.m_generate_underlying_property(field)
202                for field in self.register.fields()
203            ),
204            methods = ast.MethodDefinition(
205                name = class_name,
206                constructor = True,
207                access = PUBLIC,
208                args = ast.ArgDecl('memory', ast.Type.byte.array())
209                    .then(ast.ArgDecl('spanBegin', ast.Type.long)),
210                body = ast.Node.join([
211                    ast.Assign(
212                        lhs = ast.HardExpr('this.memory', ty=ast.Type.byte.array()),
213                        rhs = ast.HardExpr('memory', ty=ast.Type.byte.array())
214                    ).into_stmt(),
215                    ast.Assign(
216                        lhs = ast.HardExpr('this.spanBegin', ast.Type.long),
217                        rhs = ast.HardExpr('spanBegin', ast.Type.long)
218                    ).into_stmt(),
219                ])
220            )
221        )
222
223    def generate_csharp_container_type(self) -> ast.Class:
224        class_name = \
225            PascalCase(self.name) + '_' + PascalCase(self.register.inst_name) + 'Container'
226
227        wrapper_type = self.generate_csharp_wrapper_type()
228
229        return ast.Class(
230            name = class_name,
231            access = PROTECTED,
232            fields = ast.VariableDecl('memory', ast.Type.byte.array(), access=PRIVATE),
233            properties = ast.Node.join([
234                ast.PropertyDefintion(
235                    name = 'Size',
236                    access = PUBLIC,
237                    ret_ty = ast.Type.long,
238                    get = ast.Return(ast.IntLit(self.count * self.stride, long=True))
239                ),
240                ast.PropertyDefintion(
241                    name = 'this[long index]',
242                    access = PUBLIC,
243                    ret_ty = wrapper_type.type,
244                    get = ast.Node.join([
245                        ast.If(
246                            condition = op.LOR(
247                                lhs = op.LT(
248                                    lhs = ast.HardExpr('index', ty=ast.Type.long),
249                                    rhs = ast.IntLit(0)
250                                ),
251                                rhs = op.GTE(
252                                    lhs = ast.HardExpr('index', ty=ast.Type.long),
253                                    rhs = ast.IntLit(self.count)
254                                ),
255                            ),
256                            then = ast.Throw(
257                                ast.New(
258                                    ast.Type('System.IndexOutOfRangeException')
259                                )
260                            )
261                        ),
262                        ast.Return(ast.New(
263                            wrapper_type.type,
264                            ast.Arg('memory'),
265                            ast.Arg(f'index * {self.stride}')
266                        ))
267                    ])
268                )
269            ]),
270            methods = ast.Node.join([
271                ast.MethodDefinition(
272                    name = class_name,
273                    access = PUBLIC,
274                    constructor = True,
275                    body = ast.Assign(
276                        lhs = ast.HardExpr('memory', ast.Type.byte.array()),
277                        rhs = ast.NewArray(ast.Type.byte, 'Size')
278                    ).into_stmt()
279                ),
280                ast.MethodDefinition(
281                    name = 'ReadDoubleWord',
282                    access = PUBLIC,
283                    args = ast.ArgDecl('offset', ast.Type.long),
284                    body = ast.Return(ast.HardExpr(
285                        f'(uint)memory[offset] + '
286                        f'((uint)memory[offset + 1] << 8) + ' +
287                        f'((uint)memory[offset + 2] << 16) + ' +
288                        f'((uint)memory[offset + 3] << 24)',
289                        ast.Type.long
290                    )),
291                    ret_ty = ast.Type.uint
292                ),
293                ast.MethodDefinition(
294                    name = 'WriteDoubleWord',
295                    access = PUBLIC,
296                    args = ast.Node.join([
297                        ast.ArgDecl('offset', ast.Type.long),
298                        ast.ArgDecl('value', ast.Type.uint)
299                    ]),
300                    body = ast.Node.join([
301                        ast.Assign(
302                            lhs = ast.HardExpr('memory[offset]', ast.Type.byte),
303                            rhs = ast.Cast(ast.Type.byte,
304                                           ast.HardExpr('value', ast.Type.uint))
305                        ).into_stmt(),
306                        ast.Assign(
307                            lhs = ast.HardExpr('memory[offset + 1]', ast.Type.byte),
308                            rhs = ast.Cast(ast.Type.byte,
309                                           ast.HardExpr('(value >> 8)', ast.Type.uint))
310                        ).into_stmt(),
311                        ast.Assign(
312                            lhs = ast.HardExpr('memory[offset + 2]', ast.Type.byte),
313                            rhs = ast.Cast(ast.Type.byte,
314                                           ast.HardExpr('(value >> 16)', ast.Type.uint))
315                        ).into_stmt(),
316                        ast.Assign(
317                            lhs = ast.HardExpr('memory[offset + 3]', ast.Type.byte),
318                            rhs = ast.Cast(ast.Type.byte,
319                                           ast.HardExpr('(value >> 24)', ast.Type.uint))
320                        ).into_stmt()
321                    ])
322                )
323            ]),
324            classes = wrapper_type
325        )
326
327    def m_generate_conditional_access(
328            self,
329            me: ast.VariableDecl,
330            offset_var: ast.VariableDecl
331        ) -> SimpleNamespace:
332        return TemplatedAST(
333            ast.If(
334                condition = op.LAND(
335                    lhs = op.GTE(
336                        lhs = offset_var.ref(),
337                        rhs = ast.IntLit(self.addr)
338                    ),
339                    rhs = op.LT(
340                        lhs = offset_var.ref(),
341                        rhs = op.Add(
342                            lhs = ast.IntLit(self.addr, long=True),
343                            rhs = ast.HardExpr(f'{me.ref()}.Size', ty='long')
344                        )
345                    )
346                ),
347                then = TemplateHole('then')
348            )
349        ).template
350
351    def generate_dword_read_logic(
352        self,
353        me: ast.VariableDecl,
354        offset_var: ast.VariableDecl
355    ) -> ast.Stmt:
356        template = self.m_generate_conditional_access(me, offset_var)
357        template.then.replace(ast.Return(
358            ast.Call(
359                'ReadDoubleWord',
360                ast.Arg(f'{offset_var.ref()} - {self.addr}'),
361                object = me.ref()
362            )
363        ))
364
365        return template.ast
366
367    def generate_dword_write_logic(
368        self,
369        me: ast.VariableDecl,
370        offset_var: ast.VariableDecl,
371        value_var: ast.VariableDecl
372    ) -> ast.Stmt:
373        template = self.m_generate_conditional_access(me, offset_var)
374        template.then.replace(ast.Node.join([
375            ast.Call(
376                'WriteDoubleWord',
377                ast.Arg(f'{offset_var.ref()} - {self.addr}'),
378                ast.Arg(str(value_var.ref())),
379                object=me.ref()
380            ).into_stmt(),
381            ast.Return()
382        ]))
383
384        return template.ast
385