1"""
2  Copyright (c) 2024, The OpenThread Authors.
3  All rights reserved.
4
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  1. Redistributions of source code must retain the above copyright
8     notice, this list of conditions and the following disclaimer.
9  2. Redistributions in binary form must reproduce the above copyright
10     notice, this list of conditions and the following disclaimer in the
11     documentation and/or other materials provided with the distribution.
12  3. Neither the name of the copyright holder nor the
13     names of its contributors may be used to endorse or promote products
14     derived from this software without specific prior written permission.
15
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  POSSIBILITY OF SUCH DAMAGE.
27"""
28
29import _ssl
30import asyncio
31import ssl
32import sys
33import logging
34
35from cryptography.x509 import load_der_x509_certificate
36from cryptography.hazmat.primitives.serialization import (Encoding, PublicFormat)
37from tlv.tlv import TLV
38from tlv.tcat_tlv import TcatTLVType
39from time import time
40import utils
41
42logger = logging.getLogger(__name__)
43
44
45class BleStreamSecure:
46
47    def __init__(self, stream):
48        self.stream = stream
49        self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
50        self.incoming = ssl.MemoryBIO()
51        self.outgoing = ssl.MemoryBIO()
52        self.ssl_object = None
53        self.cert = ''
54        self.peer_challenge = None
55        self._peer_public_key = None
56
57    def load_cert(self, certfile='', keyfile='', cafile=''):
58        if certfile and keyfile:
59            self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile)
60            self.cert = utils.load_cert_pem(certfile)
61        elif certfile:
62            self.ssl_context.load_cert_chain(certfile=certfile)
63            self.cert = utils.load_cert_pem(certfile)
64
65        if cafile:
66            self.ssl_context.load_verify_locations(cafile=cafile)
67
68    async def do_handshake(self, timeout=30.0):
69        is_debug = logger.getEffectiveLevel() <= logging.DEBUG
70        self.ssl_object = self.ssl_context.wrap_bio(
71            incoming=self.incoming,
72            outgoing=self.outgoing,
73            server_side=False,
74            server_hostname=None,
75        )
76        start = time()
77        while (time() - start) < timeout:
78            try:
79                if not is_debug:
80                    print('.', end='')
81                    sys.stdout.flush()
82                self.ssl_object.do_handshake()
83                break
84
85            # SSLWantWrite means ssl wants to send data over the link,
86            # but might need a receive first
87            except ssl.SSLWantWriteError:
88                output = await self.stream.recv(4096)
89                if output:
90                    self.incoming.write(output)
91                data = self.outgoing.read()
92                if data:
93                    await self.stream.send(data)
94                await asyncio.sleep(0.02)
95
96            # SSLWantRead means ssl wants to receive data from the link,
97            # but might need to send first
98            except ssl.SSLWantReadError:
99                data = self.outgoing.read()
100                if data:
101                    await self.stream.send(data)
102                output = await self.stream.recv(4096)
103                if output:
104                    self.incoming.write(output)
105                await asyncio.sleep(0.02)
106        else:
107            print('TLS Connection timed out.')
108            return False
109        print('')
110        cert = self.ssl_object.getpeercert(True)
111        cert_obj = load_der_x509_certificate(cert)
112        self._peer_public_key = cert_obj.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)
113        self.log_cert_identities()
114        return True
115
116    async def send(self, bytes):
117        self.ssl_object.write(bytes)
118        encode = self.outgoing.read(4096)
119        await self.stream.send(encode)
120
121    async def recv(self, buffersize, timeout=1):
122        end_time = asyncio.get_event_loop().time() + timeout
123        data = await self.stream.recv(buffersize)
124        while not data and asyncio.get_event_loop().time() < end_time:
125            await asyncio.sleep(0.1)
126            data = await self.stream.recv(buffersize)
127        if not data:
128            logger.warning('No response when response expected.')
129            return b''
130
131        self.incoming.write(data)
132        while True:
133            try:
134                decode = self.ssl_object.read(4096)
135                break
136            # if recv called before entire message was received from the link
137            except ssl.SSLWantReadError:
138                more = await self.stream.recv(buffersize)
139                while not more:
140                    await asyncio.sleep(0.1)
141                    more = await self.stream.recv(buffersize)
142                self.incoming.write(more)
143        return decode
144
145    async def send_with_resp(self, bytes):
146        await self.send(bytes)
147        res = await self.recv(buffersize=4096, timeout=5)
148        return res
149
150    async def close(self):
151        if self.ssl_object.session is not None:
152            logger.debug('sending Disconnect command TLV')
153            data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes()
154            self.peer_challenge = None
155            self._peer_public_key = None
156            await self.send(data)
157
158    @property
159    def peer_public_key(self):
160        return self._peer_public_key
161
162    @property
163    def peer_challenge(self):
164        return self._peer_challenge
165
166    @peer_challenge.setter
167    def peer_challenge(self, value):
168        self._peer_challenge = value
169
170    def log_cert_identities(self):
171        # using the internal object of the ssl library is necessary to see the cert data in
172        # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores
173        # Should work for Python >= 3.10
174        try:
175            cc = self.ssl_object._sslobj.get_unverified_chain()
176            if cc is None:
177                logger.info('No TCAT Device cert chain was received (yet).')
178                return
179            logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.')
180            for cert in cc:
181                logger.info(f'  cert info:\n{cert.get_info()}')
182                peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER))
183                logger.info(f'  base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}')
184            logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}')
185
186        except Exception as e:
187            logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)')
188