1"""
2  Copyright (c) 2024, The OpenThread Authors.
3  All rights reserved.
4
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  1. Redistributions of source code must retain the above copyright
8     notice, this list of conditions and the following disclaimer.
9  2. Redistributions in binary form must reproduce the above copyright
10     notice, this list of conditions and the following disclaimer in the
11     documentation and/or other materials provided with the distribution.
12  3. Neither the name of the copyright holder nor the
13     names of its contributors may be used to endorse or promote products
14     derived from this software without specific prior written permission.
15
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  POSSIBILITY OF SUCH DAMAGE.
27"""
28
29import _ssl
30import asyncio
31import ssl
32import sys
33import logging
34
35from tlv.tlv import TLV
36from tlv.tcat_tlv import TcatTLVType
37from time import time
38import utils
39
40logger = logging.getLogger(__name__)
41
42
43class BleStreamSecure:
44
45    def __init__(self, stream):
46        self.stream = stream
47        self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
48        self.incoming = ssl.MemoryBIO()
49        self.outgoing = ssl.MemoryBIO()
50        self.ssl_object = None
51        self.cert = ''
52
53    def load_cert(self, certfile='', keyfile='', cafile=''):
54        if certfile and keyfile:
55            self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile)
56            self.cert = utils.load_cert_pem(certfile)
57        elif certfile:
58            self.ssl_context.load_cert_chain(certfile=certfile)
59            self.cert = utils.load_cert_pem(certfile)
60
61        if cafile:
62            self.ssl_context.load_verify_locations(cafile=cafile)
63
64    async def do_handshake(self, timeout=30.0):
65        is_debug = logger.getEffectiveLevel() <= logging.DEBUG
66        self.ssl_object = self.ssl_context.wrap_bio(
67            incoming=self.incoming,
68            outgoing=self.outgoing,
69            server_side=False,
70            server_hostname=None,
71        )
72        start = time()
73        while (time() - start) < timeout:
74            try:
75                if not is_debug:
76                    print('.', end='')
77                    sys.stdout.flush()
78                self.ssl_object.do_handshake()
79                break
80
81            # SSLWantWrite means ssl wants to send data over the link,
82            # but might need a receive first
83            except ssl.SSLWantWriteError:
84                output = await self.stream.recv(4096)
85                if output:
86                    self.incoming.write(output)
87                data = self.outgoing.read()
88                if data:
89                    await self.stream.send(data)
90                await asyncio.sleep(0.02)
91
92            # SSLWantRead means ssl wants to receive data from the link,
93            # but might need to send first
94            except ssl.SSLWantReadError:
95                data = self.outgoing.read()
96                if data:
97                    await self.stream.send(data)
98                output = await self.stream.recv(4096)
99                if output:
100                    self.incoming.write(output)
101                await asyncio.sleep(0.02)
102        else:
103            print('TLS Connection timed out.')
104            return False
105        return True
106
107    async def send(self, bytes):
108        self.ssl_object.write(bytes)
109        encode = self.outgoing.read(4096)
110        await self.stream.send(encode)
111
112    async def recv(self, buffersize, timeout=1):
113        end_time = asyncio.get_event_loop().time() + timeout
114        data = await self.stream.recv(buffersize)
115        while not data and asyncio.get_event_loop().time() < end_time:
116            await asyncio.sleep(0.1)
117            data = await self.stream.recv(buffersize)
118        if not data:
119            logger.warning('No response when response expected.')
120            return b''
121
122        self.incoming.write(data)
123        while True:
124            try:
125                decode = self.ssl_object.read(4096)
126                break
127            # if recv called before entire message was received from the link
128            except ssl.SSLWantReadError:
129                more = await self.stream.recv(buffersize)
130                while not more:
131                    await asyncio.sleep(0.1)
132                    more = await self.stream.recv(buffersize)
133                self.incoming.write(more)
134        return decode
135
136    async def send_with_resp(self, bytes):
137        await self.send(bytes)
138        res = await self.recv(buffersize=4096, timeout=5)
139        return res
140
141    async def close(self):
142        if self.ssl_object.session is not None:
143            logger.debug('sending Disconnect command TLV')
144            data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes()
145            await self.send(data)
146
147    def log_cert_identities(self):
148        # using the internal object of the ssl library is necessary to see the cert data in
149        # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores
150        # Should work for Python >= 3.10
151        try:
152            cc = self.ssl_object._sslobj.get_unverified_chain()
153            if cc is None:
154                logger.info('No TCAT Device cert chain was received (yet).')
155                return
156            logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.')
157            for cert in cc:
158                logger.info(f'  cert info:\n{cert.get_info()}')
159                peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER))
160                logger.info(f'  base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}')
161            logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}')
162
163        except Exception as e:
164            logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)')
165