1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2024 Antmicro
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18# SPDX-License-Identifier: Apache-2.0
19
20from typing import Union, Optional
21
22from systemrdl.node import FieldNode, RegNode, RootNode, AddrmapNode
23from systemrdl.rdltypes import OnReadType, OnWriteType
24from itertools import chain
25from functools import reduce
26
27from .csharp import ast as ast
28from .util import PascalCase
29from .scanner import ScannedState, RdlDesignScanner
30from .csharp.process import process_ast
31
32PUBLIC = ast.AccessibilityMod.PUBLIC
33PROTECTED = ast.AccessibilityMod.PROTECTED
34PRIVATE = ast.AccessibilityMod.PRIVATE
35
36class CSharpGenerator:
37    ty_IFlagRegisterField = ast.Type('IFlagRegisterField')
38    ty_IValueRegisterField = ast.Type('IValueRegisterField')
39
40    def __init__(self, scanned: ScannedState, name: str, namespace: str,
41                 make_all_public: bool = False) -> None:
42        self.scanned = scanned
43        self.name = name if name is not None else scanned.top_name
44
45
46        self.ty_register = ast.Type('DoubleWordRegister')
47        self.ty_register_collection = ast.Type('DoubleWordRegisterCollection')
48
49        self.iprovides_register_collection = ast.Class(
50            name = f'IProvidesRegisterCollection<{self.ty_register_collection}>'
51        )
52
53        self.reg_classes = {
54            reg.inst_name: self.generate_value_container_class(reg)
55            for reg in scanned.registers
56        }
57
58        regarray_containers = {reg_arr.name: reg_arr.generate_csharp_container_type() \
59                            for reg_arr in scanned.register_arrays}
60
61        regarray_fields = {
62            ra.name: ast.VariableDecl(
63                name = PascalCase(ra.name),
64                ty = regarray_containers[ra.name].type,
65                access = PROTECTED,
66                doc = f'Memory "{ra.name}" at {hex(ra.addr)}'
67            )
68            for ra in scanned.register_arrays
69        }
70
71        def make_reg_instance_assignement(c: ast.VariableDecl):
72            return ast.Assign(
73                lhs = c.ref(),
74                rhs = ast.New(c.type, ast.Arg(ast.This()))
75            ).into_stmt()
76
77        def make_regarray_assignement(c: ast.VariableDecl):
78            return ast.Assign(
79                lhs = c.ref(),
80                rhs = ast.New(c.type)
81            ).into_stmt()
82
83        def generate_read_method() -> ast.MethodDefinition:
84            offset_arg = ast.ArgDecl(name='offset', ty=ast.Type.long)
85
86            return ast.MethodDefinition(
87                name = 'IDoubleWordPeripheral.ReadDoubleWord',
88                args = offset_arg,
89                body = ast.Node.join(
90                    [
91                        regarray.generate_dword_read_logic(
92                            regarray_fields[regarray.name],
93                            offset_arg
94                        )
95                        for regarray in scanned.register_arrays
96                    ] + [
97                        ast.Return(ast.Call(
98                            ast.MethodDefinition(name='Read'),
99                            ast.Arg(ast.HardCode('offset')),
100                            object = ast.HardCode('RegistersCollection')
101                        ))
102                    ]
103                ),
104                ret_ty = ast.Type.uint
105            )
106
107        def generate_write_method() -> ast.Node:
108            offset_arg = ast.ArgDecl(name='offset', ty=ast.Type.long)
109            value_arg = ast.ArgDecl(name='value', ty=ast.Type.uint)
110
111            return ast.MethodDefinition(
112                name = 'IDoubleWordPeripheral.WriteDoubleWord',
113                args = ast.Node.join([offset_arg, value_arg]),
114                body = ast.Node.join(
115                    [
116                        regarray.generate_dword_write_logic(
117                            regarray_fields[regarray.name],
118                            offset_arg,
119                            value_arg
120                        )
121                        for regarray in scanned.register_arrays
122                    ] + [
123                        ast.Call(
124                            'Write',
125                            ast.Arg(ast.HardCode('offset')),
126                            ast.Arg(ast.HardCode('value')),
127                            object=ast.HardCode('RegistersCollection')
128                        ).into_stmt()
129                    ]
130                ),
131            )
132
133        reg_instances = [
134            ast.VariableDecl(
135                name = PascalCase(reg.inst_name),
136                ty = self.reg_classes[reg.inst_name].type,
137                access = PROTECTED,
138                doc = f'Register "{reg.inst_name}" at {hex(reg.absolute_address)}'
139            )
140            for reg in self.scanned.registers
141        ]
142
143        init_method = ast.MethodDefinition(name='Init', partial=True)
144        reset_method = ast.MethodDefinition(name='Reset', partial=True)
145
146        self.peripheral_class = ast.Class(
147            name = self.name,
148            access = PUBLIC,
149            fields = ast.Node.join(chain(reg_instances, regarray_fields.values())),
150            properties = ast.PropertyDefintion(
151                name = 'RegistersCollection',
152                access = PUBLIC,
153                ret_ty = self.ty_register_collection,
154                get = True
155            ),
156            methods = ast.Node.join([
157                ast.MethodDefinition(
158                    name = self.name,
159                    constructor = True,
160                    access = PUBLIC,
161                    body = ast.Node.join([
162                        ast.Assign(
163                            lhs = ast.HardExpr('RegistersCollection', self.ty_register_collection),
164                            rhs = ast.New(self.ty_register_collection, ast.This())
165                        ).into_stmt(),
166                        *map(make_reg_instance_assignement, reg_instances),
167                        *(make_regarray_assignement(container_member)
168                          for container_member in regarray_fields.values()),
169                        ast.Call(init_method, object=ast.This()).into_stmt()
170                    ])
171                ),
172                init_method,
173                reset_method,
174                ast.MethodDefinition(name='IPeripheral.Reset', body = ast.Node.join([
175                    ast.Call(reset_method, object=ast.This()).into_stmt(),
176                    ast.Call('Reset', object=ast.HardCode('RegistersCollection')).into_stmt()
177                ])),
178                generate_read_method(),
179                generate_write_method()
180            ]),
181            classes = ast.Node.join(chain(
182                self.reg_classes.values(),
183                regarray_containers.values()
184            )),
185            derives = [
186                (None, self.iprovides_register_collection),
187                (None, ast.Class(name = "IPeripheral")),
188                (None, ast.Class(name = "IDoubleWordPeripheral"))
189            ],
190            partial = True
191        )
192
193        self.namespace = ast.Namespace(namespace, classes=self.peripheral_class)
194
195        self.root = ast.Namespace('Antmicro', namespaces=[
196            ast.Namespace('Renode', namespaces=[
197                ast.Namespace('Peripherals', namespaces=[
198                    self.namespace
199                ])
200            ])
201        ])
202
203        process_ast(self.root, make_all_public=make_all_public)
204
205    def generate_code(self) -> str:
206        code = \
207            '// Generated by PeakRDL-renode\n\n' + \
208            'using Antmicro.Renode.Core.Structure.Registers;\n' + \
209            'using Antmicro.Renode.Peripherals.Bus;\n\n' + \
210            ast.CodeGenerator.emit(self.namespace, docs=True)
211
212        return '\n'.join(line.rstrip() for line in code.splitlines()) + '\n'
213
214    def generate_field_modifier(self, field: FieldNode) -> list[ast.Arg]:
215        onread = field.get_property('onread')
216        onwrite = field.get_property('onwrite')
217
218        match onread:
219            case OnReadType.rclr: read_flag = 'FieldMode.ReadToClear'
220            case OnReadType.rset: read_flag = 'FieldMode.ReadToSet'
221            case OnReadType.ruser:
222                read_flag = 'FieldMode.Read'
223            case _: read_flag = 'FieldMode.Read' if field.is_sw_readable else None
224
225        match onwrite:
226            case OnWriteType.woset: write_flag = 'FieldMode.Set'
227            case OnWriteType.woclr: write_flag = 'FieldMode.WriteOneToClear'
228            case OnWriteType.wot: write_flag = 'FieldMode.Toggle'
229            case OnWriteType.wzs: write_flag = 'FieldMode.WriteZeroToSet'
230            case OnWriteType.wzc: write_flag = 'FieldMode.WriteZeroToClear'
231            case OnWriteType.wzt: write_flag = 'FieldMode.WriteZeroToToggle'
232            case OnWriteType.wclr: write_flag = 'FieldMode.WriteToClearAll'
233            case OnWriteType.wset: write_flag = 'FieldMode.WriteToSetAll'
234            case OnWriteType.wuser:
235                write_flag = 'FieldMode.Write'
236            case _: write_flag = 'FieldMode.Write' if field.is_sw_writable else None
237
238        match (read_flag, write_flag):
239            case ('FieldMode.Read', 'FieldMode.Write'): return []
240            case (str(f), None) | (None, str(f)): return [ast.Arg(f, name = 'mode')]
241            case (str(rd), str(wr)): return [ast.Arg(rd + ' | ' + wr, name = 'mode')]
242            case (None, None): raise RuntimeError('Can\'t calculate field access flags')
243
244    def generate_field_decl(self, field: FieldNode,
245                            underlying_var: Optional[str] = None) -> ast.Call:
246        field_width = field.high - field.low + 1
247        field_name_arg = ast.StringLit(field.inst_name.upper())
248
249        match (field_width == 1, underlying_var):
250            case (True, str(out_var)): return ast.Call(
251                'WithFlag',
252                ast.Arg(field.low),
253                ast.Arg(out_var, out=True),
254                *self.generate_field_modifier(field),
255                ast.Arg(field_name_arg, name='name'),
256                ret_ty=self.ty_register
257            )
258            case (True, None): return ast.Call(
259                'WithTaggedFlag',
260                ast.Arg(field_name_arg),
261                ast.Arg(field.position),
262                ret_ty=self.ty_register
263            )
264            case (False, out_var): return ast.Call(
265                'WithValueField',
266                ast.Arg(field.low),
267                ast.Arg(field_width),
268                *([ast.Arg(out_var, out=True)] if type(out_var) is str else []),
269                *self.generate_field_modifier(field),
270                ast.Arg(field_name_arg, name='name'),
271                ret_ty=self.ty_register
272            )
273            case _: raise RuntimeError('Unhandled field configuration')
274
275    def register_class_name(self, register: RegNode) -> str:
276        return PascalCase(register.inst_name) + 'Type'
277
278    def generate_value_container_class(
279        self,
280        register: RegNode
281    ) -> ast.Class:
282
283        def make_var_decl(field: FieldNode):
284            field_width = field.high - field.low + 1
285            return ast.VariableDecl(
286                name = field.inst_name.upper(),
287                ty = self.ty_IFlagRegisterField if field_width == 1
288                     else self.ty_IValueRegisterField,
289                access = PUBLIC,
290                doc = f'Field "{field.inst_name}" at {hex(field.low)}, ' +
291                      f'width: {field.high - field.low + 1} bits'
292            )
293
294        def add_field_impl(obj: ast.Node, field: FieldNode):
295            call = self.generate_field_decl(field, field.inst_name.upper())
296            call.object = obj
297            call.breakline = True
298            return call
299
300        name = self.register_class_name(register)
301
302        methods = [
303            ast.MethodDefinition(
304                constructor = True,
305                access = PUBLIC,
306                name = name,
307                args = ast.ArgDecl(
308                    name = 'parent',
309                    ty = self.iprovides_register_collection.type
310                ),
311                body = reduce(add_field_impl, register.fields(),
312                    ast.Call(
313                        'DefineRegister',
314                        ast.Arg(register.absolute_address),
315                        ast.Arg(self.scanned.resets[register.inst_name]),
316                        ast.Arg(ast.BoolLit(True)),
317                        object = ast.HardCode('parent.RegistersCollection'),
318                        ret_ty=self.ty_register
319                    )
320                ).into_stmt()
321            )
322        ]
323
324        return ast.Class(
325            name = name,
326            access = PUBLIC,
327            fields = ast.Node.join(map(make_var_decl, register.fields())),
328            methods = ast.Node.join(methods),
329            struct = True
330        )
331
332    def generate_value_container_instance(self, register: RegNode) -> ast.VariableDecl:
333        return ast.VariableDecl(
334            PascalCase(register.inst_name),
335            self.reg_classes[register.inst_name].type
336        )
337
338    @staticmethod
339    def add_indents(text: str, base_indent: int) -> str:
340        text = str(text)
341
342        def add_indent(line):
343            if line.strip() == '':
344                return ''
345            return '    ' * base_indent + line
346
347        match text.splitlines():
348            case [head, *tail]:
349                indented = '\n'.join(add_indent(line) for line in tail)
350                return head + '\n' + indented
351            case _: return text
352
353class CSharpExporter:
354    def export(self, node: Union[RootNode, AddrmapNode], path: str,
355               name: str, namespace: str, all_public: bool):
356        top_node = node.top if isinstance(node, RootNode) else node
357
358        scanned = RdlDesignScanner(top_node).run()
359        csharp = CSharpGenerator(
360            scanned = scanned,
361            name = name,
362            namespace = namespace,
363            make_all_public = all_public
364        ).generate_code()
365
366        with open(path, 'w') as f:
367            f.write(csharp)
368