1#!/usr/bin/env python3
2#
3#  Copyright (c) 2019, The OpenThread Authors.
4#  All rights reserved.
5#
6#  Redistribution and use in source and binary forms, with or without
7#  modification, are permitted provided that the following conditions are met:
8#  1. Redistributions of source code must retain the above copyright
9#     notice, this list of conditions and the following disclaimer.
10#  2. Redistributions in binary form must reproduce the above copyright
11#     notice, this list of conditions and the following disclaimer in the
12#     documentation and/or other materials provided with the distribution.
13#  3. Neither the name of the copyright holder nor the
14#     names of its contributors may be used to endorse or promote products
15#     derived from this software without specific prior written permission.
16#
17#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27#  POSSIBILITY OF SUCH DAMAGE.
28#
29
30from enum import IntEnum
31from functools import reduce
32import io
33import struct
34
35from ipv6 import BuildableFromBytes
36from ipv6 import ConvertibleToBytes
37
38
39class HandshakeType(IntEnum):
40    HELLO_REQUEST = 0
41    CLIENT_HELLO = 1
42    SERVER_HELLO = 2
43    HELLO_VERIFY_REQUEST = 3
44    CERTIFICATE = 11
45    SERVER_KEY_EXCHANGE = 12
46    CERTIFICATE_REQUEST = 13
47    SERVER_HELLO_DONE = 14
48    CERTIFICATE_VERIFY = 15
49    CLIENT_KEY_EXCHANGE = 16
50    FINISHED = 20
51
52
53class ContentType(IntEnum):
54    CHANGE_CIPHER_SPEC = 20
55    ALERT = 21
56    HANDSHAKE = 22
57    APPLICATION_DATA = 23
58
59
60class AlertLevel(IntEnum):
61    WARNING = 1
62    FATAL = 2
63
64
65class AlertDescription(IntEnum):
66    CLOSE_NOTIFY = 0
67    UNEXPECTED_MESSAGE = 10
68    BAD_RECORD_MAC = 20
69    DECRYPTION_FAILED_RESERVED = 21
70    RECORD_OVERFLOW = 22
71    DECOMPRESSION_FAILURE = 30
72    HANDSHAKE_FAILURE = 40
73    NO_CERTIFICATE_RESERVED = 41
74    BAD_CERTIFICATE = 42
75    UNSUPPORTED_CERTIFICATE = 43
76    CERTIFICATE_REVOKED = 44
77    CERTIFICATE_EXPIRED = 45
78    CERTIFICATE_UNKNOWN = 46
79    ILLEGAL_PARAMETER = 47
80    UNKNOWN_CA = 48
81    ACCESS_DENIED = 49
82    DECODE_ERROR = 50
83    DECRYPT_ERROR = 51
84    EXPORT_RESTRICTION_RESERVED = 60
85    PROTOCOL_VERSION = 70
86    INSUFFICIENT_SECURITY = 71
87    INTERNAL_ERROR = 80
88    USER_CANCELED = 90
89    NO_RENEGOTIATION = 100
90    UNSUPPORTED_EXTENSION = 110
91
92
93class Record(ConvertibleToBytes, BuildableFromBytes):
94
95    def __init__(self, content_type, version, epoch, sequence_number, length, fragment):
96        self.content_type = content_type
97        self.version = version
98        self.epoch = epoch
99        self.sequence_number = sequence_number
100        self.length = length
101        self.fragment = fragment
102
103    def to_bytes(self):
104        return (struct.pack(">B", self.content_type) + self.version.to_bytes() + struct.pack(">H", self.epoch) +
105                self.sequence_number.to_bytes(6, byteorder='big') + struct.pack(">H", self.length) + self.fragment)
106
107    @classmethod
108    def from_bytes(cls, data):
109        content_type = ContentType(struct.unpack(">B", data.read(1))[0])
110        version = ProtocolVersion.from_bytes(data)
111        epoch = struct.unpack(">H", data.read(2))[0]
112        sequence_number = struct.unpack(">Q", b'\x00\x00' + data.read(6))[0]
113        length = struct.unpack(">H", data.read(2))[0]
114        fragment = bytes(data.read(length))
115        return cls(content_type, version, epoch, sequence_number, length, fragment)
116
117    def __repr__(self):
118        return "Record(content_type={}, version={}, epoch={}, sequence_number={}, length={})".format(
119            str(self.content_type),
120            self.version,
121            self.epoch,
122            self.sequence_number,
123            self.length,
124        )
125
126
127class Message(ConvertibleToBytes, BuildableFromBytes):
128
129    def __init__(self, content_type):
130        self.content_type = content_type
131
132    def to_bytes(self):
133        raise NotImplementedError
134
135    @classmethod
136    def from_bytes(cls, data):
137        raise NotImplementedError
138
139
140class HandshakeMessage(Message):
141
142    def __init__(
143        self,
144        handshake_type,
145        length,
146        message_seq,
147        fragment_offset,
148        fragment_length,
149        body,
150    ):
151        super(HandshakeMessage, self).__init__(ContentType.HANDSHAKE)
152        self.handshake_type = handshake_type
153        self.length = length
154        self.message_seq = message_seq
155        self.fragment_offset = fragment_offset
156        self.fragment_length = fragment_length
157        self.body = body
158
159    def to_bytes(self):
160        return (struct.pack(">B", self.handshake_type) + struct.pack(">I", self.length)[1:] +
161                struct.pack(">H", self.message_seq) + struct.pack(">I", self.fragment_offset)[1:] +
162                struct.pack(">I", self.fragment_length)[1:] + self.body.to_bytes())
163
164    @classmethod
165    def from_bytes(cls, data):
166        handshake_type = HandshakeType(struct.unpack(">B", data.read(1))[0])
167        length = struct.unpack(">I", b'\x00' + data.read(3))[0]
168        message_seq = struct.unpack(">H", data.read(2))[0]
169        fragment_offset = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
170        fragment_length = struct.unpack(">I", b'\x00' + bytes(data.read(3)))[0]
171        end_position = data.tell() + fragment_length
172        # TODO(wgtdkp): handle fragmentation
173
174        message_class, body = handshake_map[handshake_type], None
175        if message_class:
176            body = message_class.from_bytes(data)
177        else:
178            print("{} messages are not handled".format(str(handshake_type)))
179            body = bytes(data.read(fragment_length))
180        assert data.tell() == end_position
181
182        return cls(
183            handshake_type,
184            length,
185            message_seq,
186            fragment_offset,
187            fragment_length,
188            body,
189        )
190
191    def __repr__(self):
192        return "Handshake(type={}, length={})".format(str(self.handshake_type), self.length)
193
194
195class ProtocolVersion(ConvertibleToBytes, BuildableFromBytes):
196
197    def __init__(self, major, minor):
198        self.major = major
199        self.minor = minor
200
201    def __eq__(self, other):
202        return (isinstance(self, type(other)) and self.major == other.major and self.minor == other.minor)
203
204    def to_bytes(self):
205        return struct.pack(">BB", self.major, self.minor)
206
207    @classmethod
208    def from_bytes(cls, data):
209        major, minor = struct.unpack(">BB", data.read(2))
210        return cls(major, minor)
211
212    def __repr__(self):
213        return "ProtocolVersion(major={}, minor={})".format(self.major, self.minor)
214
215
216class Random(ConvertibleToBytes, BuildableFromBytes):
217
218    random_bytes_length = 28
219
220    def __init__(self, gmt_unix_time, random_bytes):
221        self.gmt_unix_time = gmt_unix_time
222        self.random_bytes = random_bytes
223        assert len(self.random_bytes) == Random.random_bytes_length
224
225    def __eq__(self, other):
226        return (isinstance(self, type(other)) and self.gmt_unix_time == other.gmt_unix_time and
227                self.random_bytes == other.random_bytes)
228
229    def to_bytes(self):
230        return struct.pack(">I", self.gmt_unix_time) + (self.random_bytes)
231
232    @classmethod
233    def from_bytes(cls, data):
234        gmt_unix_time = struct.unpack(">I", data.read(4))[0]
235        random_bytes = bytes(data.read(cls.random_bytes_length))
236        return cls(gmt_unix_time, random_bytes)
237
238
239class VariableVector(ConvertibleToBytes):
240
241    def __init__(self, subrange, ele_cls, elements):
242        self.subrange = subrange
243        self.ele_cls = ele_cls
244        self.elements = elements
245        assert self.subrange[0] <= len(self.elements) <= self.subrange[1]
246
247    def length(self):
248        return len(self.elements)
249
250    def __eq__(self, other):
251        return (isinstance(self, type(other)) and self.subrange == other.subrange and self.ele_cls == other.ele_cls and
252                self.elements == other.elements)
253
254    def to_bytes(self):
255        data = reduce(lambda ele, acc: acc + ele.to_bytes(), self.elements)
256        return VariableVector._encode_length(len(data), self.subrange) + data
257
258    @classmethod
259    def from_bytes(cls, ele_cls, subrange, data):
260        length = cls._decode_length(subrange, data)
261        end_position = data.tell() + length
262        elements = []
263        while data.tell() < end_position:
264            elements.append(ele_cls.from_bytes(data))
265        return cls(subrange, ele_cls, elements)
266
267    @classmethod
268    def _decode_length(cls, subrange, data):
269        length_in_byte = cls._calc_length_in_byte(subrange[1])
270        return reduce(
271            lambda acc, byte: (acc << 8) | byte,
272            bytearray(data.read(length_in_byte)),
273            0,
274        )
275
276    @classmethod
277    def _encode_length(cls, length, subrange):
278        length_in_byte = cls._calc_length_in_byte(subrange[1])
279        ret = bytearray([])
280        while length_in_byte > 0:
281            ret += bytes(length_in_byte & 0xff)
282            length_in_byte = length_in_byte >> 8
283        return ret
284
285    @classmethod
286    def _calc_length_in_byte(cls, ceiling):
287        return (ceiling.bit_length() + 7) // 8
288
289
290class Opaque(ConvertibleToBytes, BuildableFromBytes):
291
292    def __init__(self, byte):
293        self.byte = byte
294
295    def __eq__(self, other):
296        return isinstance(self, type(other)) and self.byte == other.byte
297
298    def to_bytes(self):
299        return struct.pack(">B", self.byte)
300
301    @classmethod
302    def from_bytes(cls, data):
303        return cls(struct.unpack(">B", data.read(1))[0])
304
305
306class CipherSuite(ConvertibleToBytes, BuildableFromBytes):
307
308    def __init__(self, cipher):
309        self.cipher = cipher
310
311    def __eq__(self, other):
312        return isinstance(self, type(other)) and self.cipher == other.cipher
313
314    def to_bytes(self):
315        return struct.pack(">BB", self.cipher[0], self.cipher[1])
316
317    @classmethod
318    def from_bytes(cls, data):
319        return cls(struct.unpack(">BB", data.read(2)))
320
321    def __repr__(self):
322        return "CipherSuite({}, {})".format(self.cipher[0], self.cipher[1])
323
324
325class CompressionMethod(ConvertibleToBytes, BuildableFromBytes):
326
327    NULL = 0
328
329    def __init__(self):
330        pass
331
332    def __eq__(self, other):
333        return isinstance(self, type(other))
334
335    def to_bytes(self):
336        return struct.pack(">B", CompressionMethod.NULL)
337
338    @classmethod
339    def from_bytes(cls, data):
340        method = struct.unpack(">B", data.read(1))[0]
341        assert method == cls.NULL
342        return cls()
343
344
345class Extension(ConvertibleToBytes, BuildableFromBytes):
346
347    def __init__(self, extension_type, extension_data):
348        self.extension_type = extension_type
349        self.extension_data = extension_data
350
351    def __eq__(self, other):
352        return (isinstance(self, type(other)) and self.extension_type == other.extension_type and
353                self.extension_data == other.extension_data)
354
355    def to_bytes(self):
356        return (struct.pack(">H", self.extension_type) + self.extension_data.to_bytes())
357
358    @classmethod
359    def from_bytes(cls, data):
360        extension_type = struct.unpack(">H", data.read(2))[0]
361        extension_data = VariableVector.from_bytes(Opaque, (0, 2**16 - 1), data)
362        return cls(extension_type, extension_data)
363
364
365class ClientHello(HandshakeMessage):
366
367    def __init__(
368        self,
369        client_version,
370        random,
371        session_id,
372        cookie,
373        cipher_suites,
374        compression_methods,
375        extensions,
376    ):
377        self.client_version = client_version
378        self.random = random
379        self.session_id = session_id
380        self.cookie = cookie
381        self.cipher_suites = cipher_suites
382        self.compression_methods = compression_methods
383        self.extensions = extensions
384
385    def to_bytes(self):
386        return (self.client_version.to_bytes() + self.random.to_bytes() + self.session_id.to_bytes() +
387                self.cookie.to_bytes() + self.cipher_suites.to_bytes() + self.compression_methods.to_bytes() +
388                self.extensions.to_bytes())
389
390    @classmethod
391    def from_bytes(cls, data):
392        client_version = ProtocolVersion.from_bytes(data)
393        random = Random.from_bytes(data)
394        session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
395        cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
396        cipher_suites = VariableVector.from_bytes(CipherSuite, (2, 2**16 - 1), data)
397        compression_methods = VariableVector.from_bytes(CompressionMethod, (1, 2**8 - 1), data)
398        extensions = None
399        if data.tell() < len(data.getvalue()):
400            extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
401        return cls(
402            client_version,
403            random,
404            session_id,
405            cookie,
406            cipher_suites,
407            compression_methods,
408            extensions,
409        )
410
411
412class HelloVerifyRequest(HandshakeMessage):
413
414    def __init__(self, server_version, cookie):
415        self.server_version = server_version
416        self.cookie = cookie
417
418    def to_bytes(self):
419        return self.server_version.to_bytes() + self.cookie.to_bytes()
420
421    @classmethod
422    def from_bytes(cls, data):
423        server_version = ProtocolVersion.from_bytes(data)
424        cookie = VariableVector.from_bytes(Opaque, (0, 2**8 - 1), data)
425        return cls(server_version, cookie)
426
427
428class ServerHello(HandshakeMessage):
429
430    def __init__(
431        self,
432        server_version,
433        random,
434        session_id,
435        cipher_suite,
436        compression_method,
437        extensions,
438    ):
439        self.server_version = server_version
440        self.random = random
441        self.session_id = session_id
442        self.cipher_suite = cipher_suite
443        self.compression_method = compression_method
444        self.extensions = extensions
445
446    def to_bytes(self):
447        return (self.server_version.to_bytes() + self.random.to_bytes() + self.session_id.to_bytes() +
448                self.cipher_suite.to_bytes() + self.compression_method.to_bytes() + self.extensions.to_bytes())
449
450    @classmethod
451    def from_bytes(cls, data):
452        server_version = ProtocolVersion.from_bytes(data)
453        random = Random.from_bytes(data)
454        session_id = VariableVector.from_bytes(Opaque, (0, 32), data)
455        cipher_suite = CipherSuite.from_bytes(data)
456        compression_method = CompressionMethod.from_bytes(data)
457        extensions = None
458        if data.tell() < len(data.getvalue()):
459            extensions = VariableVector.from_bytes(Extension, (0, 2**16 - 1), data)
460        return cls(
461            server_version,
462            random,
463            session_id,
464            cipher_suite,
465            compression_method,
466            extensions,
467        )
468
469
470class ServerHelloDone(HandshakeMessage):
471
472    def __init__(self):
473        pass
474
475    def to_bytes(self):
476        return bytearray([])
477
478    @classmethod
479    def from_bytes(cls, data):
480        return cls()
481
482
483class HelloRequest(HandshakeMessage):
484
485    def __init__(self):
486        raise NotImplementedError
487
488
489class Certificate(HandshakeMessage):
490
491    def __init__(self):
492        raise NotImplementedError
493
494
495class ServerKeyExchange(HandshakeMessage):
496
497    def __init__(self):
498        raise NotImplementedError
499
500
501class CertificateRequest(HandshakeMessage):
502
503    def __init__(self):
504        raise NotImplementedError
505
506
507class CertificateVerify(HandshakeMessage):
508
509    def __init__(self):
510        raise NotImplementedError
511
512
513class ClientKeyExchange(HandshakeMessage):
514
515    def __init__(self):
516        raise NotImplementedError
517
518
519class Finished(HandshakeMessage):
520
521    def __init__(self, verify_data):
522        raise NotImplementedError
523
524
525class AlertMessage(Message):
526
527    def __init__(self, level, description):
528        super(AlertMessage, self).__init__(ContentType.ALERT)
529        self.level = level
530        self.description = description
531
532    def to_bytes(self):
533        struct.pack(">BB", self.level, self.description)
534
535    @classmethod
536    def from_bytes(cls, data):
537        level, description = struct.unpack(">BB", data.read(2))
538        try:
539            return cls(AlertLevel(level), AlertDescription(description))
540        except BaseException:
541            data.read()
542            # An AlertMessage could be encrypted and we can't parsing it.
543            return cls(None, None)
544
545    def __repr__(self):
546        return "Alert(level={}, description={})".format(str(self.level), str(self.description))
547
548
549class ChangeCipherSpecMessage(Message):
550
551    def __init__(self):
552        super(ChangeCipherSpecMessage, self).__init__(ContentType.CHANGE_CIPHER_SPEC)
553
554    def to_bytes(self):
555        return struct.pack(">B", 1)
556
557    @classmethod
558    def from_bytes(cls, data):
559        assert struct.unpack(">B", data.read(1))[0] == 1
560        return cls()
561
562    def __repr__(self):
563        return "ChangeCipherSpec(value=1)"
564
565
566class ApplicationDataMessage(Message):
567
568    def __init__(self, raw):
569        super(ApplicationDataMessage, self).__init__(ContentType.APPLICATION_DATA)
570        self.raw = raw
571        self.body = None
572
573    def to_bytes(self):
574        return self.raw
575
576    @classmethod
577    def from_bytes(cls, data):
578        # It is safe to read until the end of this byte stream, because
579        # there is single application data message in a record.
580        length = len(data.getvalue()) - data.tell()
581        return cls(bytes(data.read(length)))
582
583    def __repr__(self):
584        if self.body:
585            return "ApplicationData(body={})".format(self.body)
586        else:
587            return "ApplicationData(raw_length={})".format(len(self.raw))
588
589
590handshake_map = {
591    HandshakeType.HELLO_REQUEST: None,  # HelloRequest
592    HandshakeType.CLIENT_HELLO: ClientHello,
593    HandshakeType.SERVER_HELLO: ServerHello,
594    HandshakeType.HELLO_VERIFY_REQUEST: HelloVerifyRequest,
595    HandshakeType.CERTIFICATE: None,  # Certificate
596    HandshakeType.SERVER_KEY_EXCHANGE: None,  # ServerKeyExchange
597    HandshakeType.CERTIFICATE_REQUEST: None,  # CertificateRequest
598    HandshakeType.SERVER_HELLO_DONE: ServerHelloDone,
599    HandshakeType.CERTIFICATE_VERIFY: None,  # CertificateVerify
600    HandshakeType.CLIENT_KEY_EXCHANGE: None,  # ClientKeyExchange
601    HandshakeType.FINISHED: None,  # Finished
602}
603
604content_map = {
605    ContentType.CHANGE_CIPHER_SPEC: ChangeCipherSpecMessage,
606    ContentType.ALERT: AlertMessage,
607    ContentType.HANDSHAKE: HandshakeMessage,
608    ContentType.APPLICATION_DATA: ApplicationDataMessage,
609}
610
611
612class MessageFactory(object):
613
614    last_msg_is_change_cipher_spec = False
615
616    def __init__(self):
617        pass
618
619    def parse(self, data, message_info):
620        messages = []
621
622        # Multiple records could be sent in the same UDP datagram
623        while data.tell() < len(data.getvalue()):
624            record = Record.from_bytes(data)
625
626            if record.version.major != 0xfe or record.version.minor != 0xFD:
627                raise ValueError("DTLS version error, expect DTLSv1.2")
628
629            last_msg_is_change_cipher_spec = type(self).last_msg_is_change_cipher_spec
630            type(self).last_msg_is_change_cipher_spec = (record.content_type == ContentType.CHANGE_CIPHER_SPEC)
631
632            # FINISHED message immediately follows CHANGE_CIPHER_SPEC message
633            # We skip FINISHED message as it is encrypted
634            if last_msg_is_change_cipher_spec:
635                continue
636
637            fragment_data = io.BytesIO(record.fragment)
638
639            # Multiple handshake messages could be sent in the same record
640            while fragment_data.tell() < len(fragment_data.getvalue()):
641                content_class = content_map[record.content_type]
642                assert content_class
643                messages.append(content_class.from_bytes(fragment_data))
644
645        return messages
646