1#!/usr/bin/env python3
2#
3#  Copyright (c) 2017-2018, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30import binascii
31
32import ipaddress
33import ipv6
34import network_data
35import network_layer
36import common
37import config
38import mesh_cop
39import mle
40
41from enum import IntEnum
42
43
44class CheckType(IntEnum):
45    CONTAIN = 0
46    NOT_CONTAIN = 1
47    OPTIONAL = 2
48
49
50class NetworkDataCheckType:
51    PREFIX_CNT = 1
52    PREFIX_CONTENT = 2
53
54
55def check_address_query(command_msg, source_node, destination_address):
56    """Verify source_node sent a properly formatted Address Query Request message to the destination_address.
57    """
58    command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid)
59
60    source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC)
61    assert (ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address), (
62        "Error: The IPv6 source address is not the RLOC of the originator. The source node's rloc is: " +
63        str(ipv6.ip_address(source_rloc)) + ", but the source_address in command msg is: " +
64        str(command_msg.ipv6_packet.ipv6_header.source_address))
65
66    if isinstance(destination_address, bytearray):
67        destination_address = bytes(destination_address)
68
69    assert (ipv6.ip_address(destination_address) == command_msg.ipv6_packet.ipv6_header.destination_address
70           ), "Error: The IPv6 destination address is not expected."
71
72
73def check_address_notification(command_msg, source_node, destination_node):
74    """Verify source_node sent a properly formatted Address Notification command message to destination_node.
75    """
76    command_msg.assertCoapMessageRequestUriPath('/a/an')
77    command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid)
78    command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16)
79    command_msg.assertCoapMessageContainsTlv(network_layer.MlEid)
80
81    source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC)
82    assert (ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address
83           ), "Error: The IPv6 source address is not the RLOC of the originator."
84
85    destination_rloc = destination_node.get_ip6_address(config.ADDRESS_TYPE.RLOC)
86    assert (ipv6.ip_address(destination_rloc) == command_msg.ipv6_packet.ipv6_header.destination_address
87           ), "Error: The IPv6 destination address is not the RLOC of the destination."
88
89
90def check_address_error_notification(command_msg, source_node, destination_address):
91    """Verify source_node sent a properly formatted Address Error Notification command message to destination_address.
92    """
93    command_msg.assertCoapMessageRequestUriPath('/a/ae')
94    command_msg.assertCoapMessageContainsTlv(network_layer.TargetEid)
95    command_msg.assertCoapMessageContainsTlv(network_layer.MlEid)
96
97    source_rloc = source_node.get_ip6_address(config.ADDRESS_TYPE.RLOC)
98    assert (ipv6.ip_address(source_rloc) == command_msg.ipv6_packet.ipv6_header.source_address), (
99        "Error: The IPv6 source address is not the RLOC of the originator. The source node's rloc is: " +
100        str(ipv6.ip_address(source_rloc)) + ", but the source_address in command msg is: " +
101        str(command_msg.ipv6_packet.ipv6_header.source_address))
102
103    if isinstance(destination_address, bytearray):
104        destination_address = bytes(destination_address)
105
106    assert (ipv6.ip_address(destination_address) == command_msg.ipv6_packet.ipv6_header.destination_address), (
107        "Error: The IPv6 destination address is not expected. The destination node's rloc is: " +
108        str(ipv6.ip_address(destination_address)) + ", but the destination_address in command msg is: " +
109        str(command_msg.ipv6_packet.ipv6_header.destination_address))
110
111
112def check_address_solicit(command_msg, was_router):
113    command_msg.assertCoapMessageRequestUriPath('/a/as')
114    command_msg.assertCoapMessageContainsTlv(network_layer.MacExtendedAddress)
115    command_msg.assertCoapMessageContainsTlv(network_layer.Status)
116    if was_router:
117        command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16)
118    else:
119        command_msg.assertMleMessageDoesNotContainTlv(network_layer.Rloc16)
120
121
122def check_address_release(command_msg, destination_node):
123    """Verify the message is a properly formatted address release destined to the given node.
124    """
125    command_msg.assertCoapMessageRequestUriPath('/a/ar')
126    command_msg.assertCoapMessageContainsTlv(network_layer.Rloc16)
127    command_msg.assertCoapMessageContainsTlv(network_layer.MacExtendedAddress)
128
129    destination_rloc = destination_node.get_ip6_address(config.ADDRESS_TYPE.RLOC)
130    assert (ipv6.ip_address(destination_rloc) == command_msg.ipv6_packet.ipv6_header.destination_address
131           ), "Error: The destination is not RLOC address"
132
133
134def check_tlv_request_tlv(command_msg, check_type, tlv_id):
135    """Verify if TLV Request TLV contains specified TLV ID
136    """
137    tlv_request_tlv = command_msg.get_mle_message_tlv(mle.TlvRequest)
138
139    if check_type == CheckType.CONTAIN:
140        assert (tlv_request_tlv is not None), "Error: The msg doesn't contain TLV Request TLV"
141        assert any(
142            tlv_id == tlv
143            for tlv in tlv_request_tlv.tlvs), "Error: The msg doesn't contain TLV Request TLV ID: {}".format(tlv_id)
144
145    elif check_type == CheckType.NOT_CONTAIN:
146        if tlv_request_tlv is not None:
147            assert (any(tlv_id == tlv for tlv in tlv_request_tlv.tlvs) is
148                    False), "Error: The msg contains TLV Request TLV ID: {}".format(tlv_id)
149
150    elif check_type == CheckType.OPTIONAL:
151        if tlv_request_tlv is not None:
152            if any(tlv_id == tlv for tlv in tlv_request_tlv.tlvs):
153                print("TLV Request TLV contains TLV ID: {}".format(tlv_id))
154            else:
155                print("TLV Request TLV doesn't contain TLV ID: {}".format(tlv_id))
156        else:
157            print("The msg doesn't contain TLV Request TLV")
158
159    else:
160        raise ValueError("Invalid check type")
161
162
163def check_link_request(
164    command_msg,
165    source_address=CheckType.OPTIONAL,
166    leader_data=CheckType.OPTIONAL,
167    tlv_request_address16=CheckType.OPTIONAL,
168    tlv_request_route64=CheckType.OPTIONAL,
169    tlv_request_link_margin=CheckType.OPTIONAL,
170):
171    """Verify a properly formatted Link Request command message.
172    """
173    command_msg.assertMleMessageContainsTlv(mle.Challenge)
174    command_msg.assertMleMessageContainsTlv(mle.Version)
175
176    check_mle_optional_tlv(command_msg, source_address, mle.SourceAddress)
177    check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData)
178
179    check_tlv_request_tlv(command_msg, tlv_request_address16, mle.TlvType.ADDRESS16)
180    check_tlv_request_tlv(command_msg, tlv_request_route64, mle.TlvType.ROUTE64)
181    check_tlv_request_tlv(command_msg, tlv_request_link_margin, mle.TlvType.LINK_MARGIN)
182
183
184def check_link_accept(
185    command_msg,
186    destination_node,
187    leader_data=CheckType.OPTIONAL,
188    link_margin=CheckType.OPTIONAL,
189    mle_frame_counter=CheckType.OPTIONAL,
190    challenge=CheckType.OPTIONAL,
191    address16=CheckType.OPTIONAL,
192    route64=CheckType.OPTIONAL,
193    tlv_request_link_margin=CheckType.OPTIONAL,
194):
195    """verify a properly formatted link accept command message.
196    """
197    command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter)
198    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
199    command_msg.assertMleMessageContainsTlv(mle.Response)
200    command_msg.assertMleMessageContainsTlv(mle.Version)
201
202    check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData)
203    check_mle_optional_tlv(command_msg, link_margin, mle.LinkMargin)
204    check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter)
205    check_mle_optional_tlv(command_msg, challenge, mle.Challenge)
206    check_mle_optional_tlv(command_msg, address16, mle.Address16)
207    check_mle_optional_tlv(command_msg, route64, mle.Route64)
208
209    check_tlv_request_tlv(command_msg, tlv_request_link_margin, mle.TlvType.LINK_MARGIN)
210
211    destination_link_local = destination_node.get_ip6_address(config.ADDRESS_TYPE.LINK_LOCAL)
212    assert (ipv6.ip_address(destination_link_local) == command_msg.ipv6_packet.ipv6_header.destination_address
213           ), "Error: The destination is unexpected"
214
215
216def check_icmp_path(sniffer, path, nodes, icmp_type=ipv6.ICMP_ECHO_REQUEST):
217    """Verify icmp message is forwarded along the path.
218    """
219    len_path = len(path)
220
221    # Verify icmp message is forwarded to the next node of the path.
222    for i in range(0, len_path):
223        node_msg = sniffer.get_messages_sent_by(path[i])
224        node_icmp_msg = node_msg.get_icmp_message(icmp_type)
225
226        if i < len_path - 1:
227            next_node = nodes[path[i + 1]]
228            next_node_rloc16 = next_node.get_addr16()
229            assert (next_node_rloc16 == node_icmp_msg.mac_header.dest_address.rloc), "Error: The path is unexpected."
230        else:
231            return True
232
233    return False
234
235
236def check_id_set(command_msg, router_id):
237    """Check the command_msg's Route64 tlv to verify router_id is an active router.
238    """
239    tlv = command_msg.assertMleMessageContainsTlv(mle.Route64)
240    return (tlv.router_id_mask >> (63 - router_id)) & 1
241
242
243def get_routing_cost(command_msg, router_id):
244    """Check the command_msg's Route64 tlv to get the routing cost to router.
245    """
246    tlv = command_msg.assertMleMessageContainsTlv(mle.Route64)
247
248    # Get router's mask pos
249    # Turn the number into binary string. Need to consider the preceding 0
250    # omitted during conversion.
251    router_id_mask_str = bin(tlv.router_id_mask).replace('0b', '')
252    prefix_len = 64 - len(router_id_mask_str)
253    routing_entry_pos = 0
254
255    for i in range(0, router_id - prefix_len):
256        if router_id_mask_str[i] == '1':
257            routing_entry_pos += 1
258
259    assert router_id_mask_str[router_id - prefix_len] == '1', \
260        (("Error: The router isn't in the topology. \n",
261          "route64 tlv is: %s. \nrouter_id is: %s. \nrouting_entry_pos is: %s. \nrouter_id_mask_str is: %s.") %
262         (tlv, router_id, routing_entry_pos, router_id_mask_str))
263
264    return tlv.link_quality_and_route_data[routing_entry_pos].route
265
266
267def check_mle_optional_tlv(command_msg, type, tlv):
268    if type == CheckType.CONTAIN:
269        command_msg.assertMleMessageContainsTlv(tlv)
270    elif type == CheckType.NOT_CONTAIN:
271        command_msg.assertMleMessageDoesNotContainTlv(tlv)
272    elif type == CheckType.OPTIONAL:
273        command_msg.assertMleMessageContainsOptionalTlv(tlv)
274    else:
275        raise ValueError("Invalid check type")
276
277
278def check_mle_advertisement(command_msg):
279    command_msg.assertSentWithHopLimit(255)
280    command_msg.assertSentToDestinationAddress(config.LINK_LOCAL_ALL_NODES_ADDRESS)
281    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
282    command_msg.assertMleMessageContainsTlv(mle.LeaderData)
283    command_msg.assertMleMessageContainsTlv(mle.Route64)
284
285
286def check_parent_request(command_msg, is_first_request):
287    """Verify a properly formatted Parent Request command message.
288    """
289    if command_msg.mle.aux_sec_hdr.key_id_mode != 0x2:
290        raise ValueError("The Key Identifier Mode of the Security Control Field SHALL be set to 0x02")
291
292    command_msg.assertSentWithHopLimit(255)
293    command_msg.assertSentToDestinationAddress(config.LINK_LOCAL_ALL_ROUTERS_ADDRESS)
294    command_msg.assertMleMessageContainsTlv(mle.Mode)
295    command_msg.assertMleMessageContainsTlv(mle.Challenge)
296    command_msg.assertMleMessageContainsTlv(mle.Version)
297    scan_mask = command_msg.assertMleMessageContainsTlv(mle.ScanMask)
298    if not scan_mask.router:
299        raise ValueError("Parent request without R bit set")
300    if is_first_request:
301        if scan_mask.end_device:
302            raise ValueError("First parent request with E bit set")
303    elif not scan_mask.end_device:
304        raise ValueError("Second parent request without E bit set")
305
306
307def check_parent_response(command_msg, mle_frame_counter=CheckType.OPTIONAL):
308    """Verify a properly formatted Parent Response command message.
309    """
310    command_msg.assertMleMessageContainsTlv(mle.Challenge)
311    command_msg.assertMleMessageContainsTlv(mle.Connectivity)
312    command_msg.assertMleMessageContainsTlv(mle.LeaderData)
313    command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter)
314    command_msg.assertMleMessageContainsTlv(mle.LinkMargin)
315    command_msg.assertMleMessageContainsTlv(mle.Response)
316    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
317    command_msg.assertMleMessageContainsTlv(mle.Version)
318
319    check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter)
320
321
322def check_child_id_request(
323    command_msg,
324    tlv_request=CheckType.OPTIONAL,
325    mle_frame_counter=CheckType.OPTIONAL,
326    address_registration=CheckType.OPTIONAL,
327    active_timestamp=CheckType.OPTIONAL,
328    pending_timestamp=CheckType.OPTIONAL,
329    route64=CheckType.OPTIONAL,
330):
331    """Verify a properly formatted Child Id Request command message.
332    """
333    if command_msg.mle.aux_sec_hdr.key_id_mode != 0x2:
334        raise ValueError("The Key Identifier Mode of the Security Control Field SHALL be set to 0x02")
335
336    command_msg.assertMleMessageContainsTlv(mle.LinkLayerFrameCounter)
337    command_msg.assertMleMessageContainsTlv(mle.Mode)
338    command_msg.assertMleMessageContainsTlv(mle.Response)
339    command_msg.assertMleMessageContainsTlv(mle.Timeout)
340    command_msg.assertMleMessageContainsTlv(mle.Version)
341
342    check_mle_optional_tlv(command_msg, tlv_request, mle.TlvRequest)
343    check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter)
344    check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration)
345    check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp)
346    check_mle_optional_tlv(command_msg, pending_timestamp, mle.PendingTimestamp)
347    check_mle_optional_tlv(command_msg, route64, mle.Route64)
348
349    check_tlv_request_tlv(command_msg, CheckType.CONTAIN, mle.TlvType.ADDRESS16)
350    check_tlv_request_tlv(command_msg, CheckType.CONTAIN, mle.TlvType.NETWORK_DATA)
351
352
353def check_child_id_response(
354    command_msg,
355    route64=CheckType.OPTIONAL,
356    network_data=CheckType.OPTIONAL,
357    address_registration=CheckType.OPTIONAL,
358    active_timestamp=CheckType.OPTIONAL,
359    pending_timestamp=CheckType.OPTIONAL,
360    active_operational_dataset=CheckType.OPTIONAL,
361    pending_operational_dataset=CheckType.OPTIONAL,
362    network_data_check=None,
363):
364    """Verify a properly formatted Child Id Response command message.
365    """
366    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
367    command_msg.assertMleMessageContainsTlv(mle.LeaderData)
368    command_msg.assertMleMessageContainsTlv(mle.Address16)
369
370    check_mle_optional_tlv(command_msg, route64, mle.Route64)
371    check_mle_optional_tlv(command_msg, network_data, mle.NetworkData)
372    check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration)
373    check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp)
374    check_mle_optional_tlv(command_msg, pending_timestamp, mle.PendingTimestamp)
375    check_mle_optional_tlv(command_msg, active_operational_dataset, mle.ActiveOperationalDataset)
376    check_mle_optional_tlv(command_msg, pending_operational_dataset, mle.PendingOperationalDataset)
377
378    if network_data_check is not None:
379        network_data_tlv = command_msg.assertMleMessageContainsTlv(mle.NetworkData)
380        network_data_check.check(network_data_tlv)
381
382
383def check_prefix(prefix):
384    """Verify if a prefix contains 6loWPAN sub-TLV and border router sub-TLV
385    """
386    assert contains_tlv(prefix.sub_tlvs, network_data.BorderRouter), 'Prefix doesn\'t contain a border router sub-TLV!'
387    assert contains_tlv(prefix.sub_tlvs, network_data.LowpanId), 'Prefix doesn\'t contain a LowpanId sub-TLV!'
388
389
390def check_child_update_request_from_child(
391        command_msg,
392        source_address=CheckType.OPTIONAL,
393        leader_data=CheckType.OPTIONAL,
394        challenge=CheckType.OPTIONAL,
395        time_out=CheckType.OPTIONAL,
396        address_registration=CheckType.OPTIONAL,
397        tlv_request_tlv=CheckType.OPTIONAL,
398        active_timestamp=CheckType.OPTIONAL,
399        CIDs=(),
400):
401
402    command_msg.assertMleMessageContainsTlv(mle.Mode)
403    check_mle_optional_tlv(command_msg, source_address, mle.SourceAddress)
404    check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData)
405    check_mle_optional_tlv(command_msg, challenge, mle.Challenge)
406    check_mle_optional_tlv(command_msg, time_out, mle.Timeout)
407    check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration)
408    check_mle_optional_tlv(command_msg, tlv_request_tlv, mle.TlvRequest)
409    check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp)
410
411    if (address_registration == CheckType.CONTAIN) and len(CIDs) > 0:
412        _check_address_registration(command_msg, CIDs)
413
414
415def check_coap_optional_tlv(coap_msg, type, tlv):
416    if type == CheckType.CONTAIN:
417        coap_msg.assertCoapMessageContainsTlv(tlv)
418    elif type == CheckType.NOT_CONTAIN:
419        coap_msg.assertCoapMessageDoesNotContainTlv(tlv)
420    elif type == CheckType.OPTIONAL:
421        coap_msg.assertCoapMessageContainsOptionalTlv(tlv)
422    else:
423        raise ValueError("Invalid check type")
424
425
426def check_router_id_cached(node, router_id, cached=True):
427    """Verify if the node has cached any entries based on the router ID
428    """
429    eidcaches = node.get_eidcaches()
430    if cached:
431        assert any(router_id == (int(rloc, 16) >> 10) for (_, rloc) in eidcaches)
432    else:
433        assert (any(router_id == (int(rloc, 16) >> 10) for (_, rloc) in eidcaches) is False)
434
435
436def contains_tlv(sub_tlvs, tlv_type):
437    """Verify if a specific type of tlv is included in a sub-tlv list.
438    """
439    return any(isinstance(sub_tlv, tlv_type) for sub_tlv in sub_tlvs)
440
441
442def contains_tlvs(sub_tlvs, tlv_types):
443    """Verify if all types of tlv in a list are included in a sub-tlv list.
444    """
445    return all((any(isinstance(sub_tlv, tlv_type) for sub_tlv in sub_tlvs)) for tlv_type in tlv_types)
446
447
448def check_secure_mle_key_id_mode(command_msg, key_id_mode):
449    """Verify if the mle command message sets the right key id mode.
450    """
451    assert isinstance(command_msg.mle, mle.MleMessageSecured)
452    assert command_msg.mle.aux_sec_hdr.key_id_mode == key_id_mode
453
454
455def check_data_response(command_msg, network_data_check=None, active_timestamp=CheckType.OPTIONAL):
456    """Verify a properly formatted Data Response command message.
457    """
458    check_secure_mle_key_id_mode(command_msg, 0x02)
459    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
460    command_msg.assertMleMessageContainsTlv(mle.LeaderData)
461    check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp)
462    if network_data_check is not None:
463        network_data_tlv = command_msg.assertMleMessageContainsTlv(mle.NetworkData)
464        network_data_check.check(network_data_tlv)
465
466
467def check_child_update_request_from_parent(
468    command_msg,
469    leader_data=CheckType.OPTIONAL,
470    network_data=CheckType.OPTIONAL,
471    challenge=CheckType.OPTIONAL,
472    tlv_request=CheckType.OPTIONAL,
473    active_timestamp=CheckType.OPTIONAL,
474):
475    """Verify a properly formatted Child Update Request(from parent) command message.
476    """
477    check_secure_mle_key_id_mode(command_msg, 0x02)
478
479    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
480    check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData)
481    check_mle_optional_tlv(command_msg, network_data, mle.NetworkData)
482    check_mle_optional_tlv(command_msg, challenge, mle.Challenge)
483    check_mle_optional_tlv(command_msg, tlv_request, mle.TlvRequest)
484    check_mle_optional_tlv(command_msg, active_timestamp, mle.ActiveTimestamp)
485
486
487def check_child_update_response(
488        command_msg,
489        timeout=CheckType.OPTIONAL,
490        address_registration=CheckType.OPTIONAL,
491        address16=CheckType.OPTIONAL,
492        leader_data=CheckType.OPTIONAL,
493        network_data=CheckType.OPTIONAL,
494        response=CheckType.OPTIONAL,
495        link_layer_frame_counter=CheckType.OPTIONAL,
496        mle_frame_counter=CheckType.OPTIONAL,
497        CIDs=(),
498):
499    """Verify a properly formatted Child Update Response from parent
500    """
501    check_secure_mle_key_id_mode(command_msg, 0x02)
502
503    command_msg.assertMleMessageContainsTlv(mle.SourceAddress)
504    command_msg.assertMleMessageContainsTlv(mle.Mode)
505    check_mle_optional_tlv(command_msg, timeout, mle.Timeout)
506    check_mle_optional_tlv(command_msg, address_registration, mle.AddressRegistration)
507    check_mle_optional_tlv(command_msg, address16, mle.Address16)
508    check_mle_optional_tlv(command_msg, leader_data, mle.LeaderData)
509    check_mle_optional_tlv(command_msg, network_data, mle.NetworkData)
510    check_mle_optional_tlv(command_msg, response, mle.Response)
511    check_mle_optional_tlv(command_msg, link_layer_frame_counter, mle.LinkLayerFrameCounter)
512    check_mle_optional_tlv(command_msg, mle_frame_counter, mle.MleFrameCounter)
513
514    if (address_registration == CheckType.CONTAIN) and len(CIDs) > 0:
515        _check_address_registration(command_msg, CIDs)
516
517
518def _check_address_registration(command_msg, CIDs=()):
519    addresses = command_msg.assertMleMessageContainsTlv(mle.AddressRegistration).addresses
520    for cid in CIDs:
521        found = False
522        for address in addresses:
523            if isinstance(address, mle.AddressCompressed):
524                if cid == address.cid:
525                    found = True
526                    break
527        assert found, "AddressRegistration TLV doesn't have CID {} ".format(cid)
528
529
530def get_sub_tlv(tlvs, tlv_type):
531    for sub_tlv in tlvs:
532        if isinstance(sub_tlv, tlv_type):
533            return sub_tlv
534
535
536def check_address_registration_tlv(
537    command_msg,
538    full_address,
539):
540    """Check whether or not a full IPv6 address in AddressRegistrationTlv.
541    """
542    found = False
543    addr = ipaddress.ip_address(full_address)
544    addresses = command_msg.assertMleMessageContainsTlv(mle.AddressRegistration).addresses
545
546    for item in addresses:
547        if isinstance(item, mle.AddressFull) and ipaddress.ip_address(item.ipv6_address) == addr:
548            found = True
549            break
550
551    return found
552
553
554def check_compressed_address_registration_tlv(command_msg, cid, iid, cid_present_once=False):
555    '''Check whether or not a compressed IPv6 address in AddressRegistrationTlv.
556    note: only compare the iid part of the address.
557
558        Args:
559            command_msg (MleMessage) : The Mle message to check.
560            cid (int): The context id of the domain prefix.
561            iid (string): The Interface Identifier.
562            cid_present_once(boolean): True if cid entry should apprear only once in AR Tlv.
563                                       False otherwise.
564    '''
565    found = False
566    cid_cnt = 0
567
568    addresses = command_msg.assertMleMessageContainsTlv(mle.AddressRegistration).addresses
569
570    for item in addresses:
571        if isinstance(item, mle.AddressCompressed):
572            if cid == item.cid:
573                cid_cnt = cid_cnt + 1
574                if iid == item.iid.hex():
575                    found = True
576                    break
577    assert found, 'Error: Expected (cid, iid):({},{}) Not Found'.format(cid, iid)
578
579    assert cid_present_once == (cid_cnt == 1), 'Error: Expected cid present {} but present {}'.format(
580        'once' if cid_present_once else '', cid_cnt)
581
582
583def assert_contains_tlv(tlvs, check_type, tlv_type):
584    """Assert a tlv list contains specific tlv and return the first qualified.
585    """
586    tlvs = [tlv for tlv in tlvs if isinstance(tlv, tlv_type)]
587    if check_type is CheckType.CONTAIN:
588        assert tlvs
589        return tlvs[0]
590    elif check_type is CheckType.NOT_CONTAIN:
591        assert not tlvs
592        return None
593    elif check_type is CheckType.OPTIONAL:
594        return None
595    else:
596        raise ValueError("Invalid check type: {}".format(check_type))
597
598
599def check_discovery_request(command_msg, thread_version: str = None):
600    """Verify a properly formatted Thread Discovery Request command message.
601    """
602    assert not isinstance(command_msg.mle, mle.MleMessageSecured)
603    tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
604    request = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryRequest)
605    assert not thread_version or thread_version in ['1.1', '1.2']
606    if thread_version == '1.1':
607        assert request.version == config.THREAD_VERSION_1_1
608    elif thread_version == '1.2':
609        assert request.version == config.THREAD_VERSION_1_2
610
611
612def check_discovery_response(command_msg,
613                             request_src_addr,
614                             steering_data=CheckType.OPTIONAL,
615                             thread_version: str = None):
616    """Verify a properly formatted Thread Discovery Response command message.
617    """
618    assert not isinstance(command_msg.mle, mle.MleMessageSecured)
619    assert (command_msg.mac_header.src_address.type == common.MacAddressType.LONG)
620    assert command_msg.mac_header.dest_address == request_src_addr
621
622    tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
623    response = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.DiscoveryResponse)
624    assert not thread_version or thread_version in ['1.1', '1.2']
625    if thread_version == '1.1':
626        assert response.version == config.THREAD_VERSION_1_1
627    elif thread_version == '1.2':
628        assert response.version == config.THREAD_VERSION_1_2
629    assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.ExtendedPanid)
630    assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.NetworkName)
631    assert_contains_tlv(tlvs, steering_data, mesh_cop.SteeringData)
632    assert_contains_tlv(tlvs, steering_data, mesh_cop.JoinerUdpPort)
633
634    check_type = (CheckType.CONTAIN if response.native_flag else CheckType.OPTIONAL)
635    assert_contains_tlv(tlvs, check_type, mesh_cop.CommissionerUdpPort)
636
637
638def get_joiner_udp_port_in_discovery_response(command_msg):
639    """Get the udp port specified in a DISCOVERY RESPONSE message
640    """
641    tlvs = command_msg.assertMleMessageContainsTlv(mle.ThreadDiscovery).tlvs
642    udp_port_tlv = assert_contains_tlv(tlvs, CheckType.CONTAIN, mesh_cop.JoinerUdpPort)
643    return udp_port_tlv.udp_port
644
645
646def check_joiner_commissioning_messages(commissioning_messages, url=''):
647    """Verify COAP messages sent by joiner while commissioning process.
648    """
649    print(commissioning_messages)
650    assert len(commissioning_messages) >= 4
651    join_fin_req = commissioning_messages[0]
652    assert join_fin_req.type == mesh_cop.MeshCopMessageType.JOIN_FIN_REQ
653    if url:
654        provisioning_url = assert_contains_tlv(join_fin_req.tlvs, CheckType.CONTAIN, mesh_cop.ProvisioningUrl)
655        assert url == provisioning_url.url
656    else:
657        assert_contains_tlv(join_fin_req.tlvs, CheckType.NOT_CONTAIN, mesh_cop.ProvisioningUrl)
658
659    join_ent_rsp = commissioning_messages[3]
660    assert join_ent_rsp.type == mesh_cop.MeshCopMessageType.JOIN_ENT_RSP
661
662
663def check_commissioner_commissioning_messages(commissioning_messages, state=mesh_cop.MeshCopState.ACCEPT):
664    """Verify COAP messages sent by commissioner while commissioning process.
665    """
666    assert len(commissioning_messages) >= 2
667    join_fin_rsq = commissioning_messages[1]
668    assert join_fin_rsq.type == mesh_cop.MeshCopMessageType.JOIN_FIN_RSP
669    rsq_state = assert_contains_tlv(join_fin_rsq.tlvs, CheckType.CONTAIN, mesh_cop.State)
670    assert rsq_state.state == state
671
672
673def check_joiner_router_commissioning_messages(commissioning_messages):
674    """Verify COAP messages sent by joiner router while commissioning process.
675    """
676    if len(commissioning_messages) >= 4:
677        join_ent_ntf = commissioning_messages[2]
678    else:
679        join_ent_ntf = commissioning_messages[0]
680    assert join_ent_ntf.type == mesh_cop.MeshCopMessageType.JOIN_ENT_NTF
681    return None
682
683
684def check_payload_same(tp1, tp2):
685    """Verfiy two payloads are totally the same.
686       A payload is a tuple of tlvs.
687    """
688    assert len(tp1) == len(tp2)
689    for tlv in tp2:
690        peer_tlv = get_sub_tlv(tp1, type(tlv))
691        assert (peer_tlv is not None and
692                peer_tlv == tlv), 'peer_tlv:{}, tlv:{} type:{}'.format(peer_tlv, tlv, type(tlv))
693
694
695def check_coap_message(msg, payloads, dest_addrs=None):
696    if dest_addrs is not None:
697        found = False
698        for dest in dest_addrs:
699            if msg.ipv6_packet.ipv6_header.destination_address == dest:
700                found = True
701                break
702        assert found, 'Destination address incorrect'
703    check_payload_same(msg.coap.payload, payloads)
704
705
706class SinglePrefixCheck:
707
708    def __init__(self, prefix=None, border_router_16=None):
709        self._prefix = prefix
710        self._border_router_16 = border_router_16
711
712    def check(self, prefix_tlv):
713        border_router_tlv = assert_contains_tlv(prefix_tlv.sub_tlvs, CheckType.CONTAIN, network_data.BorderRouter)
714        assert_contains_tlv(prefix_tlv.sub_tlvs, CheckType.CONTAIN, network_data.LowpanId)
715        result = True
716        if self._prefix is not None:
717            result &= self._prefix == binascii.hexlify(prefix_tlv.prefix)
718        if self._border_router_16 is not None:
719            result &= (self._border_router_16 == border_router_tlv.border_router_16)
720        return result
721
722
723class PrefixesCheck:
724
725    def __init__(self, prefix_cnt=0, prefix_check_list=()):
726        self._prefix_cnt = prefix_cnt
727        self._prefix_check_list = prefix_check_list
728
729    def check(self, prefix_tlvs):
730        # if prefix_cnt is given, then check count only
731        if self._prefix_cnt > 0:
732            assert (len(prefix_tlvs) >= self._prefix_cnt), 'prefix count is less than expected'
733        else:
734            for prefix_check in self._prefix_check_list:
735                found = False
736                for prefix_tlv in prefix_tlvs:
737                    if prefix_check.check(prefix_tlv):
738                        found = True
739                        break
740                assert found, 'Some prefix is absent: {}'.format(prefix_check)
741
742
743class CommissioningDataCheck:
744
745    def __init__(self, stable=None, sub_tlv_type_list=()):
746        self._stable = stable
747        self._sub_tlv_type_list = sub_tlv_type_list
748
749    def check(self, commissioning_data_tlv):
750        if self._stable is not None:
751            assert (self._stable == commissioning_data_tlv.stable), 'Commissioning Data stable flag is not correct'
752        assert contains_tlvs(commissioning_data_tlv.sub_tlvs,
753                             self._sub_tlv_type_list), 'Some sub tlvs are missing in Commissioning Data'
754
755
756class NetworkDataCheck:
757
758    def __init__(self, prefixes_check=None, commissioning_data_check=None):
759        self._prefixes_check = prefixes_check
760        self._commissioning_data_check = commissioning_data_check
761
762    def check(self, network_data_tlv):
763        if self._prefixes_check is not None:
764            prefix_tlvs = [tlv for tlv in network_data_tlv.tlvs if isinstance(tlv, network_data.Prefix)]
765            self._prefixes_check.check(prefix_tlvs)
766        if self._commissioning_data_check is not None:
767            commissioning_data_tlv = assert_contains_tlv(
768                network_data_tlv.tlvs,
769                CheckType.CONTAIN,
770                network_data.CommissioningData,
771            )
772            self._commissioning_data_check.check(commissioning_data_tlv)
773