1# Copyright (c) 2024 Vestas Wind Systems A/S
2#
3# SPDX-License-Identifier: Apache-2.0
4
5"""
6Zephyr CAN shell module support for providing a python-can bus interface for testing.
7"""
8
9import re
10import logging
11from typing import Optional, Tuple
12
13from can import BusABC, CanProtocol, Message
14from can.exceptions import CanInitializationError, CanOperationError
15from can.typechecking import CanFilters
16
17from twister_harness import DeviceAdapter, Shell
18
19logger = logging.getLogger(__name__)
20
21class CanShellBus(BusABC): # pylint: disable=abstract-method
22    """
23    A CAN interface using the Zephyr CAN shell module.
24    """
25
26    def __init__(self, dut: DeviceAdapter, shell: Shell, channel: str,
27                 can_filters: Optional[CanFilters] = None, **kwargs) -> None:
28        self._dut = dut
29        self._shell = shell
30        self._device = channel
31        self._is_filtered = False
32        self._filter_ids = []
33
34        self.channel_info = f'Zephyr CAN shell, device "{self._device}"'
35
36        mode = 'normal'
37        if 'fd' in self._get_capabilities():
38            self._can_protocol = CanProtocol.CAN_FD
39            mode += ' fd'
40        else:
41            self._can_protocol = CanProtocol.CAN_20
42
43        self._set_mode(mode)
44        self._start()
45
46        super().__init__(channel=channel, can_filters=can_filters, **kwargs)
47
48    def _retval(self):
49        """Get return value of last shell command."""
50        return int(self._shell.get_filtered_output(self._shell.exec_command('retval'))[0])
51
52    def _get_capabilities(self) -> list[str]:
53        cmd = f'can show {self._device}'
54
55        lines = self._shell.get_filtered_output(self._shell.exec_command(cmd))
56        regex_compiled = re.compile(r'capabilities:\s+(?P<caps>.*)')
57        for line in lines:
58            m = regex_compiled.match(line)
59            if m:
60                return m.group('caps').split()
61
62        raise CanOperationError('capabilities not found')
63
64    def _set_mode(self, mode: str) -> None:
65        self._shell.exec_command(f'can mode {self._device} {mode}')
66        retval = self._retval()
67        if retval != 0:
68            raise CanOperationError(f'failed to set mode "{mode}" (err {retval})')
69
70    def _start(self):
71        self._shell.exec_command(f'can start {self._device}')
72        retval = self._retval()
73        if retval != 0:
74            raise CanInitializationError(f'failed to start (err {retval})')
75
76    def _stop(self):
77        self._shell.exec_command(f'can stop {self._device}')
78
79    def send(self, msg: Message, timeout: Optional[float] = None) -> None:
80        logger.debug('sending: %s', msg)
81
82        cmd = f'can send {self._device}'
83        cmd += ' -e' if msg.is_extended_id else ''
84        cmd += ' -r' if msg.is_remote_frame else ''
85        cmd += ' -f' if msg.is_fd else ''
86        cmd += ' -b' if msg.bitrate_switch else ''
87
88        if msg.is_extended_id:
89            cmd += f' {msg.arbitration_id:08x}'
90        else:
91            cmd += f' {msg.arbitration_id:03x}'
92
93        if msg.data:
94            cmd += ' ' + msg.data.hex(' ', 1)
95
96        lines = self._shell.exec_command(cmd)
97        regex_compiled = re.compile(r'enqueuing\s+CAN\s+frame\s+#(?P<id>\d+)')
98        frame_num = None
99        for line in lines:
100            m = regex_compiled.match(line)
101            if m:
102                frame_num = m.group('id')
103                break
104
105        if frame_num is None:
106            raise CanOperationError('frame not enqueued')
107
108        tx_regex = r'CAN\s+frame\s+#' + frame_num + r'\s+successfully\s+sent'
109        self._dut.readlines_until(regex=tx_regex, timeout=timeout)
110
111    def _add_filter(self, can_id: int, can_mask: int, extended: bool) -> None:
112        """Add RX filter."""
113        cmd = f'can filter add {self._device}'
114        cmd += ' -e' if extended else ''
115
116        if extended:
117            cmd += f' {can_id:08x}'
118            cmd += f' {can_mask:08x}'
119        else:
120            cmd += f' {can_id:03x}'
121            cmd += f' {can_mask:03x}'
122
123        lines = self._shell.exec_command(cmd)
124        regex_compiled = re.compile(r'filter\s+ID:\s+(?P<id>\d+)')
125        for line in lines:
126            m = regex_compiled.match(line)
127            if m:
128                filter_id = int(m.group('id'))
129                self._filter_ids.append(filter_id)
130                return
131
132        raise CanOperationError('filter_id not found')
133
134    def _remove_filter(self, filter_id: int) -> None:
135        """Remove RX filter."""
136        if filter_id in self._filter_ids:
137            self._filter_ids.remove(filter_id)
138
139        self._shell.exec_command(f'can filter remove {self._device} {filter_id}')
140        retval = self._retval()
141        if retval != 0:
142            raise CanOperationError(f'failed to remove filter ID {filter_id} (err {retval})')
143
144    def _remove_all_filters(self) -> None:
145        """Remove all RX filters."""
146        for filter_id in self._filter_ids[:]:
147            self._remove_filter(filter_id)
148
149    def _apply_filters(self, filters: Optional[CanFilters]) -> None:
150        self._remove_all_filters()
151
152        if filters:
153            self._is_filtered = True
154        else:
155            # Accept all frames if no hardware filters provided
156            filters = [
157                {'can_id': 0x0, 'can_mask': 0x0},
158                {'can_id': 0x0, 'can_mask': 0x0, 'extended': True}
159            ]
160            self._is_filtered = False
161
162        for can_filter in filters:
163            can_id = can_filter['can_id']
164            can_mask = can_filter['can_mask']
165            extended = can_filter['extended'] if 'extended' in can_filter else False
166            self._add_filter(can_id, can_mask, extended)
167
168    def _recv_internal(self, timeout: Optional[float]) -> Tuple[Optional[Message], bool]:
169        frame_regex = r'.*' + re.escape(self._device) + \
170            r'\s+(?P<brs>\S)(?P<esi>\S)\s+(?P<can_id>\d+)\s+\[(?P<dlc>\d+)\]\s*(?P<data>[a-z0-9 ]*)'
171        lines = self._dut.readlines_until(regex=frame_regex, timeout=timeout)
172        msg = None
173
174        regex_compiled = re.compile(frame_regex)
175        for line in lines:
176            m = regex_compiled.match(line)
177            if m:
178                can_id = int(m.group('can_id'), 16)
179                ext = len(m.group('can_id')) == 8
180                dlc = int(m.group('dlc'))
181                fd = len(m.group('dlc')) == 2
182                brs = m.group('brs') == 'B'
183                esi = m.group('esi') == 'P'
184                data = bytearray.fromhex(m.group('data'))
185                msg = Message(arbitration_id=can_id,is_extended_id=ext,
186                              data=data, dlc=dlc,
187                              is_fd=fd, bitrate_switch=brs, error_state_indicator=esi,
188                              channel=self._device, check=True)
189                logger.debug('received: %s', msg)
190
191        return msg, self._is_filtered
192
193    def shutdown(self) -> None:
194        if not self._is_shutdown:
195            super().shutdown()
196            self._stop()
197            self._remove_all_filters()
198