1#!/usr/bin/env python3
2#
3#  Copyright (c) 2019, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS'
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import binascii
31import json
32import logging
33import os
34import signal
35import stat
36import subprocess
37import sys
38import time
39import traceback
40import unittest
41from typing import Optional, Callable, Union, Mapping, Any
42
43import config
44import debug
45from node import Node, OtbrNode, HostNode
46from pktverify import utils as pvutils
47
48PACKET_VERIFICATION = int(os.getenv('PACKET_VERIFICATION', 0))
49
50if PACKET_VERIFICATION:
51    from pktverify.addrs import ExtAddr, EthAddr
52    from pktverify.packet_verifier import PacketVerifier
53
54PORT_OFFSET = int(os.getenv('PORT_OFFSET', "0"))
55
56ENV_THREAD_VERSION = os.getenv('THREAD_VERSION', '1.1')
57
58DEFAULT_PARAMS = {
59    'is_mtd': False,
60    'is_ftd': False,
61    'is_bbr': False,
62    'is_otbr': False,
63    'is_host': False,
64    'mode': 'rdn',
65    'allowlist': None,
66    'version': ENV_THREAD_VERSION,
67}
68"""Default configurations when creating nodes."""
69
70FTD_DEFAULT_PARAMS = {
71    'is_ftd': True,
72    'router_selection_jitter': config.DEFAULT_ROUTER_SELECTION_JITTER,
73}
74
75EXTENDED_ADDRESS_BASE = 0x166e0a0000000000
76"""Extended address base to keep U/L bit 1. The value is borrowed from Thread Test Harness."""
77
78
79class NcpSupportMixin():
80    """ The mixin to check whether a test case supports NCP.
81    """
82
83    SUPPORT_NCP = True
84
85    def __init__(self, *args, **kwargs):
86        if os.getenv('NODE_TYPE', 'sim') == 'ncp-sim' and not self.SUPPORT_NCP:
87            # 77 means skip this test case in automake tests
88            sys.exit(77)
89
90        super().__init__(*args, **kwargs)
91
92
93class TestCase(NcpSupportMixin, unittest.TestCase):
94    """The base class for all thread certification test cases.
95
96    The `topology` member of sub-class is used to create test topology.
97    """
98
99    USE_MESSAGE_FACTORY = True
100    TOPOLOGY = None
101    CASE_WIRESHARK_PREFS = None
102    SUPPORT_THREAD_1_1 = True
103    PACKET_VERIFICATION = config.PACKET_VERIFICATION_DEFAULT
104
105    def __init__(self, *args, **kwargs):
106        super().__init__(*args, **kwargs)
107
108        logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
109
110        self._start_time = None
111        self._do_packet_verification = PACKET_VERIFICATION and hasattr(self, 'verify') \
112                                       and self.PACKET_VERIFICATION == PACKET_VERIFICATION
113
114    def skipTest(self, reason: Any) -> None:
115        self._testSkipped = True
116        super(TestCase, self).skipTest(reason)
117
118    def setUp(self):
119        self._testSkipped = False
120
121        if ENV_THREAD_VERSION == '1.1' and not self.SUPPORT_THREAD_1_1:
122            self.skipTest('Thread 1.1 not supported.')
123
124        try:
125            self._setUp()
126        except:
127            traceback.print_exc()
128            for node in list(self.nodes.values()):
129                try:
130                    node.destroy()
131                except Exception:
132                    traceback.print_exc()
133
134            raise
135
136    def _setUp(self):
137        """Create simulator, nodes and apply configurations.
138        """
139        self._clean_up_tmp()
140
141        self.simulator = config.create_default_simulator(use_message_factory=self.USE_MESSAGE_FACTORY)
142        self.nodes = {}
143
144        os.environ['LD_LIBRARY_PATH'] = '/tmp/thread-wireshark'
145
146        if self._has_backbone_traffic():
147            self._prepare_backbone_network()
148            self._start_backbone_sniffer()
149
150        self._initial_topology = initial_topology = {}
151
152        for i, params in self.TOPOLOGY.items():
153            params = self._parse_params(params)
154            initial_topology[i] = params
155
156            logging.info("Creating node %d: %r", i, params)
157
158            if params['is_otbr']:
159                nodeclass = OtbrNode
160            elif params['is_host']:
161                nodeclass = HostNode
162            else:
163                nodeclass = Node
164
165            node = nodeclass(
166                i,
167                is_mtd=params['is_mtd'],
168                simulator=self.simulator,
169                name=params.get('name'),
170                version=params['version'],
171                is_bbr=params['is_bbr'],
172            )
173            if 'boot_delay' in params:
174                self.simulator.go(params['boot_delay'])
175
176            self.nodes[i] = node
177
178            if node.is_host:
179                continue
180
181            self.nodes[i].set_mode(params['mode'])
182
183            if 'partition_id' in params:
184                self.nodes[i].set_preferred_partition_id(params['partition_id'])
185
186            if params['is_ftd']:
187                self.nodes[i].set_router_selection_jitter(params['router_selection_jitter'])
188
189            if 'router_upgrade_threshold' in params:
190                self.nodes[i].set_router_upgrade_threshold(params['router_upgrade_threshold'])
191            if 'router_downgrade_threshold' in params:
192                self.nodes[i].set_router_downgrade_threshold(params['router_downgrade_threshold'])
193            if 'router_eligible' in params:
194                self.nodes[i].set_router_eligible(params['router_eligible'])
195            if 'prefer_router_id' in params:
196                self.nodes[i].prefer_router_id(params['prefer_router_id'])
197
198            if 'timeout' in params:
199                self.nodes[i].set_timeout(params['timeout'])
200
201            self._set_up_active_dataset(self.nodes[i], params)
202
203            if 'pending_dataset' in params:
204                self.nodes[i].set_pending_dataset(params['pending_dataset']['pendingtimestamp'],
205                                                  params['pending_dataset']['activetimestamp'],
206                                                  panid=params['pending_dataset'].get('panid'),
207                                                  channel=params['pending_dataset'].get('channel'),
208                                                  delay=params['pending_dataset'].get('delay'))
209
210            if 'key_switch_guardtime' in params:
211                self.nodes[i].set_key_switch_guardtime(params['key_switch_guardtime'])
212            if 'key_sequence_counter' in params:
213                self.nodes[i].set_key_sequence_counter(params['key_sequence_counter'])
214
215            if 'network_id_timeout' in params:
216                self.nodes[i].set_network_id_timeout(params['network_id_timeout'])
217
218            if 'context_reuse_delay' in params:
219                self.nodes[i].set_context_reuse_delay(params['context_reuse_delay'])
220
221            if 'max_children' in params:
222                self.nodes[i].set_max_children(params['max_children'])
223
224            if 'bbr_registration_jitter' in params:
225                self.nodes[i].set_bbr_registration_jitter(params['bbr_registration_jitter'])
226
227            if 'router_id_range' in params:
228                self.nodes[i].set_router_id_range(params['router_id_range'][0], params['router_id_range'][1])
229
230        # we have to add allowlist after nodes are all created
231        for i, params in initial_topology.items():
232            allowlist = params['allowlist']
233            if allowlist is None:
234                continue
235
236            for j in allowlist:
237                rssi = None
238                if isinstance(j, tuple):
239                    j, rssi = j
240                self.nodes[i].add_allowlist(self.nodes[j].get_addr64(), rssi=rssi)
241            self.nodes[i].enable_allowlist()
242
243        self._inspector = debug.Inspector(self)
244        self._collect_test_info_after_setup()
245
246    def _set_up_active_dataset(self, node, params):
247        dataset = {
248            'timestamp': 1,
249            'channel': config.CHANNEL,
250            'channel_mask': config.CHANNEL_MASK,
251            'extended_panid': config.EXTENDED_PANID,
252            'mesh_local_prefix': config.MESH_LOCAL_PREFIX.split('/')[0],
253            'network_key': binascii.hexlify(config.DEFAULT_NETWORK_KEY).decode(),
254            'network_name': config.NETWORK_NAME,
255            'panid': config.PANID,
256            'pskc': config.PSKC,
257            'security_policy': config.SECURITY_POLICY,
258        }
259
260        if 'channel' in params:
261            dataset['channel'] = params['channel']
262        if 'networkkey' in params:
263            dataset['network_key'] = params['networkkey']
264        if 'network_name' in params:
265            dataset['network_name'] = params['network_name']
266        if 'panid' in params:
267            dataset['panid'] = params['panid']
268
269        if 'active_dataset' in params:
270            dataset.update(params['active_dataset'])
271
272        node.set_active_dataset(**dataset)
273
274    def inspect(self):
275        self._inspector.inspect()
276
277    def tearDown(self):
278        """Destroy nodes and simulator.
279        """
280        if self._do_packet_verification and os.uname().sysname != "Linux":
281            raise NotImplementedError(
282                f'{self.test_name}: Packet Verification not available on {os.uname().sysname} (Linux only).')
283
284        if self._do_packet_verification:
285            self.simulator.go(3)
286
287        if self._has_backbone_traffic():
288            # Stop Backbone sniffer before stopping nodes so that we don't capture Codecov Uploading traffic
289            self._stop_backbone_sniffer()
290
291        for node in list(self.nodes.values()):
292            try:
293                node.stop()
294            except:
295                traceback.print_exc()
296            finally:
297                node.destroy()
298
299        self.simulator.stop()
300
301        if self._do_packet_verification:
302
303            if self._has_backbone_traffic():
304                self._remove_backbone_network()
305                pcap_filename = self._merge_thread_backbone_pcaps()
306            else:
307                pcap_filename = self._get_thread_pcap_filename()
308
309            self._test_info['pcap'] = pcap_filename
310
311            test_info_path = self._output_test_info()
312            if not self._testSkipped:
313                self._verify_packets(test_info_path)
314
315    def flush_all(self):
316        """Flush away all captured messages of all nodes.
317        """
318        for i in list(self.nodes.keys()):
319            self.simulator.get_messages_sent_by(i)
320
321    def flush_nodes(self, nodes):
322        """Flush away all captured messages of specified nodes.
323
324        Args:
325            nodes (list): nodes whose messages to flush.
326
327        """
328        for i in nodes:
329            if i in list(self.nodes.keys()):
330                self.simulator.get_messages_sent_by(i)
331
332    def _clean_up_tmp(self):
333        """
334        Clean up node files in tmp directory
335        """
336        os.system(f"rm -f tmp/{PORT_OFFSET}_*.flash tmp/{PORT_OFFSET}_*.data tmp/{PORT_OFFSET}_*.swap")
337
338    def _verify_packets(self, test_info_path: str):
339        pv = PacketVerifier(test_info_path, self.CASE_WIRESHARK_PREFS)
340        pv.add_common_vars()
341        pv.pkts.filter_thread_unallowed_icmpv6().must_not_next()
342        self.verify(pv)
343        print("Packet verification passed: %s" % test_info_path, file=sys.stderr)
344
345    @property
346    def test_name(self):
347        return os.getenv('TEST_NAME', 'current')
348
349    def collect_ipaddrs(self):
350        if not self._do_packet_verification:
351            return
352
353        test_info = self._test_info
354
355        for i, node in self.nodes.items():
356            ipaddrs = node.get_addrs()
357
358            if hasattr(node, 'get_ether_addrs'):
359                ipaddrs += node.get_ether_addrs()
360
361            test_info['ipaddrs'][i] = ipaddrs
362            if not node.is_host:
363                mleid = node.get_mleid()
364                test_info['mleids'][i] = mleid
365
366    def collect_rloc16s(self):
367        if not self._do_packet_verification:
368            return
369
370        test_info = self._test_info
371        test_info['rloc16s'] = {}
372
373        for i, node in self.nodes.items():
374            if not node.is_host:
375                test_info['rloc16s'][i] = '0x%04x' % node.get_addr16()
376
377    def collect_rlocs(self):
378        if not self._do_packet_verification:
379            return
380
381        test_info = self._test_info
382        test_info['rlocs'] = {}
383
384        for i, node in self.nodes.items():
385            if node.is_host:
386                continue
387
388            test_info['rlocs'][i] = node.get_rloc()
389
390    def collect_omrs(self):
391        if not self._do_packet_verification:
392            return
393
394        test_info = self._test_info
395        test_info['omrs'] = {}
396
397        for i, node in self.nodes.items():
398            if node.is_host:
399                continue
400
401            test_info['omrs'][i] = node.get_ip6_address(config.ADDRESS_TYPE.OMR)
402
403    def collect_duas(self):
404        if not self._do_packet_verification:
405            return
406
407        test_info = self._test_info
408        test_info['duas'] = {}
409
410        for i, node in self.nodes.items():
411            if node.is_host:
412                continue
413
414            test_info['duas'][i] = node.get_ip6_address(config.ADDRESS_TYPE.DUA)
415
416    def collect_leader_aloc(self, node):
417        if not self._do_packet_verification:
418            return
419
420        test_info = self._test_info
421        test_info['leader_aloc'] = self.nodes[node].get_addr_leader_aloc()
422
423    def collect_extra_vars(self, **vars):
424        if not self._do_packet_verification:
425            return
426
427        for k in vars.keys():
428            assert isinstance(k, str), k
429
430        test_vars = self._test_info.setdefault("extra_vars", {})
431        test_vars.update(vars)
432
433    def _collect_test_info_after_setup(self):
434        """
435        Collect test info after setUp
436        """
437        if not self._do_packet_verification:
438            return
439
440        test_info = self._test_info = {
441            'script': os.path.abspath(sys.argv[0]),
442            'testcase': self.test_name,
443            'start_time': time.ctime(self._start_time),
444            'pcap': '',
445            'extaddrs': {},
446            'ethaddrs': {},
447            'ipaddrs': {},
448            'mleids': {},
449            'topology': self._initial_topology,
450            'backbone': {
451                'interface': config.BACKBONE_DOCKER_NETWORK_NAME,
452                'prefix': config.BACKBONE_PREFIX,
453            },
454            'domain_prefix': config.DOMAIN_PREFIX,
455            'env': {
456                'PORT_OFFSET': config.PORT_OFFSET,
457            },
458        }
459
460        for i, node in self.nodes.items():
461            if not node.is_host:
462                extaddr = node.get_addr64()
463                test_info['extaddrs'][i] = ExtAddr(extaddr).format_octets()
464
465            if node.is_host or node.is_otbr:
466                ethaddr = node.get_ether_mac()
467                test_info['ethaddrs'][i] = EthAddr(ethaddr).format_octets()
468
469    def _output_test_info(self):
470        """
471        Output test info to json file after tearDown
472        """
473        filename = f'{self.test_name}.json'
474        with open(filename, 'wt') as ofd:
475            ofd.write(json.dumps(self._test_info, indent=1, sort_keys=True))
476
477        return filename
478
479    def _get_thread_pcap_filename(self):
480        current_pcap = self.test_name + '.pcap'
481        return os.path.abspath(current_pcap)
482
483    def assure_run_ok(self, cmd, shell=False):
484        if not shell and isinstance(cmd, str):
485            cmd = cmd.split()
486        proc = subprocess.run(cmd, stdout=sys.stdout, stderr=sys.stderr, shell=shell)
487        print(">>> %s => %d" % (cmd, proc.returncode), file=sys.stderr)
488        proc.check_returncode()
489
490    def _parse_params(self, params: Optional[dict]) -> dict:
491        params = params or {}
492
493        if params.get('is_bbr') or params.get('is_otbr'):
494            # BBRs must not use thread version 1.1
495            version = params.get('version', '1.3')
496            assert version != '1.1', params
497            params['version'] = version
498            params.setdefault('bbr_registration_jitter', config.DEFAULT_BBR_REGISTRATION_JITTER)
499        elif params.get('is_host'):
500            # Hosts must not specify thread version
501            assert params.get('version', '') == '', params
502            params['version'] = ''
503
504        # use 1.3 node for 1.2 tests
505        if params.get('version') == '1.2':
506            params['version'] = '1.3'
507
508        is_ftd = (not params.get('is_mtd') and not params.get('is_host'))
509
510        effective_params = DEFAULT_PARAMS.copy()
511
512        if is_ftd:
513            effective_params.update(FTD_DEFAULT_PARAMS)
514
515        effective_params.update(params)
516
517        return effective_params
518
519    def _has_backbone_traffic(self):
520        for param in self.TOPOLOGY.values():
521            if param and (param.get('is_otbr') or param.get('is_host')):
522                return True
523
524        return False
525
526    def _prepare_backbone_network(self):
527        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
528        self.assure_run_ok(
529            f'docker network create --driver bridge --ipv6 --subnet {config.BACKBONE_PREFIX} -o "com.docker.network.bridge.name"="{network_name}" {network_name} || true',
530            shell=True)
531
532    def _remove_backbone_network(self):
533        network_name = config.BACKBONE_DOCKER_NETWORK_NAME
534        self.assure_run_ok(f'docker network rm {network_name}', shell=True)
535
536    def _start_backbone_sniffer(self):
537        # don't know why but I have to create the empty bbr.pcap first, otherwise tshark won't work
538        # self.assure_run_ok("truncate --size 0 bbr.pcap && chmod 664 bbr.pcap", shell=True)
539        pcap_file = self._get_backbone_pcap_filename()
540        try:
541            os.remove(pcap_file)
542        except FileNotFoundError:
543            pass
544
545        dumpcap = pvutils.which_dumpcap()
546        self._dumpcap_proc = subprocess.Popen([dumpcap, '-i', config.BACKBONE_DOCKER_NETWORK_NAME, '-w', pcap_file],
547                                              stdout=sys.stdout,
548                                              stderr=sys.stderr)
549        time.sleep(0.2)
550        assert self._dumpcap_proc.poll() is None, 'tshark terminated unexpectedly'
551        logging.info('Backbone sniffer launched successfully: pid=%s', self._dumpcap_proc.pid)
552        os.chmod(pcap_file, stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
553
554    def _get_backbone_pcap_filename(self):
555        backbone_pcap = self.test_name + '_backbone.pcap'
556        return os.path.abspath(backbone_pcap)
557
558    def _get_merged_pcap_filename(self):
559        backbone_pcap = self.test_name + '_merged.pcap'
560        return os.path.abspath(backbone_pcap)
561
562    def _stop_backbone_sniffer(self):
563        self._dumpcap_proc.send_signal(signal.SIGTERM)
564        self._dumpcap_proc.__exit__(None, None, None)
565        logging.info('Backbone sniffer terminated successfully: pid=%s' % self._dumpcap_proc.pid)
566
567    def _merge_thread_backbone_pcaps(self):
568        thread_pcap = self._get_thread_pcap_filename()
569        backbone_pcap = self._get_backbone_pcap_filename()
570        merged_pcap = self._get_merged_pcap_filename()
571
572        mergecap = pvutils.which_mergecap()
573        self.assure_run_ok(f'{mergecap} -w {merged_pcap} {thread_pcap} {backbone_pcap}', shell=True)
574        return merged_pcap
575
576    def wait_until(self, cond: Callable[[], bool], timeout: int, go_interval: int = 1):
577        while True:
578            self.simulator.go(go_interval)
579
580            if cond():
581                break
582
583            timeout -= go_interval
584            if timeout <= 0:
585                raise RuntimeError(f'wait failed after {timeout} seconds')
586
587    def wait_node_state(self, node: Union[int, Node], state: str, timeout: int):
588        node = self.nodes[node] if isinstance(node, int) else node
589        self.wait_until(lambda: node.get_state() == state, timeout)
590
591    def wait_route_established(self, node1: int, node2: int, timeout=10):
592        node2_addr = self.nodes[node2].get_ip6_address(config.ADDRESS_TYPE.RLOC)
593
594        while timeout > 0:
595
596            if self.nodes[node1].ping(node2_addr):
597                break
598
599            self.simulator.go(1)
600            timeout -= 1
601
602        else:
603            raise Exception("Route between node %d and %d is not established" % (node1, node2))
604
605    def assertDictIncludes(self, actual: Mapping[str, str], expected: Mapping[str, str]):
606        """ Asserts the `actual` dict includes the `expected` dict.
607
608        Args:
609            actual: A dict for checking.
610            expected: The expected items that the actual dict should contains.
611        """
612        for k, v in expected.items():
613            if k not in actual:
614                raise AssertionError(f"key {k} is not found in first dict")
615            if v != actual[k]:
616                raise AssertionError(f"{repr(actual[k])} != {repr(v)} for key {k}")
617