1#!/usr/bin/env python3
2#
3#  Copyright (c) 2019, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import binascii
31import json
32import logging
33import os
34import signal
35import subprocess
36import sys
37import time
38import traceback
39import unittest
40from typing import Optional, Callable
41
42import config
43import debug
44from node import Node, OtbrNode, HostNode
45from pktverify import utils as pvutils
46
47PACKET_VERIFICATION = int(os.getenv('PACKET_VERIFICATION', 0))
48
49if PACKET_VERIFICATION:
50    from pktverify.addrs import ExtAddr, EthAddr
51    from pktverify.packet_verifier import PacketVerifier
52
53PORT_OFFSET = int(os.getenv('PORT_OFFSET', "0"))
54
55ENV_THREAD_VERSION = os.getenv('THREAD_VERSION', '1.1')
56
57DEFAULT_PARAMS = {
58    'is_mtd': False,
59    'is_ftd': False,
60    'is_bbr': False,
61    'is_otbr': False,
62    'is_host': False,
63    'mode': 'rdn',
64    'allowlist': None,
65    'version': ENV_THREAD_VERSION,
66    'panid': 0xface,
67}
68"""Default configurations when creating nodes."""
69
70FTD_DEFAULT_PARAMS = {
71    'is_ftd': True,
72    'router_selection_jitter': config.DEFAULT_ROUTER_SELECTION_JITTER,
73}
74
75EXTENDED_ADDRESS_BASE = 0x166e0a0000000000
76"""Extended address base to keep U/L bit 1. The value is borrowed from Thread Test Harness."""
77
78
79class NcpSupportMixin():
80    """ The mixin to check whether a test case supports NCP.
81    """
82
83    SUPPORT_NCP = True
84
85    def __init__(self, *args, **kwargs):
86        if os.getenv('NODE_TYPE', 'sim') == 'ncp-sim' and not self.SUPPORT_NCP:
87            # 77 means skip this test case in automake tests
88            sys.exit(77)
89
90        super().__init__(*args, **kwargs)
91
92
93class TestCase(NcpSupportMixin, unittest.TestCase):
94    """The base class for all thread certification test cases.
95
96    The `topology` member of sub-class is used to create test topology.
97    """
98
99    USE_MESSAGE_FACTORY = True
100    TOPOLOGY = None
101    CASE_WIRESHARK_PREFS = None
102    SUPPORT_THREAD_1_1 = True
103
104    def __init__(self, *args, **kwargs):
105        super().__init__(*args, **kwargs)
106
107        logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
108
109        self._start_time = None
110        self._do_packet_verification = PACKET_VERIFICATION and hasattr(self, 'verify')
111
112    def setUp(self):
113        if ENV_THREAD_VERSION == '1.1' and not self.SUPPORT_THREAD_1_1:
114            self.skipTest('Thread 1.1 not supported.')
115
116        try:
117            self._setUp()
118        except:
119            traceback.print_exc()
120            for node in list(self.nodes.values()):
121                try:
122                    node.destroy()
123                except Exception:
124                    traceback.print_exc()
125
126            raise
127
128    def _setUp(self):
129        """Create simulator, nodes and apply configurations.
130        """
131        self._clean_up_tmp()
132
133        self.simulator = config.create_default_simulator(use_message_factory=self.USE_MESSAGE_FACTORY)
134        self.nodes = {}
135
136        os.environ['LD_LIBRARY_PATH'] = '/tmp/thread-wireshark'
137
138        if self._has_backbone_traffic():
139            self._prepare_backbone_network()
140            self._start_backbone_sniffer()
141
142        self._initial_topology = initial_topology = {}
143
144        for i, params in self.TOPOLOGY.items():
145            params = self._parse_params(params)
146            initial_topology[i] = params
147
148            logging.info("Creating node %d: %r", i, params)
149
150            if params['is_otbr']:
151                nodeclass = OtbrNode
152            elif params['is_host']:
153                nodeclass = HostNode
154            else:
155                nodeclass = Node
156
157            node = nodeclass(
158                i,
159                is_mtd=params['is_mtd'],
160                simulator=self.simulator,
161                name=params.get('name'),
162                version=params['version'],
163                is_bbr=params['is_bbr'],
164            )
165
166            self.nodes[i] = node
167
168            if node.is_host:
169                continue
170
171            self.nodes[i].set_networkkey(binascii.hexlify(config.DEFAULT_NETWORK_KEY).decode())
172            self.nodes[i].set_panid(params['panid'])
173            self.nodes[i].set_mode(params['mode'])
174
175            if 'partition_id' in params:
176                self.nodes[i].set_preferred_partition_id(params['partition_id'])
177            if 'channel' in params:
178                self.nodes[i].set_channel(params['channel'])
179            if 'networkkey' in params:
180                self.nodes[i].set_networkkey(params['networkkey'])
181            if 'network_name' in params:
182                self.nodes[i].set_network_name(params['network_name'])
183
184            if params['is_ftd']:
185                self.nodes[i].set_router_selection_jitter(params['router_selection_jitter'])
186
187            if 'router_upgrade_threshold' in params:
188                self.nodes[i].set_router_upgrade_threshold(params['router_upgrade_threshold'])
189            if 'router_downgrade_threshold' in params:
190                self.nodes[i].set_router_downgrade_threshold(params['router_downgrade_threshold'])
191            if 'router_eligible' in params:
192                self.nodes[i].set_router_eligible(params['router_eligible'])
193            if 'prefer_router_id' in params:
194                self.nodes[i].prefer_router_id(params['prefer_router_id'])
195
196            if 'timeout' in params:
197                self.nodes[i].set_timeout(params['timeout'])
198
199            if 'active_dataset' in params:
200                if 'network_key' not in params['active_dataset']:
201                    params['active_dataset']['network_key'] = binascii.hexlify(config.DEFAULT_NETWORK_KEY).decode()
202                self.nodes[i].set_active_dataset(params['active_dataset']['timestamp'],
203                                                 panid=params['active_dataset'].get('panid'),
204                                                 channel=params['active_dataset'].get('channel'),
205                                                 channel_mask=params['active_dataset'].get('channel_mask'),
206                                                 network_key=params['active_dataset'].get('network_key'),
207                                                 security_policy=params['active_dataset'].get('security_policy'))
208
209            if 'pending_dataset' in params:
210                self.nodes[i].set_pending_dataset(params['pending_dataset']['pendingtimestamp'],
211                                                  params['pending_dataset']['activetimestamp'],
212                                                  panid=params['pending_dataset'].get('panid'),
213                                                  channel=params['pending_dataset'].get('channel'),
214                                                  delay=params['pending_dataset'].get('delay'))
215
216            if 'key_switch_guardtime' in params:
217                self.nodes[i].set_key_switch_guardtime(params['key_switch_guardtime'])
218            if 'key_sequence_counter' in params:
219                self.nodes[i].set_key_sequence_counter(params['key_sequence_counter'])
220
221            if 'network_id_timeout' in params:
222                self.nodes[i].set_network_id_timeout(params['network_id_timeout'])
223
224            if 'context_reuse_delay' in params:
225                self.nodes[i].set_context_reuse_delay(params['context_reuse_delay'])
226
227            if 'max_children' in params:
228                self.nodes[i].set_max_children(params['max_children'])
229
230            if 'bbr_registration_jitter' in params:
231                self.nodes[i].set_bbr_registration_jitter(params['bbr_registration_jitter'])
232
233        # we have to add allowlist after nodes are all created
234        for i, params in initial_topology.items():
235            allowlist = params['allowlist']
236            if not allowlist:
237                continue
238
239            for j in allowlist:
240                rssi = None
241                if isinstance(j, tuple):
242                    j, rssi = j
243                self.nodes[i].add_allowlist(self.nodes[j].get_addr64(), rssi=rssi)
244            self.nodes[i].enable_allowlist()
245
246        self._inspector = debug.Inspector(self)
247        self._collect_test_info_after_setup()
248
249    def inspect(self):
250        self._inspector.inspect()
251
252    def tearDown(self):
253        """Destroy nodes and simulator.
254        """
255        if self._do_packet_verification and os.uname().sysname != "Linux":
256            raise NotImplementedError(
257                f'{self.test_name}: Packet Verification not available on {os.uname().sysname} (Linux only).')
258
259        if self._do_packet_verification:
260            self.simulator.go(3)
261
262        if self._has_backbone_traffic():
263            # Stop Backbone sniffer before stopping nodes so that we don't capture Codecov Uploading traffic
264            self._stop_backbone_sniffer()
265
266        for node in list(self.nodes.values()):
267            node.stop()
268            node.destroy()
269
270        self.simulator.stop()
271
272        if self._has_backbone_traffic():
273            self._remove_backbone_network()
274            pcap_filename = self._merge_thread_backbone_pcaps()
275        else:
276            pcap_filename = self._get_thread_pcap_filename()
277
278        if self._do_packet_verification:
279            self._test_info['pcap'] = pcap_filename
280
281            test_info_path = self._output_test_info()
282            self._verify_packets(test_info_path)
283
284    def flush_all(self):
285        """Flush away all captured messages of all nodes.
286        """
287        for i in list(self.nodes.keys()):
288            self.simulator.get_messages_sent_by(i)
289
290    def flush_nodes(self, nodes):
291        """Flush away all captured messages of specified nodes.
292
293        Args:
294            nodes (list): nodes whose messages to flush.
295
296        """
297        for i in nodes:
298            if i in list(self.nodes.keys()):
299                self.simulator.get_messages_sent_by(i)
300
301    def _clean_up_tmp(self):
302        """
303        Clean up node files in tmp directory
304        """
305        os.system(f"rm -f tmp/{PORT_OFFSET}_*.flash tmp/{PORT_OFFSET}_*.data tmp/{PORT_OFFSET}_*.swap")
306
307    def _verify_packets(self, test_info_path: str):
308        pv = PacketVerifier(test_info_path, self.CASE_WIRESHARK_PREFS)
309        pv.add_common_vars()
310        self.verify(pv)
311        print("Packet verification passed: %s" % test_info_path, file=sys.stderr)
312
313    @property
314    def test_name(self):
315        return os.getenv('TEST_NAME', 'current')
316
317    def collect_ipaddrs(self):
318        if not self._do_packet_verification:
319            return
320
321        test_info = self._test_info
322
323        for i, node in self.nodes.items():
324            ipaddrs = node.get_addrs()
325            test_info['ipaddrs'][i] = ipaddrs
326            if not node.is_host:
327                mleid = node.get_mleid()
328                test_info['mleids'][i] = mleid
329
330    def collect_rloc16s(self):
331        if not self._do_packet_verification:
332            return
333
334        test_info = self._test_info
335        test_info['rloc16s'] = {}
336
337        for i, node in self.nodes.items():
338            if not node.is_host:
339                test_info['rloc16s'][i] = '0x%04x' % node.get_addr16()
340
341    def collect_rlocs(self):
342        if not self._do_packet_verification:
343            return
344
345        test_info = self._test_info
346        test_info['rlocs'] = {}
347
348        for i, node in self.nodes.items():
349            if node.is_host:
350                continue
351
352            test_info['rlocs'][i] = node.get_rloc()
353
354    def collect_leader_aloc(self, node):
355        if not self._do_packet_verification:
356            return
357
358        test_info = self._test_info
359        test_info['leader_aloc'] = self.nodes[node].get_addr_leader_aloc()
360
361    def collect_extra_vars(self, **vars):
362        if not self._do_packet_verification:
363            return
364
365        for k in vars.keys():
366            assert isinstance(k, str), k
367
368        test_vars = self._test_info.setdefault("extra_vars", {})
369        test_vars.update(vars)
370
371    def _collect_test_info_after_setup(self):
372        """
373        Collect test info after setUp
374        """
375        if not self._do_packet_verification:
376            return
377
378        test_info = self._test_info = {
379            'script': os.path.abspath(sys.argv[0]),
380            'testcase': self.test_name,
381            'start_time': time.ctime(self._start_time),
382            'pcap': '',
383            'extaddrs': {},
384            'ethaddrs': {},
385            'ipaddrs': {},
386            'mleids': {},
387            'topology': self._initial_topology,
388            'backbone': {
389                'interface': config.BACKBONE_DOCKER_NETWORK_NAME,
390                'prefix': config.BACKBONE_PREFIX,
391            },
392            'domain_prefix': config.DOMAIN_PREFIX,
393            'env': {
394                'PORT_OFFSET': config.PORT_OFFSET,
395            },
396        }
397
398        for i, node in self.nodes.items():
399            if not node.is_host:
400                extaddr = node.get_addr64()
401                test_info['extaddrs'][i] = ExtAddr(extaddr).format_octets()
402
403            if node.is_host or node.is_otbr:
404                ethaddr = node.get_ether_mac()
405                test_info['ethaddrs'][i] = EthAddr(ethaddr).format_octets()
406
407    def _output_test_info(self):
408        """
409        Output test info to json file after tearDown
410        """
411        filename = f'{self.test_name}.json'
412        with open(filename, 'wt') as ofd:
413            ofd.write(json.dumps(self._test_info, indent=1, sort_keys=True))
414
415        return filename
416
417    def _get_thread_pcap_filename(self):
418        current_pcap = self.test_name + '.pcap'
419        return os.path.abspath(current_pcap)
420
421    def assure_run_ok(self, cmd, shell=False):
422        if not shell and isinstance(cmd, str):
423            cmd = cmd.split()
424        proc = subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr, shell=shell)
425        print(">>> %s => %d" % (cmd, proc.returncode), file=sys.stderr)
426        proc.check_returncode()
427
428    def _parse_params(self, params: Optional[dict]) -> dict:
429        params = params or {}
430
431        if params.get('is_bbr') or params.get('is_otbr'):
432            # BBRs must use thread version 1.2
433            assert params.get('version', '1.2') == '1.2', params
434            params['version'] = '1.2'
435            params.setdefault('bbr_registration_jitter', config.DEFAULT_BBR_REGISTRATION_JITTER)
436        elif params.get('is_host'):
437            # Hosts must not specify thread version
438            assert params.get('version', '') == '', params
439            params['version'] = ''
440
441        is_ftd = (not params.get('is_mtd') and not params.get('is_host'))
442
443        effective_params = DEFAULT_PARAMS.copy()
444
445        if is_ftd:
446            effective_params.update(FTD_DEFAULT_PARAMS)
447
448        effective_params.update(params)
449
450        return effective_params
451
452    def _has_backbone_traffic(self):
453        for param in self.TOPOLOGY.values():
454            if param and (param.get('is_otbr') or param.get('is_host')):
455                return True
456
457        return False
458
459    def _prepare_backbone_network(self):
460        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
461        self.assure_run_ok(
462            f'docker network create --driver bridge --ipv6 --subnet {config.BACKBONE_PREFIX} -o "com.docker.network.bridge.name"="{network_name}" {network_name} || true',
463            shell=True)
464
465    def _remove_backbone_network(self):
466        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
467        self.assure_run_ok(f'docker network rm {network_name}', shell=True)
468
469    def _start_backbone_sniffer(self):
470        # don't know why but I have to create the empty bbr.pcap first, otherwise tshark won't work
471        # self.assure_run_ok("truncate --size 0 bbr.pcap && chmod 664 bbr.pcap", shell=True)
472        pcap_file = self._get_backbone_pcap_filename()
473        try:
474            os.remove(pcap_file)
475        except FileNotFoundError:
476            pass
477
478        dumpcap = pvutils.which_dumpcap()
479        self._dumpcap_proc = subprocess.Popen([dumpcap, '-i', config.BACKBONE_DOCKER_NETWORK_NAME, '-w', pcap_file],
480                                              stdout=sys.stdout,
481                                              stderr=sys.stderr)
482        time.sleep(0.2)
483        assert self._dumpcap_proc.poll() is None, 'tshark terminated unexpectedly'
484        logging.info('Backbone sniffer launched successfully: pid=%s', self._dumpcap_proc.pid)
485
486    def _get_backbone_pcap_filename(self):
487        backbone_pcap = self.test_name + '_backbone.pcap'
488        return os.path.abspath(backbone_pcap)
489
490    def _get_merged_pcap_filename(self):
491        backbone_pcap = self.test_name + '_merged.pcap'
492        return os.path.abspath(backbone_pcap)
493
494    def _stop_backbone_sniffer(self):
495        self._dumpcap_proc.send_signal(signal.SIGTERM)
496        self._dumpcap_proc.__exit__(None, None, None)
497        logging.info('Backbone sniffer terminated successfully: pid=%s' % self._dumpcap_proc.pid)
498
499    def _merge_thread_backbone_pcaps(self):
500        thread_pcap = self._get_thread_pcap_filename()
501        backbone_pcap = self._get_backbone_pcap_filename()
502        merged_pcap = self._get_merged_pcap_filename()
503
504        mergecap = pvutils.which_mergecap()
505        self.assure_run_ok(f'{mergecap} -w {merged_pcap} {thread_pcap} {backbone_pcap}', shell=True)
506        return merged_pcap
507
508    def wait_until(self, cond: Callable[[], bool], timeout: int, go_interval: int = 1):
509        while True:
510            self.simulator.go(go_interval)
511
512            if cond():
513                break
514
515            timeout -= go_interval
516            if timeout <= 0:
517                raise RuntimeError(f'wait failed after {timeout} seconds')
518
519    def wait_node_state(self, nodeid: int, state: str, timeout: int):
520        self.wait_until(lambda: self.nodes[nodeid].get_state() == state, timeout)
521
522    def wait_route_established(self, node1: int, node2: int, timeout=10):
523        node2_addr = self.nodes[node2].get_ip6_address(config.ADDRESS_TYPE.RLOC)
524
525        while timeout > 0:
526
527            if self.nodes[node1].ping(node2_addr):
528                break
529
530            self.simulator.go(1)
531            timeout -= 1
532
533        else:
534            raise Exception("Route between node %d and %d is not established" % (node1, node2))
535