1#!/usr/bin/env python3
2#
3#  Copyright (c) 2018, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import binascii
31import bisect
32import os
33import socket
34import struct
35import traceback
36import time
37
38import io
39import config
40import mesh_cop
41import message
42import pcap
43import wpan
44
45
46def dbg_print(*args):
47    if False:
48        print(args)
49
50
51class BaseSimulator(object):
52
53    def __init__(self):
54        self._nodes = {}
55        self.commissioning_messages = {}
56        self._payload_parse_factory = mesh_cop.MeshCopCommandFactory(mesh_cop.create_default_mesh_cop_tlv_factories())
57        self._mesh_cop_msg_set = mesh_cop.create_mesh_cop_message_type_set()
58
59    def __del__(self):
60        self._nodes = None
61
62    def add_node(self, node):
63        self._nodes[node.nodeid] = node
64        self.commissioning_messages[node.nodeid] = []
65
66    def set_lowpan_context(self, cid, prefix):
67        raise NotImplementedError
68
69    def get_messages_sent_by(self, nodeid):
70        raise NotImplementedError
71
72    def go(self, duration, nodeid=None):
73        raise NotImplementedError
74
75    def stop(self):
76        raise NotImplementedError
77
78    def read_cert_messages_in_commissioning_log(self, nodeids):
79        for nodeid in nodeids:
80            node = self._nodes[nodeid]
81            if not node:
82                continue
83            for (
84                    direction,
85                    type,
86                    payload,
87            ) in node.read_cert_messages_in_commissioning_log():
88                msg = self._payload_parse_factory.parse(type.decode("utf-8"), io.BytesIO(payload))
89                self.commissioning_messages[nodeid].append(msg)
90
91
92class RealTime(BaseSimulator):
93
94    def __init__(self, use_message_factory=True):
95        super(RealTime, self).__init__()
96        self._sniffer = config.create_default_thread_sniffer(use_message_factory=use_message_factory)
97        self._sniffer.start()
98
99    def set_lowpan_context(self, cid, prefix):
100        self._sniffer.set_lowpan_context(cid, prefix)
101
102    def get_messages_sent_by(self, nodeid):
103        messages = self._sniffer.get_messages_sent_by(nodeid).messages
104        ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
105        self.commissioning_messages[nodeid] = []
106        return ret
107
108    def now(self):
109        return time.time()
110
111    def go(self, duration, nodeid=None):
112        time.sleep(duration)
113
114    def stop(self):
115        if self.is_running:
116            # self._sniffer.stop()  # FIXME: seems it blocks forever
117            self._sniffer = None
118
119    @property
120    def is_running(self):
121        return self._sniffer is not None
122
123
124class VirtualTime(BaseSimulator):
125
126    OT_SIM_EVENT_ALARM_FIRED = 0
127    OT_SIM_EVENT_RADIO_RECEIVED = 1
128    OT_SIM_EVENT_UART_WRITE = 2
129    OT_SIM_EVENT_RADIO_SPINEL_WRITE = 3
130    OT_SIM_EVENT_POSTCMD = 4
131
132    EVENT_TIME = 0
133    EVENT_SEQUENCE = 1
134    EVENT_ADDR = 2
135    EVENT_TYPE = 3
136    EVENT_DATA_LENGTH = 4
137    EVENT_DATA = 5
138
139    BASE_PORT = 9000
140    MAX_NODES = 33
141    MAX_MESSAGE = 1024
142    END_OF_TIME = float('inf')
143    PORT_OFFSET = int(os.getenv('PORT_OFFSET', '0'))
144
145    BLOCK_TIMEOUT = 10
146
147    NCP_SIM = os.getenv('NODE_TYPE', 'sim') == 'ncp-sim'
148
149    _message_factory = None
150
151    def __init__(self, use_message_factory=True):
152        super(VirtualTime, self).__init__()
153        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
154        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 2 * 1024 * 1024)
155        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2 * 1024 * 1024)
156
157        ip = '127.0.0.1'
158        self.port = self.BASE_PORT + (self.PORT_OFFSET * (self.MAX_NODES + 1))
159        self.sock.bind((ip, self.port))
160
161        self.devices = {}
162        self.event_queue = []
163        # there could be events scheduled at exactly the same time
164        self.event_sequence = 0
165        self.current_time = 0
166        self.current_event = None
167        self.awake_devices = set()
168        self._nodes_by_ack_seq = {}
169        self._node_ack_seq = {}
170
171        self._pcap = pcap.PcapCodec(os.getenv('TEST_NAME', 'current'))
172        # the addr for spinel-cli sending OT_SIM_EVENT_POSTCMD
173        self._spinel_cli_addr = (ip, self.BASE_PORT + self.port)
174        self.current_nodeid = None
175        self._pause_time = 0
176
177        if use_message_factory:
178            self._message_factory = config.create_default_thread_message_factory()
179        else:
180            self._message_factory = None
181
182    def __del__(self):
183        if self.sock:
184            self.stop()
185
186    def stop(self):
187        if self.sock:
188            self.sock.close()
189            self.sock = None
190
191    @property
192    def is_running(self):
193        return self.sock is not None
194
195    def _add_message(self, nodeid, message_obj):
196        addr = ('127.0.0.1', self.port + nodeid)
197
198        # Ignore any exceptions
199        try:
200            if self._message_factory is not None:
201                messages = self._message_factory.create(io.BytesIO(message_obj))
202                self.devices[addr]['msgs'] += messages
203
204        except message.DropPacketException:
205            print('Drop current packet because it cannot be handled in test scripts')
206        except Exception as e:
207            # Just print the exception to the console
208            print("EXCEPTION: %s" % e)
209            traceback.print_exc()
210
211    def set_lowpan_context(self, cid, prefix):
212        if self._message_factory is not None:
213            self._message_factory.set_lowpan_context(cid, prefix)
214
215    def get_messages_sent_by(self, nodeid):
216        """ Get sniffed messages.
217
218        Note! This method flushes the message queue so calling this
219        method again will return only the newly logged messages.
220
221        Args:
222            nodeid (int): node id
223
224        Returns:
225            MessagesSet: a set with received messages.
226        """
227        addr = ('127.0.0.1', self.port + nodeid)
228
229        messages = self.devices[addr]['msgs']
230        self.devices[addr]['msgs'] = []
231
232        ret = message.MessagesSet(messages, self.commissioning_messages[nodeid])
233        self.commissioning_messages[nodeid] = []
234        return ret
235
236    def _is_radio(self, addr):
237        return addr[1] < self.BASE_PORT * 2
238
239    def _to_core_addr(self, addr):
240        assert self._is_radio(addr)
241        return (addr[0], addr[1] + self.BASE_PORT)
242
243    def _to_radio_addr(self, addr):
244        assert not self._is_radio(addr)
245        return (addr[0], addr[1] - self.BASE_PORT)
246
247    def _core_addr_from(self, nodeid):
248        if self._nodes[nodeid].is_posix:
249            return ('127.0.0.1', self.BASE_PORT + self.port + nodeid)
250        else:
251            return ('127.0.0.1', self.port + nodeid)
252
253    def _next_event_time(self):
254        if len(self.event_queue) == 0:
255            return self.END_OF_TIME
256        else:
257            return self.event_queue[0][0]
258
259    def receive_events(self):
260        """ Receive events until all devices are asleep. """
261        while True:
262            if (self.current_event or len(self.awake_devices) or
263                (self._next_event_time() > self._pause_time and self.current_nodeid)):
264                self.sock.settimeout(self.BLOCK_TIMEOUT)
265                try:
266                    msg, addr = self.sock.recvfrom(self.MAX_MESSAGE)
267                except socket.error:
268                    # print debug information on failure
269                    print('Current nodeid:')
270                    print(self.current_nodeid)
271                    print('Current awake:')
272                    print(self.awake_devices)
273                    print('Current time:')
274                    print(self.current_time)
275                    print('Current event:')
276                    print(self.current_event)
277                    print('Events:')
278                    for event in self.event_queue:
279                        print(event)
280                    raise
281            else:
282                self.sock.settimeout(0)
283                try:
284                    msg, addr = self.sock.recvfrom(self.MAX_MESSAGE)
285                except socket.error:
286                    break
287
288            if addr != self._spinel_cli_addr and addr not in self.devices:
289                self.devices[addr] = {}
290                self.devices[addr]['alarm'] = None
291                self.devices[addr]['msgs'] = []
292                self.devices[addr]['time'] = self.current_time
293                self.awake_devices.discard(addr)
294                # print "New device:", addr, self.devices
295
296            delay, type, datalen = struct.unpack('=QBH', msg[:11])
297            data = msg[11:]
298
299            event_time = self.current_time + delay
300
301            if data:
302                dbg_print(
303                    "New event: ",
304                    event_time,
305                    addr,
306                    type,
307                    datalen,
308                    binascii.hexlify(data),
309                )
310            else:
311                dbg_print("New event: ", event_time, addr, type, datalen)
312
313            if type == self.OT_SIM_EVENT_ALARM_FIRED:
314                # remove any existing alarm event for device
315                if self.devices[addr]['alarm']:
316                    self.event_queue.remove(self.devices[addr]['alarm'])
317                    # print "-- Remove\t", self.devices[addr]['alarm']
318
319                # add alarm event to event queue
320                event = (event_time, self.event_sequence, addr, type, datalen)
321                self.event_sequence += 1
322                # print "-- Enqueue\t", event, delay, self.current_time
323                bisect.insort(self.event_queue, event)
324                self.devices[addr]['alarm'] = event
325
326                self.awake_devices.discard(addr)
327
328                if (self.current_event and self.current_event[self.EVENT_ADDR] == addr):
329                    # print "Done\t", self.current_event
330                    self.current_event = None
331
332            elif type == self.OT_SIM_EVENT_RADIO_RECEIVED:
333                assert self._is_radio(addr)
334                # add radio receive events event queue
335                frame_info = wpan.dissect(data)
336
337                recv_devices = None
338                if frame_info.frame_type == wpan.FrameType.ACK:
339                    recv_devices = self._nodes_by_ack_seq.get(frame_info.seq_no)
340
341                recv_devices = recv_devices or self.devices.keys()
342
343                for device in recv_devices:
344                    if device != addr and self._is_radio(device):
345                        event = (
346                            event_time,
347                            self.event_sequence,
348                            device,
349                            type,
350                            datalen,
351                            data,
352                        )
353                        self.event_sequence += 1
354                        # print "-- Enqueue\t", event
355                        bisect.insort(self.event_queue, event)
356
357                self._pcap.append(data, (event_time // 1000000, event_time % 1000000))
358                self._add_message(addr[1] - self.port, data)
359
360                # add radio transmit done events to event queue
361                event = (
362                    event_time,
363                    self.event_sequence,
364                    addr,
365                    type,
366                    datalen,
367                    data,
368                )
369                self.event_sequence += 1
370                bisect.insort(self.event_queue, event)
371
372                if frame_info.frame_type != wpan.FrameType.ACK and not frame_info.is_broadcast:
373                    self._on_ack_seq_change(addr, frame_info.seq_no)
374
375                self.awake_devices.add(addr)
376
377            elif type == self.OT_SIM_EVENT_RADIO_SPINEL_WRITE:
378                assert not self._is_radio(addr)
379                radio_addr = self._to_radio_addr(addr)
380                if radio_addr not in self.devices:
381                    self.awake_devices.add(radio_addr)
382
383                event = (
384                    event_time,
385                    self.event_sequence,
386                    radio_addr,
387                    self.OT_SIM_EVENT_UART_WRITE,
388                    datalen,
389                    data,
390                )
391                self.event_sequence += 1
392                bisect.insort(self.event_queue, event)
393
394                self.awake_devices.add(addr)
395
396            elif type == self.OT_SIM_EVENT_UART_WRITE:
397                assert self._is_radio(addr)
398                core_addr = self._to_core_addr(addr)
399                if core_addr not in self.devices:
400                    self.awake_devices.add(core_addr)
401
402                event = (
403                    event_time,
404                    self.event_sequence,
405                    core_addr,
406                    self.OT_SIM_EVENT_RADIO_SPINEL_WRITE,
407                    datalen,
408                    data,
409                )
410                self.event_sequence += 1
411                bisect.insort(self.event_queue, event)
412
413                self.awake_devices.add(addr)
414
415            elif type == self.OT_SIM_EVENT_POSTCMD:
416                assert self.current_time == self._pause_time
417                nodeid = struct.unpack('=B', data)[0]
418                if self.current_nodeid == nodeid:
419                    self.current_nodeid = None
420
421    def _on_ack_seq_change(self, device: tuple, seq_no: int):
422        old_seq = self._node_ack_seq.pop(device, None)
423        if old_seq is not None:
424            self._nodes_by_ack_seq[old_seq].remove(device)
425
426        self._node_ack_seq[device] = seq_no
427        self._nodes_by_ack_seq.setdefault(seq_no, set()).add(device)
428
429    def _send_message(self, message, addr):
430        while True:
431            try:
432                sent = self.sock.sendto(message, addr)
433            except socket.error:
434                traceback.print_exc()
435                time.sleep(0)
436            else:
437                break
438        assert sent == len(message)
439
440    def process_next_event(self):
441        assert self.current_event is None
442        assert self._next_event_time() < self.END_OF_TIME
443
444        # process next event
445        event = self.event_queue.pop(0)
446
447        if len(event) == 5:
448            event_time, sequence, addr, type, datalen = event
449            dbg_print("Pop event: ", event_time, addr, type, datalen)
450        else:
451            event_time, sequence, addr, type, datalen, data = event
452            dbg_print(
453                "Pop event: ",
454                event_time,
455                addr,
456                type,
457                datalen,
458                binascii.hexlify(data),
459            )
460
461        self.current_event = event
462
463        assert event_time >= self.current_time
464        self.current_time = event_time
465
466        elapsed = event_time - self.devices[addr]['time']
467        self.devices[addr]['time'] = event_time
468
469        message = struct.pack('=QBH', elapsed, type, datalen)
470
471        if type == self.OT_SIM_EVENT_ALARM_FIRED:
472            self.devices[addr]['alarm'] = None
473            self._send_message(message, addr)
474        elif type == self.OT_SIM_EVENT_RADIO_RECEIVED:
475            message += data
476            self._send_message(message, addr)
477        elif type == self.OT_SIM_EVENT_RADIO_SPINEL_WRITE:
478            message += data
479            self._send_message(message, addr)
480        elif type == self.OT_SIM_EVENT_UART_WRITE:
481            message += data
482            self._send_message(message, addr)
483
484    def sync_devices(self):
485        self.current_time = self._pause_time
486        for addr in self.devices:
487            elapsed = self.current_time - self.devices[addr]['time']
488            if elapsed == 0:
489                continue
490            dbg_print('syncing', addr, elapsed)
491            self.devices[addr]['time'] = self.current_time
492            message = struct.pack('=QBH', elapsed, self.OT_SIM_EVENT_ALARM_FIRED, 0)
493            self._send_message(message, addr)
494            self.awake_devices.add(addr)
495            self.receive_events()
496        self.awake_devices.clear()
497
498    def now(self):
499        return self.current_time / 1000000
500
501    def go(self, duration, nodeid=None):
502        assert self.current_time == self._pause_time
503        duration = int(duration * 1000000)
504        dbg_print('running for %d us' % duration)
505        self._pause_time += duration
506        if nodeid:
507            if self.NCP_SIM:
508                self.current_nodeid = nodeid
509            self.awake_devices.add(self._core_addr_from(nodeid))
510        self.receive_events()
511        while self._next_event_time() <= self._pause_time:
512            self.process_next_event()
513            self.receive_events()
514        if duration > 0:
515            self.sync_devices()
516        dbg_print('current time %d us' % self.current_time)
517
518
519if __name__ == '__main__':
520    simulator = VirtualTime()
521    while True:
522        simulator.go(0)
523