1#!/usr/bin/env python3
2#
3#  Copyright (c) 2019, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29import logging
30import os
31import sys
32from typing import Callable, Union
33
34from pktverify.addrs import EthAddr, ExtAddr, Ipv6Addr
35from pktverify.bytes import Bytes
36from pktverify.null_field import nullField
37
38
39def make_filter_func(func: Union[str, Callable], **vars) -> Callable:
40    """
41    Convert the filter to a callable function if it's a string.
42
43    :param func: The filter string or callable.
44    :param vars: The variables.
45    :return: The filter callable.
46    """
47    if isinstance(func, str):
48        # if func is a string, compile it to a function
49        func = func.format_map({k: repr(v) for k, v in vars.items()}).strip()
50        print("\t%s" % func, file=sys.stderr)
51        code = compile('(\n' + func + '\n)', func, "eval")
52
53        def func(p):
54            return eval(
55                code, None, {
56                    'p': p,
57                    'coap': p.coap,
58                    'wpan': p.wpan,
59                    'mle': p.mle,
60                    'ipv6': p.ipv6,
61                    'lowpan': p.lowpan,
62                    'eth': p.eth,
63                    'icmpv6': p.icmpv6,
64                    'udp': p.udp,
65                    'thread_bl': p.thread_bl,
66                    'thread_meshcop': p.thread_meshcop,
67                    'Bytes': Bytes,
68                    'ExtAddr': ExtAddr,
69                    'Ipv6Addr': Ipv6Addr,
70                    'EthAddr': EthAddr,
71                    'thread_nm': p.thread_nm,
72                    'thread_nwd': p.thread_nwd,
73                    'thread_address': p.thread_address,
74                    'thread_bcn': p.thread_bcn,
75                    'dns': p.dns,
76                    'null': nullField,
77                })
78    else:
79        assert not vars, 'can not provide vars for non-str filter: %r %r' % (func, vars)
80
81    assert callable(func)
82    return func
83
84
85def _setup_wireshark_disabled_protos():
86    home = os.environ['HOME']
87    wireshark_config_dir = os.path.join(home, '.config', 'wireshark')
88    os.makedirs(wireshark_config_dir, exist_ok=True)
89    disabled_protos_path = os.path.join(wireshark_config_dir, 'disabled_protos')
90    # read current disabled protos
91    try:
92        with open(disabled_protos_path, 'rt') as fd:
93            disabled_protos = set(l.strip() for l in fd if l.strip() != '')
94    except FileNotFoundError:
95        disabled_protos = set()
96
97    old_disabled_protos_num = len(disabled_protos)
98    disabled_protos.add('lwm')
99    disabled_protos.add('prp')
100    disabled_protos.add('stcsig')
101    disabled_protos.add('transum')
102    disabled_protos.add('zbee_nwk')
103    disabled_protos.add('zbee_nwk_gp')
104
105    if len(disabled_protos) > old_disabled_protos_num:
106        logging.info(f"set disabled_protos = {' '.join(disabled_protos)}")
107        with open(disabled_protos_path, 'wt') as fd:
108            fd.write('\n'.join(sorted(disabled_protos)))
109            fd.write('\n')
110
111
112def get_wireshark_dir() -> str:
113    """
114    :return: The path to wireshark directory.
115    """
116    dir = '/tmp/thread-wireshark'
117    _setup_wireshark_disabled_protos()
118    return dir
119
120
121def which_tshark() -> str:
122    """
123    :return: The path to `tshark` executable.
124    """
125    return os.path.join(get_wireshark_dir(), 'tshark')
126
127
128def which_dumpcap() -> str:
129    """
130    :return: The path to `dumpcap` executable.
131    """
132    return os.path.join(get_wireshark_dir(), 'dumpcap')
133
134
135def which_mergecap() -> str:
136    """
137    :return: The path to `mergecap` executable.
138    """
139    return os.path.join(get_wireshark_dir(), 'mergecap')
140
141
142def colon_hex(hexstr, interval) -> str:
143    """ Convert hexstr to colon separated string every interval
144
145    :param hexstr: The hex string to convert.
146    :param interval: The interval number.
147    :return: The colon separated string.
148    """
149    assert len(hexstr) % interval == 0
150    return ':'.join(hexstr[i:i + interval] for i in range(0, len(hexstr), interval))
151
152
153def is_sublist(lst1: list, lst2: list) -> bool:
154    """ Test whether lst1 is a slice of lst2
155
156    :param lst1: The list to judge if it is a sublist of lst2.
157    :param lst2: The list to judge if contains lst1.
158    :return: Whether lst1 is a slice of lst2.
159    """
160    return lst1 in [lst2[i:len(lst1) + i] for i in range(len(lst1))]
161