1#!/usr/bin/env python3
2#
3#  Copyright (c) 2016, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import io
31import ipaddress
32import struct
33
34import coap
35import common
36import dtls
37import ipv6
38import mac802154
39import mle
40
41from enum import IntEnum
42
43
44class DropPacketException(Exception):
45    pass
46
47
48class MessageType(IntEnum):
49    MLE = 0
50    COAP = 1
51    ICMP = 2
52    ACK = 3
53    BEACON = 4
54    DATA = 5
55    COMMAND = 6
56    DTLS = 7
57
58
59class Message(object):
60
61    def __init__(self):
62        self._type = None
63        self._channel = None
64        self._mac_header = None
65        self._ipv6_packet = None
66        self._coap = None
67        self._mle = None
68        self._icmp = None
69        self._dtls = None
70
71    def _extract_udp_datagram(self, udp_datagram):
72        if isinstance(udp_datagram.payload, mle.MleMessage):
73            self._type = MessageType.MLE
74            self._mle = udp_datagram.payload
75
76        elif isinstance(udp_datagram.payload, (coap.CoapMessage, coap.CoapMessageProxy)):
77            self._type = MessageType.COAP
78            self._coap = udp_datagram.payload
79
80        # DTLS message factory returns a list of messages
81        elif isinstance(udp_datagram.payload, list):
82            self._type = MessageType.DTLS
83            self._dtls = udp_datagram.payload
84
85    def _extract_upper_layer_protocol(self, upper_layer_protocol):
86        if isinstance(upper_layer_protocol, ipv6.ICMPv6):
87            self._type = MessageType.ICMP
88            self._icmp = upper_layer_protocol
89
90        elif isinstance(upper_layer_protocol, ipv6.UDPDatagram):
91            self._extract_udp_datagram(upper_layer_protocol)
92
93    def try_extract_dtls_messages(self):
94        """Extract multiple dtls messages that are sent in a single UDP datagram
95        """
96        if self.type != MessageType.DTLS:
97            return [self.clone()]
98
99        assert isinstance(self.dtls, list)
100        ret = []
101        for dtls in self.dtls:
102            msg = self.clone()
103            msg._dtls = dtls
104            ret.append(msg)
105        return ret
106
107    def clone(self):
108        msg = Message()
109        msg._type = self.type
110        msg._channel = self.channel
111        msg._mac_header = self.mac_header
112        msg._ipv6_packet = self.ipv6_packet
113        msg._coap = self.coap
114        msg._mle = self.mle
115        msg._icmp = self.icmp
116        msg._dtls = self.dtls
117        return msg
118
119    @property
120    def type(self):
121        return self._type
122
123    @type.setter
124    def type(self, value):
125        self._type = value
126
127    @property
128    def channel(self):
129        return self._channel
130
131    @channel.setter
132    def channel(self, value):
133        self._channel = value
134
135    @property
136    def mac_header(self):
137        return self._mac_header
138
139    @mac_header.setter
140    def mac_header(self, value):
141        self._mac_header = value
142
143        if self._mac_header.frame_type == mac802154.MacHeader.FrameType.BEACON:
144            self._type = MessageType.BEACON
145
146        elif self._mac_header.frame_type == mac802154.MacHeader.FrameType.ACK:
147            self._type = MessageType.ACK
148
149        elif self._mac_header.frame_type == mac802154.MacHeader.FrameType.DATA:
150            self._type = MessageType.DATA
151        elif (self._mac_header.frame_type == mac802154.MacHeader.FrameType.COMMAND):
152            self._type = MessageType.COMMAND
153        else:
154            raise ValueError('Invalid mac frame type %d' % self._mac_header.frame_type)
155
156    @property
157    def ipv6_packet(self):
158        return self._ipv6_packet
159
160    @ipv6_packet.setter
161    def ipv6_packet(self, value):
162        self._ipv6_packet = value
163        self._extract_upper_layer_protocol(value.upper_layer_protocol)
164
165    @property
166    def coap(self):
167        return self._coap
168
169    @property
170    def mle(self):
171        return self._mle
172
173    @property
174    def icmp(self):
175        return self._icmp
176
177    @icmp.setter
178    def icmp(self, value):
179        self._icmp = value
180
181    @property
182    def dtls(self):
183        return self._dtls
184
185    def get_mle_message_tlv(self, tlv_class_type):
186        if self.type != MessageType.MLE:
187            raise ValueError("Invalid message type. Expected MLE message.")
188
189        for tlv in self.mle.command.tlvs:
190            if isinstance(tlv, tlv_class_type):
191                return tlv
192
193    def assertMleMessageIsType(self, command_type):
194        if self.type != MessageType.MLE:
195            raise ValueError("Invalid message type. Expected MLE message.")
196
197        assert self.mle.command.type == command_type
198
199    def assertMleMessageContainsTlv(self, tlv_class_type):
200        """To confirm if Mle message contains the TLV type.
201
202        Args:
203            tlv_class_type: tlv's type.
204
205        Returns:
206            mle.Route64: If contains the TLV, return it.
207        """
208        if self.type != MessageType.MLE:
209            raise ValueError("Invalid message type. Expected MLE message.")
210
211        contains_tlv = False
212        for tlv in self.mle.command.tlvs:
213            if isinstance(tlv, tlv_class_type):
214                contains_tlv = True
215                break
216
217        assert contains_tlv
218        return tlv
219
220    def assertAssignedRouterQuantity(self, router_quantity):
221        """Confirm if Leader contains the Route64 TLV with router_quantity assigned Router IDs.
222
223        Args:
224            router_quantity: the quantity of router.
225        """
226        tlv = self.assertMleMessageContainsTlv(mle.Route64)
227        router_id_mask = tlv.router_id_mask
228
229        count = 0
230        for i in range(1, 65):
231            count += router_id_mask & 1
232            router_id_mask = router_id_mask >> 1
233        assert count == router_quantity
234
235    def assertMleMessageDoesNotContainTlv(self, tlv_class_type):
236        if self.type != MessageType.MLE:
237            raise ValueError("Invalid message type. Expected MLE message.")
238
239        contains_tlv = False
240        for tlv in self.mle.command.tlvs:
241            if isinstance(tlv, tlv_class_type):
242                contains_tlv = True
243                break
244
245        assert contains_tlv is False
246
247    def assertMleMessageContainsOptionalTlv(self, tlv_class_type):
248        if self.type != MessageType.MLE:
249            raise ValueError("Invalid message type. Expected MLE message.")
250
251        contains_tlv = False
252        for tlv in self.mle.command.tlvs:
253            if isinstance(tlv, tlv_class_type):
254                contains_tlv = True
255                break
256
257        if contains_tlv:
258            print("MleMessage contains optional TLV: {}".format(tlv_class_type))
259        else:
260            print("MleMessage doesn't contain optional TLV: {}".format(tlv_class_type))
261
262    def get_coap_message_tlv(self, tlv_class_type):
263        if self.type != MessageType.COAP:
264            raise ValueError("Invalid message type. Expected CoAP message.")
265
266        for tlv in self.coap.payload:
267            if isinstance(tlv, tlv_class_type):
268                return tlv
269
270    def assertCoapMessageContainsTlv(self, tlv_class_type):
271        if self.type != MessageType.COAP:
272            raise ValueError("Invalid message type. Expected CoAP message.")
273
274        contains_tlv = False
275        for tlv in self.coap.payload:
276            if isinstance(tlv, tlv_class_type):
277                contains_tlv = True
278                break
279
280        assert contains_tlv
281
282    def assertCoapMessageDoesNotContainTlv(self, tlv_class_type):
283        if self.type != MessageType.COAP:
284            raise ValueError("Invalid message type. Expected COAP message.")
285
286        contains_tlv = False
287        for tlv in self.coap.payload:
288            if isinstance(tlv, tlv_class_type):
289                contains_tlv = True
290                break
291
292        assert contains_tlv is False
293
294    def assertCoapMessageContainsOptionalTlv(self, tlv_class_type):
295        if self.type != MessageType.COAP:
296            raise ValueError("Invalid message type. Expected CoAP message.")
297
298        for tlv in self.coap.payload:
299            if isinstance(tlv, tlv_class_type):
300                break
301
302        print("CoapMessage doesn't contain optional TLV: {}".format(tlv_class_type))
303
304    def assertCoapMessageRequestUriPath(self, uri_path):
305        if self.type != MessageType.COAP:
306            raise ValueError("Invalid message type. Expected CoAP message.")
307
308        assert uri_path == self.coap.uri_path
309
310    def assertCoapMessageCode(self, code):
311        if self.type != MessageType.COAP:
312            raise ValueError("Invalid message type. Expected CoAP message.")
313
314        assert code == self.coap.code
315
316    def assertSentToNode(self, node):
317        sent_to_node = False
318        dst_addr = self.ipv6_packet.ipv6_header.destination_address
319
320        for addr in node.get_addrs():
321            if dst_addr == ipaddress.ip_address(addr):
322                sent_to_node = True
323
324        if self.mac_header.dest_address.type == common.MacAddressType.SHORT:
325            mac_address = common.MacAddress.from_rloc16(node.get_addr16())
326            if self.mac_header.dest_address == mac_address:
327                sent_to_node = True
328
329        elif self.mac_header.dest_address.type == common.MacAddressType.LONG:
330            mac_address = common.MacAddress.from_eui64(bytearray(node.get_addr64(), encoding="utf-8"))
331            if self.mac_header.dest_address == mac_address:
332                sent_to_node = True
333
334        assert sent_to_node
335
336    def assertSentToDestinationAddress(self, ipv6_address):
337        assert (self.ipv6_packet.ipv6_header.destination_address == ipaddress.ip_address(ipv6_address))
338
339    def assertSentFromSourceAddress(self, ipv6_address):
340        assert (self.ipv6_packet.ipv6_header.source_address == ipaddress.ip_address(ipv6_address))
341
342    def assertSentWithHopLimit(self, hop_limit):
343        assert self.ipv6_packet.ipv6_header.hop_limit == hop_limit
344
345    def isMacAddressTypeLong(self):
346        return self.mac_header.dest_address.type == common.MacAddressType.LONG
347
348    def get_dst_udp_port(self):
349        assert isinstance(self.ipv6_packet.upper_layer_protocol, ipv6.UDPDatagram)
350        return self.ipv6_packet.upper_layer_protocol.header.dst_port
351
352    def is_data_poll(self):
353        return self._type == MessageType.COMMAND and \
354            self._mac_header.command_type == mac802154.MacHeader.CommandIdentifier.DATA_REQUEST
355
356    def __repr__(self):
357        if (self.type == MessageType.DTLS and self.dtls.content_type == dtls.ContentType.HANDSHAKE):
358            return "Message(type={})".format(str(self.dtls.handshake_type))
359        return "Message(type={})".format(MessageType(self.type).name)
360
361
362class MessagesSet(object):
363
364    def __init__(self, messages, commissioning_messages=()):
365        self._messages = messages
366        self._commissioning_messages = commissioning_messages
367
368    @property
369    def messages(self):
370        return self._messages
371
372    @property
373    def commissioning_messages(self):
374        return self._commissioning_messages
375
376    def next_data_poll(self):
377        while True:
378            message = self.next_message_of(MessageType.COMMAND, False)
379            if not message:
380                break
381            elif message.is_data_poll():
382                return message
383
384    def next_coap_message(self, code, uri_path=None, assert_enabled=True):
385        message = None
386
387        while self.messages:
388            m = self.messages.pop(0)
389
390            if m.type != MessageType.COAP:
391                continue
392
393            if uri_path is not None and m.coap.uri_path != uri_path:
394                continue
395
396            else:
397                if not m.coap.code.is_equal_dotted(code):
398                    continue
399
400            message = m
401            break
402
403        if assert_enabled:
404            assert (message is not None), "Could not find CoapMessage with code: {}".format(code)
405
406        return message
407
408    def last_mle_message(self, command_type, assert_enabled=True):
409        """Get the last Mle Message with specified type from existing capture.
410
411        Args:
412            command_type: the specified mle type.
413            assert_enabled: interrupt or not when get the mle.
414
415        Returns:
416            message.Message: the last Mle Message with specified type.
417        """
418        message = None
419        size = len(self.messages)
420
421        for i in range(size - 1, -1, -1):
422            m = self.messages[i]
423
424            if m.type != MessageType.MLE:
425                continue
426
427            # for command_type in command_types:
428            if m.mle.command.type == command_type:
429                message = m
430                break
431
432        if assert_enabled:
433            assert (message is not None), "Could not find MleMessage with type: {}".format(command_type)
434
435        return message
436
437    def next_mle_message(self, command_type, assert_enabled=True, sent_to_node=None):
438        message = self.next_mle_message_of_one_of_command_types(command_type)
439
440        if assert_enabled:
441            assert (message is not None), "Could not find MleMessage of the type: {}".format(command_type)
442
443        if sent_to_node is not None:
444            message.assertSentToNode(sent_to_node)
445
446        return message
447
448    def next_mle_message_of_one_of_command_types(self, *command_types):
449        message = None
450
451        while self.messages:
452            m = self.messages.pop(0)
453
454            if m.type != MessageType.MLE:
455                continue
456
457            command_found = False
458
459            for command_type in command_types:
460                if m.mle.command.type == command_type:
461                    command_found = True
462                    break
463
464            if command_found:
465                message = m
466                break
467
468        return message
469
470    def next_message(self, assert_enabled=True):
471        message = self.messages.pop(0)
472        if assert_enabled:
473            assert message is not None, "Could not find next Message"
474        return message
475
476    def next_message_of(self, message_type, assert_enabled=True):
477        message = None
478
479        while self.messages:
480            m = self.messages.pop(0)
481            if m.type != message_type:
482                continue
483
484            message = m
485            break
486
487        if assert_enabled:
488            assert (message is not None), "Could not find Message of the type: {}".format(message_type)
489
490        return message
491
492    def next_data_message(self):
493        return self.next_message_of(MessageType.DATA)
494
495    def next_command_message(self):
496        return self.next_message_of(MessageType.COMMAND)
497
498    def next_dtls_message(self, content_type, handshake_type=None):
499        while self.messages:
500            msg = self.messages.pop(0)
501            if msg.type != MessageType.DTLS:
502                continue
503            if msg.dtls.content_type != content_type:
504                continue
505            if (content_type == dtls.ContentType.HANDSHAKE and msg.dtls.handshake_type != handshake_type):
506                continue
507            return msg
508
509        t = (handshake_type if content_type == dtls.ContentType.HANDSHAKE else content_type)
510        raise ValueError("Could not find DTLS message of type: {}".format(str(t)))
511
512    def contains_icmp_message(self):
513        for m in self.messages:
514            if m.type == MessageType.ICMP:
515                return True
516
517        return False
518
519    def get_icmp_message(self, icmp_type):
520        for m in self.messages:
521            if m.type != MessageType.ICMP:
522                continue
523
524            if m.icmp.header.type == icmp_type:
525                return m
526
527        return None
528
529    def contains_mle_message(self, command_type):
530        for m in self.messages:
531            if m.type != MessageType.MLE:
532                continue
533
534            if m.mle.command.type == command_type:
535                return True
536
537        return False
538
539    def does_not_contain_coap_message(self):
540        for m in self.messages:
541            if m.type != MessageType.COAP:
542                continue
543
544            return False
545
546        return True
547
548    def clone(self):
549        """Make a copy of current MessageSet.
550        """
551        return MessagesSet(self.messages[:], self.commissioning_messages[:])
552
553    def __repr__(self):
554        return str(self.messages)
555
556
557class MessageFactory:
558
559    def __init__(self, lowpan_parser):
560        self._lowpan_parser = lowpan_parser
561
562    def _add_device_descriptors(self, message):
563        for tlv in message.mle.command.tlvs:
564
565            if isinstance(tlv, mle.SourceAddress):
566                mac802154.DeviceDescriptors.add(tlv.address, message.mac_header.src_address)
567
568            if isinstance(tlv, mle.Address16):
569                mac802154.DeviceDescriptors.add(tlv.address, message.mac_header.dest_address)
570
571    def _parse_mac_frame(self, data):
572        mac_frame = mac802154.MacFrame()
573        mac_frame.parse(data)
574        return mac_frame
575
576    def set_lowpan_context(self, cid, prefix):
577        self._lowpan_parser.set_lowpan_context(cid, prefix)
578
579    def create(self, data):
580        try:
581            message = Message()
582            message.channel = struct.unpack(">B", data.read(1))
583
584            # Parse MAC header
585            mac_frame = self._parse_mac_frame(data)
586            message.mac_header = mac_frame.header
587
588            if message.mac_header.frame_type != mac802154.MacHeader.FrameType.DATA:
589                return [message]
590
591            message_info = common.MessageInfo()
592            message_info.source_mac_address = message.mac_header.src_address
593            message_info.destination_mac_address = message.mac_header.dest_address
594
595            # Create stream with 6LoWPAN datagram
596            lowpan_payload = io.BytesIO(mac_frame.payload.data)
597
598            ipv6_packet = self._lowpan_parser.parse(lowpan_payload, message_info)
599            if ipv6_packet is None:
600                return [message]
601
602            message.ipv6_packet = ipv6_packet
603
604            if message.type == MessageType.MLE:
605                self._add_device_descriptors(message)
606
607            return message.try_extract_dtls_messages()
608
609        except mac802154.KeyIdMode0Exception:
610            print('Received packet with key_id_mode = 0, cannot be handled in test scripts')
611            raise DropPacketException
612