1#!/usr/bin/env python3
2# vim: set syntax=python ts=4 :
3#
4# Copyright (c) 2022 Intel Corporation
5# SPDX-License-Identifier: Apache-2.0
6
7import os
8from multiprocessing import Lock, Value
9import re
10
11import platform
12import yaml
13import scl
14import logging
15from pathlib import Path
16from natsort import natsorted
17
18from twisterlib.environment import ZEPHYR_BASE
19
20try:
21    # Use the C LibYAML parser if available, rather than the Python parser.
22    # It's much faster.
23    from yaml import CSafeLoader as SafeLoader
24    from yaml import CDumper as Dumper
25except ImportError:
26    from yaml import SafeLoader, Dumper
27
28try:
29    from tabulate import tabulate
30except ImportError:
31    print("Install tabulate python module with pip to use --device-testing option.")
32
33logger = logging.getLogger('twister')
34logger.setLevel(logging.DEBUG)
35
36
37class DUT(object):
38    def __init__(self,
39                 id=None,
40                 serial=None,
41                 serial_baud=None,
42                 platform=None,
43                 product=None,
44                 serial_pty=None,
45                 connected=False,
46                 runner_params=None,
47                 pre_script=None,
48                 post_script=None,
49                 post_flash_script=None,
50                 runner=None,
51                 flash_timeout=60,
52                 flash_with_test=False):
53
54        self.serial = serial
55        self.baud = serial_baud or 115200
56        self.platform = platform
57        self.serial_pty = serial_pty
58        self._counter = Value("i", 0)
59        self._available = Value("i", 1)
60        self.connected = connected
61        self.pre_script = pre_script
62        self.id = id
63        self.product = product
64        self.runner = runner
65        self.runner_params = runner_params
66        self.fixtures = []
67        self.post_flash_script = post_flash_script
68        self.post_script = post_script
69        self.pre_script = pre_script
70        self.probe_id = None
71        self.notes = None
72        self.lock = Lock()
73        self.match = False
74        self.flash_timeout = flash_timeout
75        self.flash_with_test = flash_with_test
76
77    @property
78    def available(self):
79        with self._available.get_lock():
80            return self._available.value
81
82    @available.setter
83    def available(self, value):
84        with self._available.get_lock():
85            self._available.value = value
86
87    @property
88    def counter(self):
89        with self._counter.get_lock():
90            return self._counter.value
91
92    @counter.setter
93    def counter(self, value):
94        with self._counter.get_lock():
95            self._counter.value = value
96
97    def to_dict(self):
98        d = {}
99        exclude = ['_available', '_counter', 'match']
100        v = vars(self)
101        for k in v.keys():
102            if k not in exclude and v[k]:
103                d[k] = v[k]
104        return d
105
106
107    def __repr__(self):
108        return f"<{self.platform} ({self.product}) on {self.serial}>"
109
110class HardwareMap:
111    schema_path = os.path.join(ZEPHYR_BASE, "scripts", "schemas", "twister", "hwmap-schema.yaml")
112
113    manufacturer = [
114        'ARM',
115        'SEGGER',
116        'MBED',
117        'STMicroelectronics',
118        'Atmel Corp.',
119        'Texas Instruments',
120        'Silicon Labs',
121        'NXP Semiconductors',
122        'Microchip Technology Inc.',
123        'FTDI',
124        'Digilent',
125        'Microsoft'
126    ]
127
128    runner_mapping = {
129        'pyocd': [
130            'DAPLink CMSIS-DAP',
131            'MBED CMSIS-DAP'
132        ],
133        'jlink': [
134            'J-Link',
135            'J-Link OB'
136        ],
137        'openocd': [
138            'STM32 STLink', '^XDS110.*', 'STLINK-V3'
139        ],
140        'dediprog': [
141            'TTL232R-3V3',
142            'MCP2200 USB Serial Port Emulator'
143        ]
144    }
145
146    def __init__(self, env=None):
147        self.detected = []
148        self.duts = []
149        self.options = env.options
150
151    def discover(self):
152
153        if self.options.generate_hardware_map:
154            self.scan(persistent=self.options.persistent_hardware_map)
155            self.save(self.options.generate_hardware_map)
156            return 0
157
158        if not self.options.device_testing and self.options.hardware_map:
159            self.load(self.options.hardware_map)
160            logger.info("Available devices:")
161            self.dump(connected_only=True)
162            return 0
163
164        if self.options.device_testing:
165            if self.options.hardware_map:
166                self.load(self.options.hardware_map)
167                if not self.options.platform:
168                    self.options.platform = []
169                    for d in self.duts:
170                        if d.connected and d.platform != 'unknown':
171                            self.options.platform.append(d.platform)
172
173            elif self.options.device_serial:
174                self.add_device(self.options.device_serial,
175                                self.options.platform[0],
176                                self.options.pre_script,
177                                False,
178                                baud=self.options.device_serial_baud,
179                                flash_timeout=self.options.device_flash_timeout,
180                                flash_with_test=self.options.device_flash_with_test
181                                )
182
183            elif self.options.device_serial_pty:
184                self.add_device(self.options.device_serial_pty,
185                                self.options.platform[0],
186                                self.options.pre_script,
187                                True,
188                                flash_timeout=self.options.device_flash_timeout,
189                                flash_with_test=self.options.device_flash_with_test
190                                )
191
192            # the fixtures given by twister command explicitly should be assigned to each DUT
193            if self.options.fixture:
194                for d in self.duts:
195                    d.fixtures.extend(self.options.fixture)
196        return 1
197
198
199    def summary(self, selected_platforms):
200        print("\nHardware distribution summary:\n")
201        table = []
202        header = ['Board', 'ID', 'Counter']
203        for d in self.duts:
204            if d.connected and d.platform in selected_platforms:
205                row = [d.platform, d.id, d.counter]
206                table.append(row)
207        print(tabulate(table, headers=header, tablefmt="github"))
208
209
210    def add_device(self, serial, platform, pre_script, is_pty, baud=None, flash_timeout=60, flash_with_test=False):
211        device = DUT(platform=platform, connected=True, pre_script=pre_script, serial_baud=baud,
212                     flash_timeout=flash_timeout, flash_with_test=flash_with_test
213                    )
214        if is_pty:
215            device.serial_pty = serial
216        else:
217            device.serial = serial
218
219        self.duts.append(device)
220
221    def load(self, map_file):
222        hwm_schema = scl.yaml_load(self.schema_path)
223        duts = scl.yaml_load_verify(map_file, hwm_schema)
224        for dut in duts:
225            pre_script = dut.get('pre_script')
226            post_script = dut.get('post_script')
227            post_flash_script = dut.get('post_flash_script')
228            flash_timeout = dut.get('flash_timeout') or self.options.device_flash_timeout
229            flash_with_test = dut.get('flash_with_test')
230            if flash_with_test is None:
231                flash_with_test = self.options.device_flash_with_test
232            platform  = dut.get('platform')
233            id = dut.get('id')
234            runner = dut.get('runner')
235            runner_params = dut.get('runner_params')
236            serial_pty = dut.get('serial_pty')
237            serial = dut.get('serial')
238            baud = dut.get('baud', None)
239            product = dut.get('product')
240            fixtures = dut.get('fixtures', [])
241            connected= dut.get('connected') and ((serial or serial_pty) is not None)
242            if not connected:
243                continue
244            new_dut = DUT(platform=platform,
245                          product=product,
246                          runner=runner,
247                          runner_params=runner_params,
248                          id=id,
249                          serial_pty=serial_pty,
250                          serial=serial,
251                          serial_baud=baud,
252                          connected=connected,
253                          pre_script=pre_script,
254                          post_script=post_script,
255                          post_flash_script=post_flash_script,
256                          flash_timeout=flash_timeout,
257                          flash_with_test=flash_with_test)
258            new_dut.fixtures = fixtures
259            new_dut.counter = 0
260            self.duts.append(new_dut)
261
262    def scan(self, persistent=False):
263        from serial.tools import list_ports
264
265        if persistent and platform.system() == 'Linux':
266            # On Linux, /dev/serial/by-id provides symlinks to
267            # '/dev/ttyACMx' nodes using names which are unique as
268            # long as manufacturers fill out USB metadata nicely.
269            #
270            # This creates a map from '/dev/ttyACMx' device nodes
271            # to '/dev/serial/by-id/usb-...' symlinks. The symlinks
272            # go into the hardware map because they stay the same
273            # even when the user unplugs / replugs the device.
274            #
275            # Some inexpensive USB/serial adapters don't result
276            # in unique names here, though, so use of this feature
277            # requires explicitly setting persistent=True.
278            by_id = Path('/dev/serial/by-id')
279            def readlink(link):
280                return str((by_id / link).resolve())
281
282            persistent_map = {readlink(link): str(link)
283                              for link in by_id.iterdir()}
284        else:
285            persistent_map = {}
286
287        serial_devices = list_ports.comports()
288        logger.info("Scanning connected hardware...")
289        for d in serial_devices:
290            if d.manufacturer and d.manufacturer.casefold() in [m.casefold() for m in self.manufacturer]:
291
292                # TI XDS110 can have multiple serial devices for a single board
293                # assume endpoint 0 is the serial, skip all others
294                if d.manufacturer == 'Texas Instruments' and not d.location.endswith('0'):
295                    continue
296
297                if d.product is None:
298                    d.product = 'unknown'
299
300                s_dev = DUT(platform="unknown",
301                                        id=d.serial_number,
302                                        serial=persistent_map.get(d.device, d.device),
303                                        product=d.product,
304                                        runner='unknown',
305                                        connected=True)
306
307                for runner, _ in self.runner_mapping.items():
308                    products = self.runner_mapping.get(runner)
309                    if d.product in products:
310                        s_dev.runner = runner
311                        continue
312                    # Try regex matching
313                    for p in products:
314                        if re.match(p, d.product):
315                            s_dev.runner = runner
316
317                s_dev.connected = True
318                s_dev.lock = None
319                self.detected.append(s_dev)
320            else:
321                logger.warning("Unsupported device (%s): %s" % (d.manufacturer, d))
322
323    def save(self, hwm_file):
324        # use existing map
325        self.detected = natsorted(self.detected, key=lambda x: x.serial or '')
326        if os.path.exists(hwm_file):
327            with open(hwm_file, 'r') as yaml_file:
328                hwm = yaml.load(yaml_file, Loader=SafeLoader)
329                if hwm:
330                    hwm.sort(key=lambda x: x.get('id', ''))
331
332                    # disconnect everything
333                    for h in hwm:
334                        h['connected'] = False
335                        h['serial'] = None
336
337                    for _detected in self.detected:
338                        for h in hwm:
339                            if all([
340                                _detected.id == h['id'],
341                                _detected.product == h['product'],
342                                _detected.match is False,
343                                h['connected'] is False
344                            ]):
345                                h['connected'] = True
346                                h['serial'] = _detected.serial
347                                _detected.match = True
348                                break
349
350                new_duts = list(filter(lambda d: not d.match, self.detected))
351                new = []
352                for d in new_duts:
353                    new.append(d.to_dict())
354
355                if hwm:
356                    hwm = hwm + new
357                else:
358                    hwm = new
359
360            with open(hwm_file, 'w') as yaml_file:
361                yaml.dump(hwm, yaml_file, Dumper=Dumper, default_flow_style=False)
362
363            self.load(hwm_file)
364            logger.info("Registered devices:")
365            self.dump()
366
367        else:
368            # create new file
369            dl = []
370            for _connected in self.detected:
371                platform  = _connected.platform
372                id = _connected.id
373                runner = _connected.runner
374                serial = _connected.serial
375                product = _connected.product
376                d = {
377                    'platform': platform,
378                    'id': id,
379                    'runner': runner,
380                    'serial': serial,
381                    'product': product,
382                    'connected': _connected.connected
383                }
384                dl.append(d)
385            with open(hwm_file, 'w') as yaml_file:
386                yaml.dump(dl, yaml_file, Dumper=Dumper, default_flow_style=False)
387            logger.info("Detected devices:")
388            self.dump(detected=True)
389
390    def dump(self, filtered=[], header=[], connected_only=False, detected=False):
391        print("")
392        table = []
393        if detected:
394            to_show = self.detected
395        else:
396            to_show = self.duts
397
398        if not header:
399            header = ["Platform", "ID", "Serial device"]
400        for p in to_show:
401            platform = p.platform
402            connected = p.connected
403            if filtered and platform not in filtered:
404                continue
405
406            if not connected_only or connected:
407                table.append([platform, p.id, p.serial])
408
409        print(tabulate(table, headers=header, tablefmt="github"))
410