1#!/usr/bin/env python3
2# vim: set syntax=python ts=4 :
3#
4# Copyright (c) 2022 Intel Corporation
5# SPDX-License-Identifier: Apache-2.0
6
7import os
8from multiprocessing import Lock, Value
9import re
10
11import platform
12import yaml
13import scl
14import logging
15from pathlib import Path
16from natsort import natsorted
17
18from twisterlib.environment import ZEPHYR_BASE
19
20try:
21    # Use the C LibYAML parser if available, rather than the Python parser.
22    # It's much faster.
23    from yaml import CSafeLoader as SafeLoader
24    from yaml import CDumper as Dumper
25except ImportError:
26    from yaml import SafeLoader, Dumper
27
28try:
29    from tabulate import tabulate
30except ImportError:
31    print("Install tabulate python module with pip to use --device-testing option.")
32
33logger = logging.getLogger('twister')
34logger.setLevel(logging.DEBUG)
35
36
37class DUT(object):
38    def __init__(self,
39                 id=None,
40                 serial=None,
41                 serial_baud=None,
42                 platform=None,
43                 product=None,
44                 serial_pty=None,
45                 connected=False,
46                 runner_params=None,
47                 pre_script=None,
48                 post_script=None,
49                 post_flash_script=None,
50                 runner=None,
51                 flash_timeout=60,
52                 flash_with_test=False,
53                 flash_before=False):
54
55        self.serial = serial
56        self.baud = serial_baud or 115200
57        self.platform = platform
58        self.serial_pty = serial_pty
59        self._counter = Value("i", 0)
60        self._available = Value("i", 1)
61        self.connected = connected
62        self.pre_script = pre_script
63        self.id = id
64        self.product = product
65        self.runner = runner
66        self.runner_params = runner_params
67        self.flash_before = flash_before
68        self.fixtures = []
69        self.post_flash_script = post_flash_script
70        self.post_script = post_script
71        self.pre_script = pre_script
72        self.probe_id = None
73        self.notes = None
74        self.lock = Lock()
75        self.match = False
76        self.flash_timeout = flash_timeout
77        self.flash_with_test = flash_with_test
78
79    @property
80    def available(self):
81        with self._available.get_lock():
82            return self._available.value
83
84    @available.setter
85    def available(self, value):
86        with self._available.get_lock():
87            self._available.value = value
88
89    @property
90    def counter(self):
91        with self._counter.get_lock():
92            return self._counter.value
93
94    @counter.setter
95    def counter(self, value):
96        with self._counter.get_lock():
97            self._counter.value = value
98
99    def to_dict(self):
100        d = {}
101        exclude = ['_available', '_counter', 'match']
102        v = vars(self)
103        for k in v.keys():
104            if k not in exclude and v[k]:
105                d[k] = v[k]
106        return d
107
108
109    def __repr__(self):
110        return f"<{self.platform} ({self.product}) on {self.serial}>"
111
112class HardwareMap:
113    schema_path = os.path.join(ZEPHYR_BASE, "scripts", "schemas", "twister", "hwmap-schema.yaml")
114
115    manufacturer = [
116        'ARM',
117        'SEGGER',
118        'MBED',
119        'STMicroelectronics',
120        'Atmel Corp.',
121        'Texas Instruments',
122        'Silicon Labs',
123        'NXP',
124        'NXP Semiconductors',
125        'Microchip Technology Inc.',
126        'FTDI',
127        'Digilent',
128        'Microsoft'
129    ]
130
131    runner_mapping = {
132        'pyocd': [
133            'DAPLink CMSIS-DAP',
134            'MBED CMSIS-DAP'
135        ],
136        'jlink': [
137            'J-Link',
138            'J-Link OB'
139        ],
140        'openocd': [
141            'STM32 STLink', '^XDS110.*', 'STLINK-V3'
142        ],
143        'dediprog': [
144            'TTL232R-3V3',
145            'MCP2200 USB Serial Port Emulator'
146        ]
147    }
148
149    def __init__(self, env=None):
150        self.detected = []
151        self.duts = []
152        self.options = env.options
153
154    def discover(self):
155
156        if self.options.generate_hardware_map:
157            self.scan(persistent=self.options.persistent_hardware_map)
158            self.save(self.options.generate_hardware_map)
159            return 0
160
161        if not self.options.device_testing and self.options.hardware_map:
162            self.load(self.options.hardware_map)
163            logger.info("Available devices:")
164            self.dump(connected_only=True)
165            return 0
166
167        if self.options.device_testing:
168            if self.options.hardware_map:
169                self.load(self.options.hardware_map)
170                if not self.options.platform:
171                    self.options.platform = []
172                    for d in self.duts:
173                        if d.connected and d.platform != 'unknown':
174                            self.options.platform.append(d.platform)
175
176            elif self.options.device_serial:
177                self.add_device(self.options.device_serial,
178                                self.options.platform[0],
179                                self.options.pre_script,
180                                False,
181                                baud=self.options.device_serial_baud,
182                                flash_timeout=self.options.device_flash_timeout,
183                                flash_with_test=self.options.device_flash_with_test,
184                                flash_before=self.options.flash_before,
185                                )
186
187            elif self.options.device_serial_pty:
188                self.add_device(self.options.device_serial_pty,
189                                self.options.platform[0],
190                                self.options.pre_script,
191                                True,
192                                flash_timeout=self.options.device_flash_timeout,
193                                flash_with_test=self.options.device_flash_with_test,
194                                flash_before=False,
195                                )
196
197            # the fixtures given by twister command explicitly should be assigned to each DUT
198            if self.options.fixture:
199                for d in self.duts:
200                    d.fixtures.extend(self.options.fixture)
201        return 1
202
203
204    def summary(self, selected_platforms):
205        print("\nHardware distribution summary:\n")
206        table = []
207        header = ['Board', 'ID', 'Counter']
208        for d in self.duts:
209            if d.connected and d.platform in selected_platforms:
210                row = [d.platform, d.id, d.counter]
211                table.append(row)
212        print(tabulate(table, headers=header, tablefmt="github"))
213
214
215    def add_device(self, serial, platform, pre_script, is_pty, baud=None, flash_timeout=60, flash_with_test=False, flash_before=False):
216        device = DUT(platform=platform, connected=True, pre_script=pre_script, serial_baud=baud,
217                     flash_timeout=flash_timeout, flash_with_test=flash_with_test, flash_before=flash_before
218                    )
219        if is_pty:
220            device.serial_pty = serial
221        else:
222            device.serial = serial
223
224        self.duts.append(device)
225
226    def load(self, map_file):
227        hwm_schema = scl.yaml_load(self.schema_path)
228        duts = scl.yaml_load_verify(map_file, hwm_schema)
229        for dut in duts:
230            pre_script = dut.get('pre_script')
231            post_script = dut.get('post_script')
232            post_flash_script = dut.get('post_flash_script')
233            flash_timeout = dut.get('flash_timeout') or self.options.device_flash_timeout
234            flash_with_test = dut.get('flash_with_test')
235            if flash_with_test is None:
236                flash_with_test = self.options.device_flash_with_test
237            flash_before = dut.get('flash_before')
238            if flash_before is None:
239                flash_before = self.options.flash_before and (not (flash_with_test or serial_pty))
240            platform  = dut.get('platform')
241            id = dut.get('id')
242            runner = dut.get('runner')
243            runner_params = dut.get('runner_params')
244            serial_pty = dut.get('serial_pty')
245            serial = dut.get('serial')
246            baud = dut.get('baud', None)
247            product = dut.get('product')
248            fixtures = dut.get('fixtures', [])
249            connected= dut.get('connected') and ((serial or serial_pty) is not None)
250            if not connected:
251                continue
252            new_dut = DUT(platform=platform,
253                          product=product,
254                          runner=runner,
255                          runner_params=runner_params,
256                          id=id,
257                          serial_pty=serial_pty,
258                          serial=serial,
259                          serial_baud=baud,
260                          connected=connected,
261                          pre_script=pre_script,
262                          flash_before=flash_before,
263                          post_script=post_script,
264                          post_flash_script=post_flash_script,
265                          flash_timeout=flash_timeout,
266                          flash_with_test=flash_with_test)
267            new_dut.fixtures = fixtures
268            new_dut.counter = 0
269            self.duts.append(new_dut)
270
271    def scan(self, persistent=False):
272        from serial.tools import list_ports
273
274        if persistent and platform.system() == 'Linux':
275            # On Linux, /dev/serial/by-id provides symlinks to
276            # '/dev/ttyACMx' nodes using names which are unique as
277            # long as manufacturers fill out USB metadata nicely.
278            #
279            # This creates a map from '/dev/ttyACMx' device nodes
280            # to '/dev/serial/by-id/usb-...' symlinks. The symlinks
281            # go into the hardware map because they stay the same
282            # even when the user unplugs / replugs the device.
283            #
284            # Some inexpensive USB/serial adapters don't result
285            # in unique names here, though, so use of this feature
286            # requires explicitly setting persistent=True.
287            by_id = Path('/dev/serial/by-id')
288            def readlink(link):
289                return str((by_id / link).resolve())
290
291            if by_id.exists():
292                persistent_map = {readlink(link): str(link)
293                                  for link in by_id.iterdir()}
294            else:
295                persistent_map = {}
296        else:
297            persistent_map = {}
298
299        serial_devices = list_ports.comports()
300        logger.info("Scanning connected hardware...")
301        for d in serial_devices:
302            if d.manufacturer and d.manufacturer.casefold() in [m.casefold() for m in self.manufacturer]:
303
304                # TI XDS110 can have multiple serial devices for a single board
305                # assume endpoint 0 is the serial, skip all others
306                if d.manufacturer == 'Texas Instruments' and not d.location.endswith('0'):
307                    continue
308
309                if d.product is None:
310                    d.product = 'unknown'
311
312                s_dev = DUT(platform="unknown",
313                                        id=d.serial_number,
314                                        serial=persistent_map.get(d.device, d.device),
315                                        product=d.product,
316                                        runner='unknown',
317                                        connected=True)
318
319                for runner, _ in self.runner_mapping.items():
320                    products = self.runner_mapping.get(runner)
321                    if d.product in products:
322                        s_dev.runner = runner
323                        continue
324                    # Try regex matching
325                    for p in products:
326                        if re.match(p, d.product):
327                            s_dev.runner = runner
328
329                s_dev.connected = True
330                s_dev.lock = None
331                self.detected.append(s_dev)
332            else:
333                logger.warning("Unsupported device (%s): %s" % (d.manufacturer, d))
334
335    def save(self, hwm_file):
336        # use existing map
337        self.detected = natsorted(self.detected, key=lambda x: x.serial or '')
338        if os.path.exists(hwm_file):
339            with open(hwm_file, 'r') as yaml_file:
340                hwm = yaml.load(yaml_file, Loader=SafeLoader)
341                if hwm:
342                    hwm.sort(key=lambda x: x.get('id', ''))
343
344                    # disconnect everything
345                    for h in hwm:
346                        h['connected'] = False
347                        h['serial'] = None
348
349                    for _detected in self.detected:
350                        for h in hwm:
351                            if all([
352                                _detected.id == h['id'],
353                                _detected.product == h['product'],
354                                _detected.match is False,
355                                h['connected'] is False
356                            ]):
357                                h['connected'] = True
358                                h['serial'] = _detected.serial
359                                _detected.match = True
360                                break
361
362                new_duts = list(filter(lambda d: not d.match, self.detected))
363                new = []
364                for d in new_duts:
365                    new.append(d.to_dict())
366
367                if hwm:
368                    hwm = hwm + new
369                else:
370                    hwm = new
371
372            with open(hwm_file, 'w') as yaml_file:
373                yaml.dump(hwm, yaml_file, Dumper=Dumper, default_flow_style=False)
374
375            self.load(hwm_file)
376            logger.info("Registered devices:")
377            self.dump()
378
379        else:
380            # create new file
381            dl = []
382            for _connected in self.detected:
383                platform  = _connected.platform
384                id = _connected.id
385                runner = _connected.runner
386                serial = _connected.serial
387                product = _connected.product
388                d = {
389                    'platform': platform,
390                    'id': id,
391                    'runner': runner,
392                    'serial': serial,
393                    'product': product,
394                    'connected': _connected.connected
395                }
396                dl.append(d)
397            with open(hwm_file, 'w') as yaml_file:
398                yaml.dump(dl, yaml_file, Dumper=Dumper, default_flow_style=False)
399            logger.info("Detected devices:")
400            self.dump(detected=True)
401
402    def dump(self, filtered=[], header=[], connected_only=False, detected=False):
403        print("")
404        table = []
405        if detected:
406            to_show = self.detected
407        else:
408            to_show = self.duts
409
410        if not header:
411            header = ["Platform", "ID", "Serial device"]
412        for p in to_show:
413            platform = p.platform
414            connected = p.connected
415            if filtered and platform not in filtered:
416                continue
417
418            if not connected_only or connected:
419                table.append([platform, p.id, p.serial])
420
421        print(tabulate(table, headers=header, tablefmt="github"))
422