1# Copyright (c) 2024 Vestas Wind Systems A/S
2#
3# SPDX-License-Identifier: Apache-2.0
4
5"""
6Zephyr CAN shell module support for providing a python-can bus interface for testing.
7"""
8
9import logging
10import re
11
12from can import BusABC, CanProtocol, Message
13from can.exceptions import CanInitializationError, CanOperationError
14from can.typechecking import CanFilters
15from twister_harness import DeviceAdapter, Shell
16
17logger = logging.getLogger(__name__)
18
19
20class CanShellBus(BusABC):  # pylint: disable=abstract-method
21    """
22    A CAN interface using the Zephyr CAN shell module.
23    """
24
25    def __init__(
26        self,
27        dut: DeviceAdapter,
28        shell: Shell,
29        channel: str,
30        can_filters: CanFilters | None = None,
31        **kwargs,
32    ) -> None:
33        self._dut = dut
34        self._shell = shell
35        self._device = channel
36        self._is_filtered = False
37        self._filter_ids = []
38
39        self.channel_info = f'Zephyr CAN shell, device "{self._device}"'
40
41        mode = 'normal'
42        if 'fd' in self._get_capabilities():
43            self._can_protocol = CanProtocol.CAN_FD
44            mode += ' fd'
45        else:
46            self._can_protocol = CanProtocol.CAN_20
47
48        self._set_mode(mode)
49        self._start()
50
51        super().__init__(channel=channel, can_filters=can_filters, **kwargs)
52
53    def _retval(self):
54        """Get return value of last shell command."""
55        return int(self._shell.get_filtered_output(self._shell.exec_command('retval'))[0])
56
57    def _get_capabilities(self) -> list[str]:
58        cmd = f'can show {self._device}'
59
60        lines = self._shell.get_filtered_output(self._shell.exec_command(cmd))
61        regex_compiled = re.compile(r'capabilities:\s+(?P<caps>.*)')
62        for line in lines:
63            m = regex_compiled.match(line)
64            if m:
65                return m.group('caps').split()
66
67        raise CanOperationError('capabilities not found')
68
69    def _set_mode(self, mode: str) -> None:
70        self._shell.exec_command(f'can mode {self._device} {mode}')
71        retval = self._retval()
72        if retval != 0:
73            raise CanOperationError(f'failed to set mode "{mode}" (err {retval})')
74
75    def _start(self):
76        self._shell.exec_command(f'can start {self._device}')
77        retval = self._retval()
78        if retval != 0:
79            raise CanInitializationError(f'failed to start (err {retval})')
80
81    def _stop(self):
82        self._shell.exec_command(f'can stop {self._device}')
83
84    def send(self, msg: Message, timeout: float | None = None) -> None:
85        logger.debug('sending: %s', msg)
86
87        cmd = f'can send {self._device}'
88        cmd += ' -e' if msg.is_extended_id else ''
89        cmd += ' -r' if msg.is_remote_frame else ''
90        cmd += ' -f' if msg.is_fd else ''
91        cmd += ' -b' if msg.bitrate_switch else ''
92
93        if msg.is_extended_id:
94            cmd += f' {msg.arbitration_id:08x}'
95        else:
96            cmd += f' {msg.arbitration_id:03x}'
97
98        if msg.data:
99            cmd += ' ' + msg.data.hex(' ', 1)
100
101        lines = self._shell.exec_command(cmd)
102        regex_compiled = re.compile(r'enqueuing\s+CAN\s+frame\s+#(?P<id>\d+)')
103        frame_num = None
104        for line in lines:
105            m = regex_compiled.match(line)
106            if m:
107                frame_num = m.group('id')
108                break
109
110        if frame_num is None:
111            raise CanOperationError('frame not enqueued')
112
113        tx_regex = r'CAN\s+frame\s+#' + frame_num + r'\s+successfully\s+sent'
114        self._dut.readlines_until(regex=tx_regex, timeout=timeout)
115
116    def _add_filter(self, can_id: int, can_mask: int, extended: bool) -> None:
117        """Add RX filter."""
118        cmd = f'can filter add {self._device}'
119        cmd += ' -e' if extended else ''
120
121        if extended:
122            cmd += f' {can_id:08x}'
123            cmd += f' {can_mask:08x}'
124        else:
125            cmd += f' {can_id:03x}'
126            cmd += f' {can_mask:03x}'
127
128        lines = self._shell.exec_command(cmd)
129        regex_compiled = re.compile(r'filter\s+ID:\s+(?P<id>\d+)')
130        for line in lines:
131            m = regex_compiled.match(line)
132            if m:
133                filter_id = int(m.group('id'))
134                self._filter_ids.append(filter_id)
135                return
136
137        raise CanOperationError('filter_id not found')
138
139    def _remove_filter(self, filter_id: int) -> None:
140        """Remove RX filter."""
141        if filter_id in self._filter_ids:
142            self._filter_ids.remove(filter_id)
143
144        self._shell.exec_command(f'can filter remove {self._device} {filter_id}')
145        retval = self._retval()
146        if retval != 0:
147            raise CanOperationError(f'failed to remove filter ID {filter_id} (err {retval})')
148
149    def _remove_all_filters(self) -> None:
150        """Remove all RX filters."""
151        for filter_id in self._filter_ids[:]:
152            self._remove_filter(filter_id)
153
154    def _apply_filters(self, filters: CanFilters | None) -> None:
155        self._remove_all_filters()
156
157        if filters:
158            self._is_filtered = True
159        else:
160            # Accept all frames if no hardware filters provided
161            filters = [
162                {'can_id': 0x0, 'can_mask': 0x0},
163                {'can_id': 0x0, 'can_mask': 0x0, 'extended': True},
164            ]
165            self._is_filtered = False
166
167        for can_filter in filters:
168            can_id = can_filter['can_id']
169            can_mask = can_filter['can_mask']
170            extended = can_filter.get('extended', False)
171            self._add_filter(can_id, can_mask, extended)
172
173    def _recv_internal(self, timeout: float | None) -> tuple[Message | None, bool]:
174        frame_regex = (
175            r'.*'
176            + re.escape(self._device)
177            + r'\s+(?P<brs>\S)(?P<esi>\S)\s+(?P<can_id>\d+)\s+'
178            + r'\[(?P<dlc>\d+)\]\s*(?P<data>[a-z0-9 ]*)'
179        )
180        lines = self._dut.readlines_until(regex=frame_regex, timeout=timeout)
181        msg = None
182
183        regex_compiled = re.compile(frame_regex)
184        for line in lines:
185            m = regex_compiled.match(line)
186            if m:
187                can_id = int(m.group('can_id'), 16)
188                ext = len(m.group('can_id')) == 8
189                dlc = int(m.group('dlc'))
190                fd = len(m.group('dlc')) == 2
191                brs = m.group('brs') == 'B'
192                esi = m.group('esi') == 'P'
193                data = bytearray.fromhex(m.group('data'))
194                msg = Message(
195                    arbitration_id=can_id,
196                    is_extended_id=ext,
197                    data=data,
198                    dlc=dlc,
199                    is_fd=fd,
200                    bitrate_switch=brs,
201                    error_state_indicator=esi,
202                    channel=self._device,
203                    check=True,
204                )
205                logger.debug('received: %s', msg)
206
207        return msg, self._is_filtered
208
209    def shutdown(self) -> None:
210        if not self._is_shutdown:
211            super().shutdown()
212            self._stop()
213            self._remove_all_filters()
214