1#!/usr/bin/env python3
2# vim: set syntax=python ts=4 :
3#
4# Copyright (c) 2022 Intel Corporation
5# SPDX-License-Identifier: Apache-2.0
6
7import logging
8import os
9import platform
10import re
11from multiprocessing import Lock, Value
12from pathlib import Path
13
14import scl
15import yaml
16from natsort import natsorted
17from twisterlib.environment import ZEPHYR_BASE
18
19try:
20    # Use the C LibYAML parser if available, rather than the Python parser.
21    # It's much faster.
22    from yaml import CDumper as Dumper
23    from yaml import CSafeLoader as SafeLoader
24except ImportError:
25    from yaml import Dumper, SafeLoader
26
27try:
28    from tabulate import tabulate
29except ImportError:
30    print("Install tabulate python module with pip to use --device-testing option.")
31
32logger = logging.getLogger('twister')
33logger.setLevel(logging.DEBUG)
34
35
36class DUT:
37    def __init__(self,
38                 id=None,
39                 serial=None,
40                 serial_baud=None,
41                 platform=None,
42                 product=None,
43                 serial_pty=None,
44                 connected=False,
45                 runner_params=None,
46                 pre_script=None,
47                 post_script=None,
48                 post_flash_script=None,
49                 script_param=None,
50                 runner=None,
51                 flash_timeout=60,
52                 flash_with_test=False,
53                 flash_before=False):
54
55        self.serial = serial
56        self.baud = serial_baud or 115200
57        self.platform = platform
58        self.serial_pty = serial_pty
59        self._counter = Value("i", 0)
60        self._available = Value("i", 1)
61        self._failures = Value("i", 0)
62        self.connected = connected
63        self.pre_script = pre_script
64        self.id = id
65        self.product = product
66        self.runner = runner
67        self.runner_params = runner_params
68        self.flash_before = flash_before
69        self.fixtures = []
70        self.post_flash_script = post_flash_script
71        self.post_script = post_script
72        self.pre_script = pre_script
73        self.script_param = script_param
74        self.probe_id = None
75        self.notes = None
76        self.lock = Lock()
77        self.match = False
78        self.flash_timeout = flash_timeout
79        self.flash_with_test = flash_with_test
80
81    @property
82    def available(self):
83        with self._available.get_lock():
84            return self._available.value
85
86    @available.setter
87    def available(self, value):
88        with self._available.get_lock():
89            self._available.value = value
90
91    @property
92    def counter(self):
93        with self._counter.get_lock():
94            return self._counter.value
95
96    @counter.setter
97    def counter(self, value):
98        with self._counter.get_lock():
99            self._counter.value = value
100
101    def counter_increment(self, value=1):
102        with self._counter.get_lock():
103            self._counter.value += value
104
105    @property
106    def failures(self):
107        with self._failures.get_lock():
108            return self._failures.value
109
110    @failures.setter
111    def failures(self, value):
112        with self._failures.get_lock():
113            self._failures.value = value
114
115    def failures_increment(self, value=1):
116        with self._failures.get_lock():
117            self._failures.value += value
118
119    def to_dict(self):
120        d = {}
121        exclude = ['_available', '_counter', '_failures', 'match']
122        v = vars(self)
123        for k in v:
124            if k not in exclude and v[k]:
125                d[k] = v[k]
126        return d
127
128
129    def __repr__(self):
130        return f"<{self.platform} ({self.product}) on {self.serial}>"
131
132class HardwareMap:
133    schema_path = os.path.join(ZEPHYR_BASE, "scripts", "schemas", "twister", "hwmap-schema.yaml")
134
135    manufacturer = [
136        'ARM',
137        'SEGGER',
138        'MBED',
139        'STMicroelectronics',
140        'Atmel Corp.',
141        'Texas Instruments',
142        'Silicon Labs',
143        'NXP',
144        'NXP Semiconductors',
145        'Microchip Technology Inc.',
146        'FTDI',
147        'Digilent',
148        'Microsoft',
149        'Nuvoton',
150        'Espressif',
151    ]
152
153    runner_mapping = {
154        'pyocd': [
155            'DAPLink CMSIS-DAP',
156            'MBED CMSIS-DAP'
157        ],
158        'jlink': [
159            'J-Link',
160            'J-Link OB'
161        ],
162        'openocd': [
163            'STM32 STLink', '^XDS110.*', 'STLINK-V3'
164        ],
165        'dediprog': [
166            'TTL232R-3V3',
167            'MCP2200 USB Serial Port Emulator'
168        ]
169    }
170
171    def __init__(self, env=None):
172        self.detected = []
173        self.duts = []
174        self.options = env.options
175
176    def discover(self):
177
178        if self.options.generate_hardware_map:
179            self.scan(persistent=self.options.persistent_hardware_map)
180            self.save(self.options.generate_hardware_map)
181            return 0
182
183        if not self.options.device_testing and self.options.hardware_map:
184            self.load(self.options.hardware_map)
185            logger.info("Available devices:")
186            self.dump(connected_only=True)
187            return 0
188
189        if self.options.device_testing:
190            if self.options.hardware_map:
191                self.load(self.options.hardware_map)
192                if not self.options.platform:
193                    self.options.platform = []
194                    for d in self.duts:
195                        if d.connected and d.platform != 'unknown':
196                            self.options.platform.append(d.platform)
197
198            elif self.options.device_serial:
199                self.add_device(self.options.device_serial,
200                                self.options.platform[0],
201                                self.options.pre_script,
202                                False,
203                                baud=self.options.device_serial_baud,
204                                flash_timeout=self.options.device_flash_timeout,
205                                flash_with_test=self.options.device_flash_with_test,
206                                flash_before=self.options.flash_before,
207                                )
208
209            elif self.options.device_serial_pty:
210                self.add_device(self.options.device_serial_pty,
211                                self.options.platform[0],
212                                self.options.pre_script,
213                                True,
214                                flash_timeout=self.options.device_flash_timeout,
215                                flash_with_test=self.options.device_flash_with_test,
216                                flash_before=False,
217                                )
218
219            # the fixtures given by twister command explicitly should be assigned to each DUT
220            if self.options.fixture:
221                for d in self.duts:
222                    d.fixtures.extend(self.options.fixture)
223        return 1
224
225
226    def summary(self, selected_platforms):
227        print("\nHardware distribution summary:\n")
228        table = []
229        header = ['Board', 'ID', 'Counter', 'Failures']
230        for d in self.duts:
231            if d.connected and d.platform in selected_platforms:
232                row = [d.platform, d.id, d.counter, d.failures]
233                table.append(row)
234        print(tabulate(table, headers=header, tablefmt="github"))
235
236
237    def add_device(
238        self,
239        serial,
240        platform,
241        pre_script,
242        is_pty,
243        baud=None,
244        flash_timeout=60,
245        flash_with_test=False,
246        flash_before=False
247    ):
248        device = DUT(
249            platform=platform,
250            connected=True,
251            pre_script=pre_script,
252            serial_baud=baud,
253            flash_timeout=flash_timeout,
254            flash_with_test=flash_with_test,
255            flash_before=flash_before
256        )
257        if is_pty:
258            device.serial_pty = serial
259        else:
260            device.serial = serial
261
262        self.duts.append(device)
263
264    def load(self, map_file):
265        hwm_schema = scl.yaml_load(self.schema_path)
266        duts = scl.yaml_load_verify(map_file, hwm_schema)
267        for dut in duts:
268            pre_script = dut.get('pre_script')
269            script_param = dut.get('script_param')
270            post_script = dut.get('post_script')
271            post_flash_script = dut.get('post_flash_script')
272            flash_timeout = dut.get('flash_timeout') or self.options.device_flash_timeout
273            flash_with_test = dut.get('flash_with_test')
274            if flash_with_test is None:
275                flash_with_test = self.options.device_flash_with_test
276            serial_pty = dut.get('serial_pty')
277            flash_before = dut.get('flash_before')
278            if flash_before is None:
279                flash_before = self.options.flash_before and (not (flash_with_test or serial_pty))
280            platform = dut.get('platform')
281            if isinstance(platform, str):
282                platforms = platform.split()
283            elif isinstance(platform, list):
284                platforms = platform
285            else:
286                raise ValueError(f"Invalid platform value: {platform}")
287            id = dut.get('id')
288            runner = dut.get('runner')
289            runner_params = dut.get('runner_params')
290            serial = dut.get('serial')
291            baud = dut.get('baud', None)
292            product = dut.get('product')
293            fixtures = dut.get('fixtures', [])
294            connected = dut.get('connected') and ((serial or serial_pty) is not None)
295            if not connected:
296                continue
297            for plat in platforms:
298                new_dut = DUT(platform=plat,
299                              product=product,
300                              runner=runner,
301                              runner_params=runner_params,
302                              id=id,
303                              serial_pty=serial_pty,
304                              serial=serial,
305                              serial_baud=baud,
306                              connected=connected,
307                              pre_script=pre_script,
308                              flash_before=flash_before,
309                              post_script=post_script,
310                              post_flash_script=post_flash_script,
311                              script_param=script_param,
312                              flash_timeout=flash_timeout,
313                              flash_with_test=flash_with_test)
314                new_dut.fixtures = fixtures
315                new_dut.counter = 0
316                self.duts.append(new_dut)
317
318    def scan(self, persistent=False):
319        from serial.tools import list_ports
320
321        if persistent and platform.system() == 'Linux':
322            # On Linux, /dev/serial/by-id provides symlinks to
323            # '/dev/ttyACMx' nodes using names which are unique as
324            # long as manufacturers fill out USB metadata nicely.
325            #
326            # This creates a map from '/dev/ttyACMx' device nodes
327            # to '/dev/serial/by-id/usb-...' symlinks. The symlinks
328            # go into the hardware map because they stay the same
329            # even when the user unplugs / replugs the device.
330            #
331            # Some inexpensive USB/serial adapters don't result
332            # in unique names here, though, so use of this feature
333            # requires explicitly setting persistent=True.
334            by_id = Path('/dev/serial/by-id')
335            def readlink(link):
336                return str((by_id / link).resolve())
337
338            if by_id.exists():
339                persistent_map = {readlink(link): str(link)
340                                  for link in by_id.iterdir()}
341            else:
342                persistent_map = {}
343        else:
344            persistent_map = {}
345
346        serial_devices = list_ports.comports()
347        logger.info("Scanning connected hardware...")
348        for d in serial_devices:
349            if (
350                d.manufacturer
351                and d.manufacturer.casefold() in [m.casefold() for m in self.manufacturer]
352            ):
353
354                # TI XDS110 can have multiple serial devices for a single board
355                # assume endpoint 0 is the serial, skip all others
356                if d.manufacturer == 'Texas Instruments' and not d.location.endswith('0'):
357                    continue
358
359                if d.product is None:
360                    d.product = 'unknown'
361
362                s_dev = DUT(platform="unknown",
363                                        id=d.serial_number,
364                                        serial=persistent_map.get(d.device, d.device),
365                                        product=d.product,
366                                        runner='unknown',
367                                        connected=True)
368
369                for runner, _ in self.runner_mapping.items():
370                    products = self.runner_mapping.get(runner)
371                    if d.product in products:
372                        s_dev.runner = runner
373                        continue
374                    # Try regex matching
375                    for p in products:
376                        if re.match(p, d.product):
377                            s_dev.runner = runner
378
379                s_dev.connected = True
380                s_dev.lock = None
381                self.detected.append(s_dev)
382            else:
383                logger.warning(f"Unsupported device ({d.manufacturer}): {d}")
384
385    def save(self, hwm_file):
386        # use existing map
387        self.detected = natsorted(self.detected, key=lambda x: x.serial or '')
388        if os.path.exists(hwm_file):
389            with open(hwm_file) as yaml_file:
390                hwm = yaml.load(yaml_file, Loader=SafeLoader)
391                if hwm:
392                    hwm.sort(key=lambda x: x.get('id', ''))
393
394                    # disconnect everything
395                    for h in hwm:
396                        h['connected'] = False
397                        h['serial'] = None
398
399                    for _detected in self.detected:
400                        for h in hwm:
401                            if all([
402                                _detected.id == h['id'],
403                                _detected.product == h['product'],
404                                _detected.match is False,
405                                h['connected'] is False
406                            ]):
407                                h['connected'] = True
408                                h['serial'] = _detected.serial
409                                _detected.match = True
410                                break
411
412                new_duts = list(filter(lambda d: not d.match, self.detected))
413                new = []
414                for d in new_duts:
415                    new.append(d.to_dict())
416
417                if hwm:
418                    hwm = hwm + new
419                else:
420                    hwm = new
421
422            with open(hwm_file, 'w') as yaml_file:
423                yaml.dump(hwm, yaml_file, Dumper=Dumper, default_flow_style=False)
424
425            self.load(hwm_file)
426            logger.info("Registered devices:")
427            self.dump()
428
429        else:
430            # create new file
431            dl = []
432            for _connected in self.detected:
433                platform  = _connected.platform
434                id = _connected.id
435                runner = _connected.runner
436                serial = _connected.serial
437                product = _connected.product
438                d = {
439                    'platform': platform,
440                    'id': id,
441                    'runner': runner,
442                    'serial': serial,
443                    'product': product,
444                    'connected': _connected.connected
445                }
446                dl.append(d)
447            with open(hwm_file, 'w') as yaml_file:
448                yaml.dump(dl, yaml_file, Dumper=Dumper, default_flow_style=False)
449            logger.info("Detected devices:")
450            self.dump(detected=True)
451
452    def dump(self, filtered=None, header=None, connected_only=False, detected=False):
453        if filtered is None:
454            filtered = []
455        if header is None:
456            header = []
457        print("")
458        table = []
459        if detected:
460            to_show = self.detected
461        else:
462            to_show = self.duts
463
464        if not header:
465            header = ["Platform", "ID", "Serial device"]
466        for p in to_show:
467            platform = p.platform
468            connected = p.connected
469            if filtered and platform not in filtered:
470                continue
471
472            if not connected_only or connected:
473                table.append([platform, p.id, p.serial])
474
475        print(tabulate(table, headers=header, tablefmt="github"))
476