1""" 2 Copyright (c) 2024, The OpenThread Authors. 3 All rights reserved. 4 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions are met: 7 1. Redistributions of source code must retain the above copyright 8 notice, this list of conditions and the following disclaimer. 9 2. Redistributions in binary form must reproduce the above copyright 10 notice, this list of conditions and the following disclaimer in the 11 documentation and/or other materials provided with the distribution. 12 3. Neither the name of the copyright holder nor the 13 names of its contributors may be used to endorse or promote products 14 derived from this software without specific prior written permission. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 POSSIBILITY OF SUCH DAMAGE. 27""" 28 29import _ssl 30import asyncio 31import ssl 32import sys 33import logging 34 35from tlv.tlv import TLV 36from tlv.tcat_tlv import TcatTLVType 37from time import time 38import utils 39 40logger = logging.getLogger(__name__) 41 42 43class BleStreamSecure: 44 45 def __init__(self, stream): 46 self.stream = stream 47 self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 48 self.incoming = ssl.MemoryBIO() 49 self.outgoing = ssl.MemoryBIO() 50 self.ssl_object = None 51 self.cert = '' 52 53 def load_cert(self, certfile='', keyfile='', cafile=''): 54 if certfile and keyfile: 55 self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile) 56 self.cert = utils.load_cert_pem(certfile) 57 elif certfile: 58 self.ssl_context.load_cert_chain(certfile=certfile) 59 self.cert = utils.load_cert_pem(certfile) 60 61 if cafile: 62 self.ssl_context.load_verify_locations(cafile=cafile) 63 64 async def do_handshake(self, timeout=30.0): 65 is_debug = logger.getEffectiveLevel() <= logging.DEBUG 66 self.ssl_object = self.ssl_context.wrap_bio( 67 incoming=self.incoming, 68 outgoing=self.outgoing, 69 server_side=False, 70 server_hostname=None, 71 ) 72 start = time() 73 while (time() - start) < timeout: 74 try: 75 if not is_debug: 76 print('.', end='') 77 sys.stdout.flush() 78 self.ssl_object.do_handshake() 79 break 80 81 # SSLWantWrite means ssl wants to send data over the link, 82 # but might need a receive first 83 except ssl.SSLWantWriteError: 84 output = await self.stream.recv(4096) 85 if output: 86 self.incoming.write(output) 87 data = self.outgoing.read() 88 if data: 89 await self.stream.send(data) 90 await asyncio.sleep(0.02) 91 92 # SSLWantRead means ssl wants to receive data from the link, 93 # but might need to send first 94 except ssl.SSLWantReadError: 95 data = self.outgoing.read() 96 if data: 97 await self.stream.send(data) 98 output = await self.stream.recv(4096) 99 if output: 100 self.incoming.write(output) 101 await asyncio.sleep(0.02) 102 else: 103 print('TLS Connection timed out.') 104 return False 105 return True 106 107 async def send(self, bytes): 108 self.ssl_object.write(bytes) 109 encode = self.outgoing.read(4096) 110 await self.stream.send(encode) 111 112 async def recv(self, buffersize, timeout=1): 113 end_time = asyncio.get_event_loop().time() + timeout 114 data = await self.stream.recv(buffersize) 115 while not data and asyncio.get_event_loop().time() < end_time: 116 await asyncio.sleep(0.1) 117 data = await self.stream.recv(buffersize) 118 if not data: 119 logger.warning('No response when response expected.') 120 return b'' 121 122 self.incoming.write(data) 123 while True: 124 try: 125 decode = self.ssl_object.read(4096) 126 break 127 # if recv called before entire message was received from the link 128 except ssl.SSLWantReadError: 129 more = await self.stream.recv(buffersize) 130 while not more: 131 await asyncio.sleep(0.1) 132 more = await self.stream.recv(buffersize) 133 self.incoming.write(more) 134 return decode 135 136 async def send_with_resp(self, bytes): 137 await self.send(bytes) 138 res = await self.recv(buffersize=4096, timeout=5) 139 return res 140 141 async def close(self): 142 if self.ssl_object.session is not None: 143 logger.debug('sending Disconnect command TLV') 144 data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes() 145 await self.send(data) 146 147 def log_cert_identities(self): 148 # using the internal object of the ssl library is necessary to see the cert data in 149 # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores 150 # Should work for Python >= 3.10 151 try: 152 cc = self.ssl_object._sslobj.get_unverified_chain() 153 if cc is None: 154 logger.info('No TCAT Device cert chain was received (yet).') 155 return 156 logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.') 157 for cert in cc: 158 logger.info(f' cert info:\n{cert.get_info()}') 159 peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER)) 160 logger.info(f' base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}') 161 logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}') 162 163 except Exception as e: 164 logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)') 165