1#!/usr/bin/env python3
2#
3#  Copyright (c) 2022, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import argparse
31import ctypes
32import ctypes.util
33import ipaddress
34import json
35import logging
36import os
37import signal
38import socket
39import struct
40import subprocess
41import sys
42from typing import Iterable
43import yaml
44
45from otbr_sim import otbr_docker
46
47GROUP = 'ff02::114'
48PORT = 12345
49
50
51def if_nametoindex(ifname: str) -> int:
52    libc = ctypes.CDLL(ctypes.util.find_library('c'))
53    ret = libc.if_nametoindex(ifname.encode('ascii'))
54    if not ret:
55        raise RuntimeError('Invalid interface name')
56    return ret
57
58
59def get_ipaddr(ifname: str) -> str:
60    for line in os.popen(f'ip addr list dev {ifname} | grep inet | grep global'):
61        addr = line.strip().split()[1]
62        return addr.split('/')[0]
63    raise RuntimeError(f'No IP address on dev {ifname}')
64
65
66def init_socket(ifname: str, group: str, port: int) -> socket.socket:
67    # Look up multicast group address in name server and find out IP version
68    addrinfo = socket.getaddrinfo(group, None)[0]
69    assert addrinfo[0] == socket.AF_INET6
70
71    # Create a socket
72    s = socket.socket(addrinfo[0], socket.SOCK_DGRAM)
73    s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, (ifname + '\0').encode('ascii'))
74
75    # Bind it to the port
76    s.bind((group, port))
77
78    group_bin = socket.inet_pton(addrinfo[0], addrinfo[4][0])
79    # Join group
80    interface_index = if_nametoindex(ifname)
81    mreq = group_bin + struct.pack('@I', interface_index)
82    s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq)
83
84    return s
85
86
87def _advertise(s: socket.socket, dst, info):
88    logging.info('Advertise: %r', info)
89    s.sendto(json.dumps(info).encode('utf-8'), dst)
90
91
92def advertise_devices(s: socket.socket, dst, ven: str, add: str, nodeids: Iterable[int], tag: str):
93    for nodeid in nodeids:
94        info = {
95            'ven': ven,
96            'mod': 'OpenThread',
97            'ver': '4',
98            'add': f'{tag}_{nodeid}@{add}',
99            'por': 22,
100        }
101        _advertise(s, dst, info)
102
103
104def advertise_sniffers(s: socket.socket, dst, add: str, ports: Iterable[int]):
105    for port in ports:
106        info = {
107            'add': add,
108            'por': port,
109        }
110        _advertise(s, dst, info)
111
112
113def start_sniffer(addr: str, port: int, ot_path: str, max_nodes_num: int) -> subprocess.Popen:
114    if isinstance(ipaddress.ip_address(addr), ipaddress.IPv6Address):
115        server = f'[{addr}]:{port}'
116    else:
117        server = f'{addr}:{port}'
118
119    cmd = [
120        'python3',
121        os.path.join(ot_path, 'tools/harness-simulation/posix/sniffer_sim/sniffer.py'),
122        '--grpc-server',
123        server,
124        '--max-nodes-num',
125        str(max_nodes_num),
126    ]
127    logging.info('Executing command:  %s', ' '.join(cmd))
128    return subprocess.Popen(cmd)
129
130
131def main():
132    logging.basicConfig(level=logging.INFO)
133
134    # Parse arguments
135    parser = argparse.ArgumentParser()
136    parser.add_argument('-c',
137                        '--config',
138                        dest='config',
139                        type=str,
140                        required=True,
141                        help='the path of the configuration JSON file')
142    args = parser.parse_args()
143    with open(args.config, 'rt') as f:
144        config = yaml.safe_load(f)
145
146    ot_path = config['ot_path']
147    ot_build = config['ot_build']
148    max_nodes_num = ot_build['max_number']
149    # No test case requires more than 2 sniffers
150    MAX_SNIFFER_NUM = 2
151
152    ot_devices = [(item['tag'], item['number']) for item in ot_build['ot']]
153    otbr_devices = [(item['tag'], item['number']) for item in ot_build['otbr']]
154    ot_nodes_num = sum(x[1] for x in ot_devices)
155    otbr_nodes_num = sum(x[1] for x in otbr_devices)
156    nodes_num = ot_nodes_num + otbr_nodes_num
157    sniffer_num = config['sniffer']['number']
158
159    # Check validation of numbers
160    if not all(0 <= x[1] <= max_nodes_num for x in ot_devices):
161        raise ValueError(f'The number of devices of each OT version should be between 0 and {max_nodes_num}')
162
163    if not all(0 <= x[1] <= max_nodes_num for x in otbr_devices):
164        raise ValueError(f'The number of devices of each OTBR version should be between 0 and {max_nodes_num}')
165
166    if not 1 <= nodes_num <= max_nodes_num:
167        raise ValueError(f'The number of devices should be between 1 and {max_nodes_num}')
168
169    if not 1 <= sniffer_num <= MAX_SNIFFER_NUM:
170        raise ValueError(f'The number of sniffers should be between 1 and {MAX_SNIFFER_NUM}')
171
172    # Get the local IP address on the specified interface
173    ifname = config['discovery_ifname']
174    addr = get_ipaddr(ifname)
175
176    # Start the sniffer
177    sniffer_server_port_base = config['sniffer']['server_port_base']
178    sniffer_procs = []
179    for i in range(sniffer_num):
180        sniffer_procs.append(start_sniffer(addr, i + sniffer_server_port_base, ot_path, max_nodes_num))
181
182    # OTBR firewall scripts create rules inside the Docker container
183    # Run modprobe to load the kernel modules for iptables
184    subprocess.run(['sudo', 'modprobe', 'ip6table_filter'])
185    # Start the BRs
186    otbr_dockers = []
187    nodeid = ot_nodes_num
188    for item in ot_build['otbr']:
189        tag = item['tag']
190        ot_rcp_path = os.path.join(ot_path, item['rcp_subpath'], 'examples/apps/ncp/ot-rcp')
191        docker_image = item['docker_image']
192        for _ in range(item['number']):
193            nodeid += 1
194            otbr_dockers.append(
195                otbr_docker.OtbrDocker(nodeid=nodeid,
196                                       ot_path=ot_path,
197                                       ot_rcp_path=ot_rcp_path,
198                                       docker_image=docker_image,
199                                       docker_name=f'{tag}_{nodeid}'))
200
201    s = init_socket(ifname, GROUP, PORT)
202
203    logging.info('Advertising on interface %s group %s ...', ifname, GROUP)
204
205    # Terminate all sniffer simulation server processes and then exit
206    def exit_handler(signum, context):
207        # Return code is non-zero if any return code of the processes is non-zero
208        ret = 0
209        for sniffer_proc in sniffer_procs:
210            sniffer_proc.terminate()
211            ret = max(ret, sniffer_proc.wait())
212
213        for otbr in otbr_dockers:
214            otbr.close()
215
216        sys.exit(ret)
217
218    signal.signal(signal.SIGINT, exit_handler)
219    signal.signal(signal.SIGTERM, exit_handler)
220
221    # Loop, printing any data we receive
222    while True:
223        data, src = s.recvfrom(64)
224
225        if data == b'BBR':
226            logging.info('Received OpenThread simulation query, advertising')
227
228            nodeid = 1
229            for ven, devices in [('OpenThread_Sim', ot_devices), ('OpenThread_BR_Sim', otbr_devices)]:
230                for tag, number in devices:
231                    advertise_devices(s, src, ven=ven, add=addr, nodeids=range(nodeid, nodeid + number), tag=tag)
232                    nodeid += number
233
234        elif data == b'Sniffer':
235            logging.info('Received sniffer simulation query, advertising')
236            advertise_sniffers(s,
237                               src,
238                               add=addr,
239                               ports=range(sniffer_server_port_base, sniffer_server_port_base + sniffer_num))
240
241        else:
242            logging.warning('Received %r, but ignored', data)
243
244
245if __name__ == '__main__':
246    main()
247