1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
5
6import logging
7import re
8import shlex
9from dataclasses import dataclass
10from pathlib import Path
11from subprocess import check_output, getstatusoutput
12
13logger = logging.getLogger(__name__)
14
15
16class MCUmgrException(Exception):
17    """General MCUmgr exception."""
18
19
20@dataclass
21class MCUmgrImage:
22    image: int
23    slot: int
24    version: str = ''
25    flags: str = ''
26    hash: str = ''
27
28
29class MCUmgr:
30    """Sample wrapper for mcumgr command-line tool"""
31    mcumgr_exec = 'mcumgr'
32
33    def __init__(self, connection_options: str):
34        self.conn_opts = connection_options
35
36    @classmethod
37    def create_for_serial(cls, serial_port: str) -> MCUmgr:
38        return cls(connection_options=f'--conntype serial --connstring={serial_port}')
39
40    @classmethod
41    def is_available(cls) -> bool:
42        exitcode, output = getstatusoutput(f'{cls.mcumgr_exec} version')
43        if exitcode != 0:
44            logger.warning(f'mcumgr tool not available: {output}')
45            return False
46        return True
47
48    def run_command(self, cmd: str) -> str:
49        command = f'{self.mcumgr_exec} {self.conn_opts} {cmd}'
50        logger.info(f'CMD: {command}')
51        return check_output(shlex.split(command), text=True)
52
53    def reset_device(self):
54        self.run_command('reset')
55
56    def image_upload(self, image: Path | str, slot: int | None = None, timeout: int = 30):
57        command = f'-t {timeout} image upload {image}'
58        if slot is not None:
59            command += f' -e -n {slot}'
60        self.run_command(command)
61        logger.info('Image successfully uploaded')
62
63    def get_image_list(self) -> list[MCUmgrImage]:
64        output = self.run_command('image list')
65        return self._parse_image_list(output)
66
67    @staticmethod
68    def _parse_image_list(cmd_output: str) -> list[MCUmgrImage]:
69        image_list = []
70        re_image = re.compile(r'image=(\d+)\s+slot=(\d+)')
71        re_version = re.compile(r'version:\s+(\S+)')
72        re_flags = re.compile(r'flags:\s+(.+)')
73        re_hash = re.compile(r'hash:\s+(\w+)')
74        for line in cmd_output.splitlines():
75            if m := re_image.search(line):
76                image_list.append(
77                    MCUmgrImage(
78                        image=int(m.group(1)),
79                        slot=int(m.group(2))
80                    )
81                )
82            elif image_list:
83                if m := re_version.search(line):
84                    image_list[-1].version = m.group(1)
85                elif m := re_flags.search(line):
86                    image_list[-1].flags = m.group(1)
87                elif m := re_hash.search(line):
88                    image_list[-1].hash = m.group(1)
89        return image_list
90
91    def get_hash_to_test(self) -> str:
92        image_list = self.get_image_list()
93        for image in image_list:
94            if 'active' not in image.flags:
95                return image.hash
96        logger.warning(f'Images returned by mcumgr (no not active):\n{image_list}')
97        raise MCUmgrException('No not active image found')
98
99    def get_hash_to_confirm(self):
100        image_list = self.get_image_list()
101        for image in image_list:
102            if 'confirmed' not in image.flags:
103                return image.hash
104        logger.warning(f'Images returned by mcumgr (no not confirmed):\n{image_list}')
105        raise MCUmgrException('No not confirmed image found')
106
107    def image_test(self, hash: str | None = None):
108        if not hash:
109            hash = self.get_hash_to_test()
110        self.run_command(f'image test {hash}')
111
112    def image_confirm(self, hash: str | None = None):
113        if not hash:
114            hash = self.get_hash_to_confirm()
115        self.run_command(f'image confirm {hash}')
116