1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
5
6import logging
7import re
8import shlex
9
10from subprocess import check_output, getstatusoutput
11from pathlib import Path
12from dataclasses import dataclass
13
14logger = logging.getLogger(__name__)
15
16
17class MCUmgrException(Exception):
18    """General MCUmgr exception."""
19
20
21@dataclass
22class MCUmgrImage:
23    image: int
24    slot: int
25    version: str = ''
26    flags: str = ''
27    hash: str = ''
28
29
30class MCUmgr:
31    """Sample wrapper for mcumgr command-line tool"""
32    mcumgr_exec = 'mcumgr'
33
34    def __init__(self, connection_options: str):
35        self.conn_opts = connection_options
36
37    @classmethod
38    def create_for_serial(cls, serial_port: str) -> MCUmgr:
39        return cls(connection_options=f'--conntype serial --connstring={serial_port}')
40
41    @classmethod
42    def is_available(cls) -> bool:
43        exitcode, output = getstatusoutput(f'{cls.mcumgr_exec} version')
44        if exitcode != 0:
45            logger.warning(f'mcumgr tool not available: {output}')
46            return False
47        return True
48
49    def run_command(self, cmd: str) -> str:
50        command = f'{self.mcumgr_exec} {self.conn_opts} {cmd}'
51        logger.info(f'CMD: {command}')
52        return check_output(shlex.split(command), text=True)
53
54    def reset_device(self):
55        self.run_command('reset')
56
57    def image_upload(self, image: Path | str, timeout: int = 30):
58        self.run_command(f'-t {timeout} image upload {image}')
59        logger.info('Image successfully uploaded')
60
61    def get_image_list(self) -> list[MCUmgrImage]:
62        output = self.run_command('image list')
63        return self._parse_image_list(output)
64
65    @staticmethod
66    def _parse_image_list(cmd_output: str) -> list[MCUmgrImage]:
67        image_list = []
68        re_image = re.compile(r'image=(\d+)\s+slot=(\d+)')
69        re_version = re.compile(r'version:\s+(\S+)')
70        re_flags = re.compile(r'flags:\s+(.+)')
71        re_hash = re.compile(r'hash:\s+(\w+)')
72        for line in cmd_output.splitlines():
73            if m := re_image.search(line):
74                image_list.append(
75                    MCUmgrImage(
76                        image=int(m.group(1)),
77                        slot=int(m.group(2))
78                    )
79                )
80            elif image_list:
81                if m := re_version.search(line):
82                    image_list[-1].version = m.group(1)
83                elif m := re_flags.search(line):
84                    image_list[-1].flags = m.group(1)
85                elif m := re_hash.search(line):
86                    image_list[-1].hash = m.group(1)
87        return image_list
88
89    def get_hash_to_test(self) -> str:
90        image_list = self.get_image_list()
91        if len(image_list) < 2:
92            logger.info(image_list)
93            raise MCUmgrException('Please check image list returned by mcumgr')
94        return image_list[1].hash
95
96    def image_test(self, hash: str | None = None):
97        if not hash:
98            hash = self.get_hash_to_test()
99        self.run_command(f'image test {hash}')
100
101    def image_confirm(self, hash: str | None = None):
102        if not hash:
103            image_list = self.get_image_list()
104            hash = image_list[0].hash
105        self.run_command(f'image confirm {hash}')
106