1# Copyright (c) 2023 Nordic Semiconductor ASA
2#
3# SPDX-License-Identifier: Apache-2.0
4
5from __future__ import annotations
6
7import logging
8from dataclasses import dataclass, field
9from pathlib import Path
10from twister_harness.helpers.domains_helper import get_default_domain_name
11
12import pytest
13
14logger = logging.getLogger(__name__)
15
16
17@dataclass
18class DeviceConfig:
19    type: str
20    build_dir: Path
21    base_timeout: float = 60.0  # [s]
22    flash_timeout: float = 60.0  # [s]
23    platform: str = ''
24    serial: str = ''
25    baud: int = 115200
26    runner: str = ''
27    runner_params: list[str] = field(default_factory=list, repr=False)
28    id: str = ''
29    product: str = ''
30    serial_pty: str = ''
31    flash_before: bool = False
32    west_flash_extra_args: list[str] = field(default_factory=list, repr=False)
33    name: str = ''
34    pre_script: Path | None = None
35    post_script: Path | None = None
36    post_flash_script: Path | None = None
37    fixtures: list[str] = None
38    app_build_dir: Path | None = None
39    extra_test_args: str = ''
40
41    def __post_init__(self):
42        domains = self.build_dir / 'domains.yaml'
43        if domains.exists():
44            self.app_build_dir = self.build_dir / get_default_domain_name(domains)
45        else:
46            self.app_build_dir = self.build_dir
47
48
49@dataclass
50class TwisterHarnessConfig:
51    """Store Twister harness configuration to have easy access in test."""
52    devices: list[DeviceConfig] = field(default_factory=list, repr=False)
53
54    @classmethod
55    def create(cls, config: pytest.Config) -> TwisterHarnessConfig:
56        """Create new instance from pytest.Config."""
57
58        devices = []
59
60        west_flash_extra_args: list[str] = []
61        if config.option.west_flash_extra_args:
62            west_flash_extra_args = [w.strip() for w in config.option.west_flash_extra_args.split(',')]
63        runner_params: list[str] = []
64        if config.option.runner_params:
65            runner_params = [w.strip() for w in config.option.runner_params]
66        device_from_cli = DeviceConfig(
67            type=config.option.device_type,
68            build_dir=_cast_to_path(config.option.build_dir),
69            base_timeout=config.option.base_timeout,
70            flash_timeout=config.option.flash_timeout,
71            platform=config.option.platform,
72            serial=config.option.device_serial,
73            baud=config.option.device_serial_baud,
74            runner=config.option.runner,
75            runner_params=runner_params,
76            id=config.option.device_id,
77            product=config.option.device_product,
78            serial_pty=config.option.device_serial_pty,
79            flash_before=bool(config.option.flash_before),
80            west_flash_extra_args=west_flash_extra_args,
81            pre_script=_cast_to_path(config.option.pre_script),
82            post_script=_cast_to_path(config.option.post_script),
83            post_flash_script=_cast_to_path(config.option.post_flash_script),
84            fixtures=config.option.fixtures,
85            extra_test_args=config.option.extra_test_args
86        )
87
88        devices.append(device_from_cli)
89
90        return cls(
91            devices=devices
92        )
93
94
95def _cast_to_path(path: str | None) -> Path | None:
96    if path is None:
97        return None
98    return Path(path)
99