1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4
5from __future__ import annotations
6
7import logging
8import re
9import time
10
11from dataclasses import dataclass, field
12from inspect import signature
13
14from twister_harness.device.device_adapter import DeviceAdapter
15from twister_harness.exceptions import TwisterHarnessTimeoutException
16
17logger = logging.getLogger(__name__)
18
19
20class Shell:
21    """
22    Helper class that provides methods used to interact with shell application.
23    """
24
25    def __init__(
26        self, device: DeviceAdapter, prompt: str = 'uart:~$', timeout: float | None = None
27    ) -> None:
28        self._device: DeviceAdapter = device
29        self.prompt: str = prompt
30        self.base_timeout: float = timeout or device.base_timeout
31
32    def wait_for_prompt(self, timeout: float | None = None) -> bool:
33        """
34        Send every 0.5 second "enter" command to the device until shell prompt
35        statement will occur (return True) or timeout will be exceeded (return
36        False).
37        """
38        timeout = timeout or self.base_timeout
39        timeout_time = time.time() + timeout
40        self._device.clear_buffer()
41        while time.time() < timeout_time:
42            self._device.write(b'\n')
43            try:
44                line = self._device.readline(timeout=0.5, print_output=False)
45            except TwisterHarnessTimeoutException:
46                # ignore read timeout and try to send enter once again
47                continue
48            if self.prompt in line:
49                logger.debug('Got prompt')
50                return True
51        return False
52
53    def exec_command(
54        self, command: str, timeout: float | None = None, print_output: bool = True
55    ) -> list[str]:
56        """
57        Send shell command to a device and return response. Passed command
58        is extended by double enter sings - first one to execute this command
59        on a device, second one to receive next prompt what is a signal that
60        execution was finished. Method returns printout of the executed command.
61        """
62        timeout = timeout or self.base_timeout
63        command_ext = f'{command}\n\n'
64        regex_prompt = re.escape(self.prompt)
65        regex_command = f'.*{re.escape(command)}'
66        self._device.clear_buffer()
67        self._device.write(command_ext.encode())
68        lines: list[str] = []
69        # wait for device command print - it should be done immediately after sending command to device
70        lines.extend(
71            self._device.readlines_until(
72                regex=regex_command, timeout=1.0, print_output=print_output
73            )
74        )
75        # wait for device command execution
76        lines.extend(
77            self._device.readlines_until(
78                regex=regex_prompt, timeout=timeout, print_output=print_output
79            )
80        )
81        return lines
82
83    def get_filtered_output(self, command_lines: list[str]) -> list[str]:
84        """
85        Filter out prompts and log messages
86
87        Take the output of exec_command, which can contain log messages and command prompts,
88        and filter them to obtain only the command output.
89
90        Example:
91            >>> # equivalent to `lines = shell.exec_command("kernel version")`
92            >>> lines = [
93            >>>    'uart:~$',                    # filter prompts
94            >>>    'Zephyr version 3.6.0',       # keep this line
95            >>>    'uart:~$ <dbg> debug message' # filter log messages
96            >>> ]
97            >>> filtered_output = shell.get_filtered_output(output)
98            >>> filtered_output
99            ['Zephyr version 3.6.0']
100
101        :param command_lines: List of strings i.e. the output of `exec_command`.
102        :return: A list of strings containing, excluding prompts and log messages.
103        """
104        regex_filter = re.compile(
105            '|'.join([re.escape(self.prompt), '<dbg>', '<inf>', '<wrn>', '<err>'])
106        )
107        return list(filter(lambda l: not regex_filter.search(l), command_lines))
108
109
110@dataclass
111class ShellMCUbootArea:
112    name: str
113    version: str
114    image_size: str
115    magic: str = 'unset'
116    swap_type: str = 'none'
117    copy_done: str = 'unset'
118    image_ok: str = 'unset'
119
120    @classmethod
121    def from_kwargs(cls, **kwargs) -> ShellMCUbootArea:
122        cls_fields = {field for field in signature(cls).parameters}
123        native_args = {}
124        for name, val in kwargs.items():
125            if name in cls_fields:
126                native_args[name] = val
127        return cls(**native_args)
128
129
130@dataclass
131class ShellMCUbootCommandParsed:
132    """
133    Helper class to keep data from `mcuboot` shell command.
134    """
135
136    areas: list[ShellMCUbootArea] = field(default_factory=list)
137
138    @classmethod
139    def create_from_cmd_output(cls, cmd_output: list[str]) -> ShellMCUbootCommandParsed:
140        """
141        Factory to create class from the output of `mcuboot` shell command.
142        """
143        areas: list[dict] = []
144        re_area = re.compile(r'(.+ area.*):\s*$')
145        re_key = re.compile(r'(?P<key>.+):(?P<val>.+)')
146        for line in cmd_output:
147            if m := re_area.search(line):
148                areas.append({'name': m.group(1)})
149            elif areas:
150                if m := re_key.search(line):
151                    areas[-1][m.group('key').strip().replace(' ', '_')] = m.group('val').strip()
152        data_areas: list[ShellMCUbootArea] = []
153        for area in areas:
154            data_areas.append(ShellMCUbootArea.from_kwargs(**area))
155
156        return cls(data_areas)
157