1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4
5from __future__ import annotations
6
7import logging
8import os
9if os.name != 'nt':
10    import pty
11import re
12import subprocess
13import time
14from pathlib import Path
15
16import serial
17from twister_harness.device.device_adapter import DeviceAdapter
18from twister_harness.exceptions import (
19    TwisterHarnessException,
20    TwisterHarnessTimeoutException,
21)
22from twister_harness.device.utils import log_command, terminate_process
23from twister_harness.twister_harness_config import DeviceConfig
24
25logger = logging.getLogger(__name__)
26
27
28class HardwareAdapter(DeviceAdapter):
29    """Adapter class for real device."""
30
31    def __init__(self, device_config: DeviceConfig) -> None:
32        super().__init__(device_config)
33        self._flashing_timeout: float = device_config.flash_timeout
34        self._serial_connection: serial.Serial | None = None
35        self._serial_pty_proc: subprocess.Popen | None = None
36        self._serial_buffer: bytearray = bytearray()
37
38        self.device_log_path: Path = device_config.build_dir / 'device.log'
39        self._log_files.append(self.device_log_path)
40
41    def generate_command(self) -> None:
42        """Return command to flash."""
43        command = [
44            self.west,
45            'flash',
46            '--skip-rebuild',
47            '--build-dir', str(self.device_config.build_dir),
48        ]
49
50        command_extra_args = []
51        if self.device_config.west_flash_extra_args:
52            command_extra_args.extend(self.device_config.west_flash_extra_args)
53
54        if self.device_config.runner:
55            runner_base_args, runner_extra_args = self._prepare_runner_args()
56            command.extend(runner_base_args)
57            command_extra_args.extend(runner_extra_args)
58
59        if command_extra_args:
60            command.append('--')
61            command.extend(command_extra_args)
62        self.command = command
63
64    def _prepare_runner_args(self) -> tuple[list[str], list[str]]:
65        base_args: list[str] = []
66        extra_args: list[str] = []
67        runner = self.device_config.runner
68        base_args.extend(['--runner', runner])
69        if self.device_config.runner_params:
70            for param in self.device_config.runner_params:
71                extra_args.append(param)
72        if board_id := self.device_config.id:
73            if runner == 'pyocd':
74                extra_args.append('--board-id')
75                extra_args.append(board_id)
76            elif runner in ('nrfjprog', 'nrfutil'):
77                extra_args.append('--dev-id')
78                extra_args.append(board_id)
79            elif runner == 'openocd' and self.device_config.product in ['STM32 STLink', 'STLINK-V3']:
80                extra_args.append('--cmd-pre-init')
81                extra_args.append(f'hla_serial {board_id}')
82            elif runner == 'openocd' and self.device_config.product == 'EDBG CMSIS-DAP':
83                extra_args.append('--cmd-pre-init')
84                extra_args.append(f'cmsis_dap_serial {board_id}')
85            elif runner == "openocd" and self.device_config.product == "LPC-LINK2 CMSIS-DAP":
86                extra_args.append("--cmd-pre-init")
87                extra_args.append(f'adapter serial {board_id}')
88            elif runner == 'jlink':
89                base_args.append('--dev-id')
90                base_args.append(board_id)
91            elif runner == 'stm32cubeprogrammer':
92                base_args.append(f'--tool-opt=sn={board_id}')
93            elif runner == 'linkserver':
94                base_args.append(f'--probe={board_id}')
95        return base_args, extra_args
96
97    def _flash_and_run(self) -> None:
98        """Flash application on a device."""
99        if not self.command:
100            msg = 'Flash command is empty, please verify if it was generated properly.'
101            logger.error(msg)
102            raise TwisterHarnessException(msg)
103
104        if self.device_config.pre_script:
105            self._run_custom_script(self.device_config.pre_script, self.base_timeout)
106
107        if self.device_config.id:
108            logger.debug('Flashing device %s', self.device_config.id)
109        log_command(logger, 'Flashing command', self.command, level=logging.DEBUG)
110
111        process = stdout = None
112        try:
113            process = subprocess.Popen(self.command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=self.env)
114            stdout, _ = process.communicate(timeout=self._flashing_timeout)
115        except subprocess.TimeoutExpired as exc:
116            process.kill()
117            msg = f'Timeout occurred ({self._flashing_timeout}s) during flashing.'
118            logger.error(msg)
119            raise TwisterHarnessTimeoutException(msg) from exc
120        except subprocess.SubprocessError as exc:
121            msg = f'Flashing subprocess failed due to SubprocessError {exc}'
122            logger.error(msg)
123            raise TwisterHarnessTimeoutException(msg) from exc
124        finally:
125            if stdout is not None:
126                stdout_decoded = stdout.decode(errors='ignore')
127                with open(self.device_log_path, 'a+') as log_file:
128                    log_file.write(stdout_decoded)
129            if self.device_config.post_flash_script:
130                self._run_custom_script(self.device_config.post_flash_script, self.base_timeout)
131            if process is not None and process.returncode == 0:
132                logger.debug('Flashing finished')
133            else:
134                msg = f'Could not flash device {self.device_config.id}'
135                logger.error(msg)
136                raise TwisterHarnessException(msg)
137
138    def _connect_device(self) -> None:
139        serial_name = self._open_serial_pty() or self.device_config.serial
140        logger.debug('Opening serial connection for %s', serial_name)
141        try:
142            self._serial_connection = serial.Serial(
143                serial_name,
144                baudrate=self.device_config.baud,
145                parity=serial.PARITY_NONE,
146                stopbits=serial.STOPBITS_ONE,
147                bytesize=serial.EIGHTBITS,
148                timeout=self.base_timeout,
149            )
150        except serial.SerialException as exc:
151            logger.exception('Cannot open connection: %s', exc)
152            self._close_serial_pty()
153            raise
154
155        self._serial_connection.flush()
156        self._serial_connection.reset_input_buffer()
157        self._serial_connection.reset_output_buffer()
158
159    def _open_serial_pty(self) -> str | None:
160        """Open a pty pair, run process and return tty name"""
161        if not self.device_config.serial_pty:
162            return None
163
164        try:
165            master, slave = pty.openpty()
166        except NameError as exc:
167            logger.exception('PTY module is not available.')
168            raise exc
169
170        try:
171            self._serial_pty_proc = subprocess.Popen(
172                re.split(',| ', self.device_config.serial_pty),
173                stdout=master,
174                stdin=master,
175                stderr=master
176            )
177        except subprocess.CalledProcessError as exc:
178            logger.exception('Failed to run subprocess %s, error %s', self.device_config.serial_pty, str(exc))
179            raise
180        return os.ttyname(slave)
181
182    def _disconnect_device(self) -> None:
183        if self._serial_connection:
184            serial_name = self._serial_connection.port
185            self._serial_connection.close()
186            # self._serial_connection = None
187            logger.debug('Closed serial connection for %s', serial_name)
188        self._close_serial_pty()
189
190    def _close_serial_pty(self) -> None:
191        """Terminate the process opened for serial pty script"""
192        if self._serial_pty_proc:
193            self._serial_pty_proc.terminate()
194            self._serial_pty_proc.communicate(timeout=self.base_timeout)
195            logger.debug('Process %s terminated', self.device_config.serial_pty)
196            self._serial_pty_proc = None
197
198    def _close_device(self) -> None:
199        if self.device_config.post_script:
200            self._run_custom_script(self.device_config.post_script, self.base_timeout)
201
202    def is_device_running(self) -> bool:
203        return self._device_run.is_set()
204
205    def is_device_connected(self) -> bool:
206        return bool(
207            self.is_device_running()
208            and self._device_connected.is_set()
209            and self._serial_connection
210            and self._serial_connection.is_open
211        )
212
213    def _read_device_output(self) -> bytes:
214        try:
215            output = self._readline_serial()
216        except (serial.SerialException, TypeError, IOError):
217            # serial was probably disconnected
218            output = b''
219        return output
220
221    def _readline_serial(self) -> bytes:
222        """
223        This method was created to avoid using PySerial built-in readline
224        method which cause blocking reader thread even if there is no data to
225        read. Instead for this, following implementation try to read data only
226        if they are available. Inspiration for this code was taken from this
227        comment:
228        https://github.com/pyserial/pyserial/issues/216#issuecomment-369414522
229        """
230        line = self._readline_from_serial_buffer()
231        if line is not None:
232            return line
233        while True:
234            if self._serial_connection is None or not self._serial_connection.is_open:
235                return b''
236            elif self._serial_connection.in_waiting == 0:
237                time.sleep(0.05)
238                continue
239            else:
240                bytes_to_read = max(1, min(2048, self._serial_connection.in_waiting))
241                output = self._serial_connection.read(bytes_to_read)
242                self._serial_buffer.extend(output)
243                line = self._readline_from_serial_buffer()
244                if line is not None:
245                    return line
246
247    def _readline_from_serial_buffer(self) -> bytes | None:
248        idx = self._serial_buffer.find(b"\n")
249        if idx >= 0:
250            line = self._serial_buffer[:idx+1]
251            self._serial_buffer = self._serial_buffer[idx+1:]
252            return bytes(line)
253        else:
254            return None
255
256    def _write_to_device(self, data: bytes) -> None:
257        self._serial_connection.write(data)
258
259    def _flush_device_output(self) -> None:
260        if self.is_device_connected():
261            self._serial_connection.flush()
262            self._serial_connection.reset_input_buffer()
263
264    def _clear_internal_resources(self) -> None:
265        super()._clear_internal_resources()
266        self._serial_connection = None
267        self._serial_pty_proc = None
268        self._serial_buffer.clear()
269
270    @staticmethod
271    def _run_custom_script(script_path: str | Path, timeout: float) -> None:
272        with subprocess.Popen(str(script_path), stderr=subprocess.PIPE, stdout=subprocess.PIPE) as proc:
273            try:
274                stdout, stderr = proc.communicate(timeout=timeout)
275                logger.debug(stdout.decode())
276                if proc.returncode != 0:
277                    msg = f'Custom script failure: \n{stderr.decode(errors="ignore")}'
278                    logger.error(msg)
279                    raise TwisterHarnessException(msg)
280
281            except subprocess.TimeoutExpired as exc:
282                terminate_process(proc)
283                proc.communicate(timeout=timeout)
284                msg = f'Timeout occurred ({timeout}s) during execution custom script: {script_path}'
285                logger.error(msg)
286                raise TwisterHarnessTimeoutException(msg) from exc
287