1#!/usr/bin/env python3
2#
3#  Copyright (c) 2019, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import binascii
31import json
32import logging
33import os
34import signal
35import stat
36import subprocess
37import sys
38import time
39import traceback
40import unittest
41from typing import Optional, Callable, Union, Mapping, Any
42
43import config
44import debug
45from node import Node, OtbrNode, HostNode
46from pktverify import utils as pvutils
47
48PACKET_VERIFICATION = int(os.getenv('PACKET_VERIFICATION', 0))
49
50if PACKET_VERIFICATION:
51    from pktverify.addrs import ExtAddr, EthAddr
52    from pktverify.packet_verifier import PacketVerifier
53
54PORT_OFFSET = int(os.getenv('PORT_OFFSET', "0"))
55
56ENV_THREAD_VERSION = os.getenv('THREAD_VERSION', '1.1')
57
58DEFAULT_PARAMS = {
59    'is_mtd': False,
60    'is_ftd': False,
61    'is_bbr': False,
62    'is_otbr': False,
63    'is_host': False,
64    'mode': 'rdn',
65    'allowlist': None,
66    'version': ENV_THREAD_VERSION,
67    'panid': 0xface,
68}
69"""Default configurations when creating nodes."""
70
71FTD_DEFAULT_PARAMS = {
72    'is_ftd': True,
73    'router_selection_jitter': config.DEFAULT_ROUTER_SELECTION_JITTER,
74}
75
76EXTENDED_ADDRESS_BASE = 0x166e0a0000000000
77"""Extended address base to keep U/L bit 1. The value is borrowed from Thread Test Harness."""
78
79
80class NcpSupportMixin():
81    """ The mixin to check whether a test case supports NCP.
82    """
83
84    SUPPORT_NCP = True
85
86    def __init__(self, *args, **kwargs):
87        if os.getenv('NODE_TYPE', 'sim') == 'ncp-sim' and not self.SUPPORT_NCP:
88            # 77 means skip this test case in automake tests
89            sys.exit(77)
90
91        super().__init__(*args, **kwargs)
92
93
94class TestCase(NcpSupportMixin, unittest.TestCase):
95    """The base class for all thread certification test cases.
96
97    The `topology` member of sub-class is used to create test topology.
98    """
99
100    USE_MESSAGE_FACTORY = True
101    TOPOLOGY = None
102    CASE_WIRESHARK_PREFS = None
103    SUPPORT_THREAD_1_1 = True
104    PACKET_VERIFICATION = config.PACKET_VERIFICATION_DEFAULT
105
106    def __init__(self, *args, **kwargs):
107        super().__init__(*args, **kwargs)
108
109        logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
110
111        self._start_time = None
112        self._do_packet_verification = PACKET_VERIFICATION and hasattr(self, 'verify') \
113                                       and self.PACKET_VERIFICATION == PACKET_VERIFICATION
114
115    def skipTest(self, reason: Any) -> None:
116        self._testSkipped = True
117        super(TestCase, self).skipTest(reason)
118
119    def setUp(self):
120        self._testSkipped = False
121
122        if ENV_THREAD_VERSION == '1.1' and not self.SUPPORT_THREAD_1_1:
123            self.skipTest('Thread 1.1 not supported.')
124
125        try:
126            self._setUp()
127        except:
128            traceback.print_exc()
129            for node in list(self.nodes.values()):
130                try:
131                    node.destroy()
132                except Exception:
133                    traceback.print_exc()
134
135            raise
136
137    def _setUp(self):
138        """Create simulator, nodes and apply configurations.
139        """
140        self._clean_up_tmp()
141
142        self.simulator = config.create_default_simulator(use_message_factory=self.USE_MESSAGE_FACTORY)
143        self.nodes = {}
144
145        os.environ['LD_LIBRARY_PATH'] = '/tmp/thread-wireshark'
146
147        if self._has_backbone_traffic():
148            self._prepare_backbone_network()
149            self._start_backbone_sniffer()
150
151        self._initial_topology = initial_topology = {}
152
153        for i, params in self.TOPOLOGY.items():
154            params = self._parse_params(params)
155            initial_topology[i] = params
156
157            logging.info("Creating node %d: %r", i, params)
158
159            if params['is_otbr']:
160                nodeclass = OtbrNode
161            elif params['is_host']:
162                nodeclass = HostNode
163            else:
164                nodeclass = Node
165
166            node = nodeclass(
167                i,
168                is_mtd=params['is_mtd'],
169                simulator=self.simulator,
170                name=params.get('name'),
171                version=params['version'],
172                is_bbr=params['is_bbr'],
173            )
174            if 'boot_delay' in params:
175                self.simulator.go(params['boot_delay'])
176
177            self.nodes[i] = node
178
179            if node.is_host:
180                continue
181
182            self.nodes[i].set_networkkey(binascii.hexlify(config.DEFAULT_NETWORK_KEY).decode())
183            self.nodes[i].set_panid(params['panid'])
184            self.nodes[i].set_mode(params['mode'])
185
186            if 'extended_panid' in params:
187                self.nodes[i].set_extpanid(params['extended_panid'])
188            if 'partition_id' in params:
189                self.nodes[i].set_preferred_partition_id(params['partition_id'])
190            if 'channel' in params:
191                self.nodes[i].set_channel(params['channel'])
192            if 'networkkey' in params:
193                self.nodes[i].set_networkkey(params['networkkey'])
194            if 'network_name' in params:
195                self.nodes[i].set_network_name(params['network_name'])
196
197            if params['is_ftd']:
198                self.nodes[i].set_router_selection_jitter(params['router_selection_jitter'])
199
200            if 'router_upgrade_threshold' in params:
201                self.nodes[i].set_router_upgrade_threshold(params['router_upgrade_threshold'])
202            if 'router_downgrade_threshold' in params:
203                self.nodes[i].set_router_downgrade_threshold(params['router_downgrade_threshold'])
204            if 'router_eligible' in params:
205                self.nodes[i].set_router_eligible(params['router_eligible'])
206            if 'prefer_router_id' in params:
207                self.nodes[i].prefer_router_id(params['prefer_router_id'])
208
209            if 'timeout' in params:
210                self.nodes[i].set_timeout(params['timeout'])
211
212            if 'active_dataset' in params:
213                if 'network_key' not in params['active_dataset']:
214                    params['active_dataset']['network_key'] = binascii.hexlify(config.DEFAULT_NETWORK_KEY).decode()
215                self.nodes[i].set_active_dataset(params['active_dataset']['timestamp'],
216                                                 panid=params['active_dataset'].get('panid'),
217                                                 channel=params['active_dataset'].get('channel'),
218                                                 channel_mask=params['active_dataset'].get('channel_mask'),
219                                                 network_key=params['active_dataset'].get('network_key'),
220                                                 security_policy=params['active_dataset'].get('security_policy'))
221
222            if 'pending_dataset' in params:
223                self.nodes[i].set_pending_dataset(params['pending_dataset']['pendingtimestamp'],
224                                                  params['pending_dataset']['activetimestamp'],
225                                                  panid=params['pending_dataset'].get('panid'),
226                                                  channel=params['pending_dataset'].get('channel'),
227                                                  delay=params['pending_dataset'].get('delay'))
228
229            if 'key_switch_guardtime' in params:
230                self.nodes[i].set_key_switch_guardtime(params['key_switch_guardtime'])
231            if 'key_sequence_counter' in params:
232                self.nodes[i].set_key_sequence_counter(params['key_sequence_counter'])
233
234            if 'network_id_timeout' in params:
235                self.nodes[i].set_network_id_timeout(params['network_id_timeout'])
236
237            if 'context_reuse_delay' in params:
238                self.nodes[i].set_context_reuse_delay(params['context_reuse_delay'])
239
240            if 'max_children' in params:
241                self.nodes[i].set_max_children(params['max_children'])
242
243            if 'bbr_registration_jitter' in params:
244                self.nodes[i].set_bbr_registration_jitter(params['bbr_registration_jitter'])
245
246            if 'router_id_range' in params:
247                self.nodes[i].set_router_id_range(params['router_id_range'][0], params['router_id_range'][1])
248
249        # we have to add allowlist after nodes are all created
250        for i, params in initial_topology.items():
251            allowlist = params['allowlist']
252            if allowlist is None:
253                continue
254
255            for j in allowlist:
256                rssi = None
257                if isinstance(j, tuple):
258                    j, rssi = j
259                self.nodes[i].add_allowlist(self.nodes[j].get_addr64(), rssi=rssi)
260            self.nodes[i].enable_allowlist()
261
262        self._inspector = debug.Inspector(self)
263        self._collect_test_info_after_setup()
264
265    def inspect(self):
266        self._inspector.inspect()
267
268    def tearDown(self):
269        """Destroy nodes and simulator.
270        """
271        if self._do_packet_verification and os.uname().sysname != "Linux":
272            raise NotImplementedError(
273                f'{self.test_name}: Packet Verification not available on {os.uname().sysname} (Linux only).')
274
275        if self._do_packet_verification:
276            self.simulator.go(3)
277
278        if self._has_backbone_traffic():
279            # Stop Backbone sniffer before stopping nodes so that we don't capture Codecov Uploading traffic
280            self._stop_backbone_sniffer()
281
282        for node in list(self.nodes.values()):
283            try:
284                node.stop()
285            except:
286                traceback.print_exc()
287            finally:
288                node.destroy()
289
290        self.simulator.stop()
291
292        if self._do_packet_verification:
293
294            if self._has_backbone_traffic():
295                self._remove_backbone_network()
296                pcap_filename = self._merge_thread_backbone_pcaps()
297            else:
298                pcap_filename = self._get_thread_pcap_filename()
299
300            self._test_info['pcap'] = pcap_filename
301
302            test_info_path = self._output_test_info()
303            if not self._testSkipped:
304                self._verify_packets(test_info_path)
305
306    def flush_all(self):
307        """Flush away all captured messages of all nodes.
308        """
309        for i in list(self.nodes.keys()):
310            self.simulator.get_messages_sent_by(i)
311
312    def flush_nodes(self, nodes):
313        """Flush away all captured messages of specified nodes.
314
315        Args:
316            nodes (list): nodes whose messages to flush.
317
318        """
319        for i in nodes:
320            if i in list(self.nodes.keys()):
321                self.simulator.get_messages_sent_by(i)
322
323    def _clean_up_tmp(self):
324        """
325        Clean up node files in tmp directory
326        """
327        os.system(f"rm -f tmp/{PORT_OFFSET}_*.flash tmp/{PORT_OFFSET}_*.data tmp/{PORT_OFFSET}_*.swap")
328
329    def _verify_packets(self, test_info_path: str):
330        pv = PacketVerifier(test_info_path, self.CASE_WIRESHARK_PREFS)
331        pv.add_common_vars()
332        pv.pkts.filter_thread_unallowed_icmpv6().must_not_next()
333        self.verify(pv)
334        print("Packet verification passed: %s" % test_info_path, file=sys.stderr)
335
336    @property
337    def test_name(self):
338        return os.getenv('TEST_NAME', 'current')
339
340    def collect_ipaddrs(self):
341        if not self._do_packet_verification:
342            return
343
344        test_info = self._test_info
345
346        for i, node in self.nodes.items():
347            ipaddrs = node.get_addrs()
348
349            if hasattr(node, 'get_ether_addrs'):
350                ipaddrs += node.get_ether_addrs()
351
352            test_info['ipaddrs'][i] = ipaddrs
353            if not node.is_host:
354                mleid = node.get_mleid()
355                test_info['mleids'][i] = mleid
356
357    def collect_rloc16s(self):
358        if not self._do_packet_verification:
359            return
360
361        test_info = self._test_info
362        test_info['rloc16s'] = {}
363
364        for i, node in self.nodes.items():
365            if not node.is_host:
366                test_info['rloc16s'][i] = '0x%04x' % node.get_addr16()
367
368    def collect_rlocs(self):
369        if not self._do_packet_verification:
370            return
371
372        test_info = self._test_info
373        test_info['rlocs'] = {}
374
375        for i, node in self.nodes.items():
376            if node.is_host:
377                continue
378
379            test_info['rlocs'][i] = node.get_rloc()
380
381    def collect_omrs(self):
382        if not self._do_packet_verification:
383            return
384
385        test_info = self._test_info
386        test_info['omrs'] = {}
387
388        for i, node in self.nodes.items():
389            if node.is_host:
390                continue
391
392            test_info['omrs'][i] = node.get_ip6_address(config.ADDRESS_TYPE.OMR)
393
394    def collect_duas(self):
395        if not self._do_packet_verification:
396            return
397
398        test_info = self._test_info
399        test_info['duas'] = {}
400
401        for i, node in self.nodes.items():
402            if node.is_host:
403                continue
404
405            test_info['duas'][i] = node.get_ip6_address(config.ADDRESS_TYPE.DUA)
406
407    def collect_leader_aloc(self, node):
408        if not self._do_packet_verification:
409            return
410
411        test_info = self._test_info
412        test_info['leader_aloc'] = self.nodes[node].get_addr_leader_aloc()
413
414    def collect_extra_vars(self, **vars):
415        if not self._do_packet_verification:
416            return
417
418        for k in vars.keys():
419            assert isinstance(k, str), k
420
421        test_vars = self._test_info.setdefault("extra_vars", {})
422        test_vars.update(vars)
423
424    def _collect_test_info_after_setup(self):
425        """
426        Collect test info after setUp
427        """
428        if not self._do_packet_verification:
429            return
430
431        test_info = self._test_info = {
432            'script': os.path.abspath(sys.argv[0]),
433            'testcase': self.test_name,
434            'start_time': time.ctime(self._start_time),
435            'pcap': '',
436            'extaddrs': {},
437            'ethaddrs': {},
438            'ipaddrs': {},
439            'mleids': {},
440            'topology': self._initial_topology,
441            'backbone': {
442                'interface': config.BACKBONE_DOCKER_NETWORK_NAME,
443                'prefix': config.BACKBONE_PREFIX,
444            },
445            'domain_prefix': config.DOMAIN_PREFIX,
446            'env': {
447                'PORT_OFFSET': config.PORT_OFFSET,
448            },
449        }
450
451        for i, node in self.nodes.items():
452            if not node.is_host:
453                extaddr = node.get_addr64()
454                test_info['extaddrs'][i] = ExtAddr(extaddr).format_octets()
455
456            if node.is_host or node.is_otbr:
457                ethaddr = node.get_ether_mac()
458                test_info['ethaddrs'][i] = EthAddr(ethaddr).format_octets()
459
460    def _output_test_info(self):
461        """
462        Output test info to json file after tearDown
463        """
464        filename = f'{self.test_name}.json'
465        with open(filename, 'wt') as ofd:
466            ofd.write(json.dumps(self._test_info, indent=1, sort_keys=True))
467
468        return filename
469
470    def _get_thread_pcap_filename(self):
471        current_pcap = self.test_name + '.pcap'
472        return os.path.abspath(current_pcap)
473
474    def assure_run_ok(self, cmd, shell=False):
475        if not shell and isinstance(cmd, str):
476            cmd = cmd.split()
477        proc = subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr, shell=shell)
478        print(">>> %s => %d" % (cmd, proc.returncode), file=sys.stderr)
479        proc.check_returncode()
480
481    def _parse_params(self, params: Optional[dict]) -> dict:
482        params = params or {}
483
484        if params.get('is_bbr') or params.get('is_otbr'):
485            # BBRs must not use thread version 1.1
486            version = params.get('version', '1.3')
487            assert version != '1.1', params
488            params['version'] = version
489            params.setdefault('bbr_registration_jitter', config.DEFAULT_BBR_REGISTRATION_JITTER)
490        elif params.get('is_host'):
491            # Hosts must not specify thread version
492            assert params.get('version', '') == '', params
493            params['version'] = ''
494
495        # use 1.3 node for 1.2 tests
496        if params.get('version') == '1.2':
497            params['version'] = '1.3'
498
499        is_ftd = (not params.get('is_mtd') and not params.get('is_host'))
500
501        effective_params = DEFAULT_PARAMS.copy()
502
503        if is_ftd:
504            effective_params.update(FTD_DEFAULT_PARAMS)
505
506        effective_params.update(params)
507
508        return effective_params
509
510    def _has_backbone_traffic(self):
511        for param in self.TOPOLOGY.values():
512            if param and (param.get('is_otbr') or param.get('is_host')):
513                return True
514
515        return False
516
517    def _prepare_backbone_network(self):
518        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
519        self.assure_run_ok(
520            f'docker network create --driver bridge --ipv6 --subnet {config.BACKBONE_PREFIX} -o "com.docker.network.bridge.name"="{network_name}" {network_name} || true',
521            shell=True)
522
523    def _remove_backbone_network(self):
524        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
525        self.assure_run_ok(f'docker network rm {network_name}', shell=True)
526
527    def _start_backbone_sniffer(self):
528        # don't know why but I have to create the empty bbr.pcap first, otherwise tshark won't work
529        # self.assure_run_ok("truncate --size 0 bbr.pcap && chmod 664 bbr.pcap", shell=True)
530        pcap_file = self._get_backbone_pcap_filename()
531        try:
532            os.remove(pcap_file)
533        except FileNotFoundError:
534            pass
535
536        dumpcap = pvutils.which_dumpcap()
537        self._dumpcap_proc = subprocess.Popen([dumpcap, '-i', config.BACKBONE_DOCKER_NETWORK_NAME, '-w', pcap_file],
538                                              stdout=sys.stdout,
539                                              stderr=sys.stderr)
540        time.sleep(0.2)
541        assert self._dumpcap_proc.poll() is None, 'tshark terminated unexpectedly'
542        logging.info('Backbone sniffer launched successfully: pid=%s', self._dumpcap_proc.pid)
543        os.chmod(pcap_file, stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
544
545    def _get_backbone_pcap_filename(self):
546        backbone_pcap = self.test_name + '_backbone.pcap'
547        return os.path.abspath(backbone_pcap)
548
549    def _get_merged_pcap_filename(self):
550        backbone_pcap = self.test_name + '_merged.pcap'
551        return os.path.abspath(backbone_pcap)
552
553    def _stop_backbone_sniffer(self):
554        self._dumpcap_proc.send_signal(signal.SIGTERM)
555        self._dumpcap_proc.__exit__(None, None, None)
556        logging.info('Backbone sniffer terminated successfully: pid=%s' % self._dumpcap_proc.pid)
557
558    def _merge_thread_backbone_pcaps(self):
559        thread_pcap = self._get_thread_pcap_filename()
560        backbone_pcap = self._get_backbone_pcap_filename()
561        merged_pcap = self._get_merged_pcap_filename()
562
563        mergecap = pvutils.which_mergecap()
564        self.assure_run_ok(f'{mergecap} -w {merged_pcap} {thread_pcap} {backbone_pcap}', shell=True)
565        return merged_pcap
566
567    def wait_until(self, cond: Callable[[], bool], timeout: int, go_interval: int = 1):
568        while True:
569            self.simulator.go(go_interval)
570
571            if cond():
572                break
573
574            timeout -= go_interval
575            if timeout <= 0:
576                raise RuntimeError(f'wait failed after {timeout} seconds')
577
578    def wait_node_state(self, node: Union[int, Node], state: str, timeout: int):
579        node = self.nodes[node] if isinstance(node, int) else node
580        self.wait_until(lambda: node.get_state() == state, timeout)
581
582    def wait_route_established(self, node1: int, node2: int, timeout=10):
583        node2_addr = self.nodes[node2].get_ip6_address(config.ADDRESS_TYPE.RLOC)
584
585        while timeout > 0:
586
587            if self.nodes[node1].ping(node2_addr):
588                break
589
590            self.simulator.go(1)
591            timeout -= 1
592
593        else:
594            raise Exception("Route between node %d and %d is not established" % (node1, node2))
595
596    def assertDictIncludes(self, actual: Mapping[str, str], expected: Mapping[str, str]):
597        """ Asserts the `actual` dict includes the `expected` dict.
598
599        Args:
600            actual: A dict for checking.
601            expected: The expected items that the actual dict should contains.
602        """
603        for k, v in expected.items():
604            if k not in actual:
605                raise AssertionError(f"key {k} is not found in first dict")
606            if v != actual[k]:
607                raise AssertionError(f"{repr(actual[k])} != {repr(v)} for key {k}")
608