1"""
2  Copyright (c) 2024, The OpenThread Authors.
3  All rights reserved.
4
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  1. Redistributions of source code must retain the above copyright
8     notice, this list of conditions and the following disclaimer.
9  2. Redistributions in binary form must reproduce the above copyright
10     notice, this list of conditions and the following disclaimer in the
11     documentation and/or other materials provided with the distribution.
12  3. Neither the name of the copyright holder nor the
13     names of its contributors may be used to endorse or promote products
14     derived from this software without specific prior written permission.
15
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  POSSIBILITY OF SUCH DAMAGE.
27"""
28
29import _ssl
30import asyncio
31import ssl
32import sys
33import logging
34
35from tlv.tlv import TLV
36from tlv.tcat_tlv import TcatTLVType
37import utils
38
39logger = logging.getLogger(__name__)
40
41
42class BleStreamSecure:
43
44    def __init__(self, stream):
45        self.stream = stream
46        self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
47        self.incoming = ssl.MemoryBIO()
48        self.outgoing = ssl.MemoryBIO()
49        self.ssl_object = None
50        self.cert = ''
51
52    def load_cert(self, certfile='', keyfile='', cafile=''):
53        if certfile and keyfile:
54            self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile)
55            self.cert = utils.load_cert_pem(certfile)
56        elif certfile:
57            self.ssl_context.load_cert_chain(certfile=certfile)
58            self.cert = utils.load_cert_pem(certfile)
59
60        if cafile:
61            self.ssl_context.load_verify_locations(cafile=cafile)
62
63    async def do_handshake(self):
64        is_debug = logger.getEffectiveLevel() <= logging.DEBUG
65        self.ssl_object = self.ssl_context.wrap_bio(
66            incoming=self.incoming,
67            outgoing=self.outgoing,
68            server_side=False,
69            server_hostname=None,
70        )
71        while True:
72            try:
73                if not is_debug:
74                    print('.', end='')
75                    sys.stdout.flush()
76                self.ssl_object.do_handshake()
77                break
78
79            # SSLWantWrite means ssl wants to send data over the link,
80            # but might need a receive first
81            except ssl.SSLWantWriteError:
82                output = await self.stream.recv(4096)
83                if output:
84                    self.incoming.write(output)
85                data = self.outgoing.read()
86                if data:
87                    await self.stream.send(data)
88                await asyncio.sleep(0.02)
89
90            # SSLWantRead means ssl wants to receive data from the link,
91            # but might need to send first
92            except ssl.SSLWantReadError:
93                data = self.outgoing.read()
94                if data:
95                    await self.stream.send(data)
96                output = await self.stream.recv(4096)
97                if output:
98                    self.incoming.write(output)
99                await asyncio.sleep(0.02)
100
101    async def send(self, bytes):
102        self.ssl_object.write(bytes)
103        encode = self.outgoing.read(4096)
104        await self.stream.send(encode)
105
106    async def recv(self, buffersize, timeout=1):
107        end_time = asyncio.get_event_loop().time() + timeout
108        data = await self.stream.recv(buffersize)
109        while not data and asyncio.get_event_loop().time() < end_time:
110            await asyncio.sleep(0.1)
111            data = await self.stream.recv(buffersize)
112        if not data:
113            logger.warning('No response when response expected.')
114            return b''
115
116        self.incoming.write(data)
117        while True:
118            try:
119                decode = self.ssl_object.read(4096)
120                break
121            # if recv called before entire message was received from the link
122            except ssl.SSLWantReadError:
123                more = await self.stream.recv(buffersize)
124                while not more:
125                    await asyncio.sleep(0.1)
126                    more = await self.stream.recv(buffersize)
127                self.incoming.write(more)
128        return decode
129
130    async def send_with_resp(self, bytes):
131        await self.send(bytes)
132        res = await self.recv(buffersize=4096, timeout=5)
133        return res
134
135    async def close(self):
136        if self.ssl_object.session is not None:
137            data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes()
138            await self.send(data)
139
140    def log_cert_identities(self):
141        # using the internal object of the ssl library is necessary to see the cert data in
142        # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores
143        # Should work for Python >= 3.10
144        try:
145            cc = self.ssl_object._sslobj.get_unverified_chain()
146            if cc is None:
147                logger.info('No TCAT Device cert chain was received (yet).')
148                return
149            logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.')
150            for cert in cc:
151                logger.info(f'  cert info:\n{cert.get_info()}')
152                peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER))
153                logger.info(f'  base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}')
154            logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}')
155
156        except Exception as e:
157            logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)')
158