1""" 2 Copyright (c) 2024, The OpenThread Authors. 3 All rights reserved. 4 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions are met: 7 1. Redistributions of source code must retain the above copyright 8 notice, this list of conditions and the following disclaimer. 9 2. Redistributions in binary form must reproduce the above copyright 10 notice, this list of conditions and the following disclaimer in the 11 documentation and/or other materials provided with the distribution. 12 3. Neither the name of the copyright holder nor the 13 names of its contributors may be used to endorse or promote products 14 derived from this software without specific prior written permission. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 POSSIBILITY OF SUCH DAMAGE. 27""" 28 29import _ssl 30import asyncio 31import ssl 32import sys 33import logging 34 35from tlv.tlv import TLV 36from tlv.tcat_tlv import TcatTLVType 37import utils 38 39logger = logging.getLogger(__name__) 40 41 42class BleStreamSecure: 43 44 def __init__(self, stream): 45 self.stream = stream 46 self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 47 self.incoming = ssl.MemoryBIO() 48 self.outgoing = ssl.MemoryBIO() 49 self.ssl_object = None 50 self.cert = '' 51 52 def load_cert(self, certfile='', keyfile='', cafile=''): 53 if certfile and keyfile: 54 self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile) 55 self.cert = utils.load_cert_pem(certfile) 56 elif certfile: 57 self.ssl_context.load_cert_chain(certfile=certfile) 58 self.cert = utils.load_cert_pem(certfile) 59 60 if cafile: 61 self.ssl_context.load_verify_locations(cafile=cafile) 62 63 async def do_handshake(self): 64 is_debug = logger.getEffectiveLevel() <= logging.DEBUG 65 self.ssl_object = self.ssl_context.wrap_bio( 66 incoming=self.incoming, 67 outgoing=self.outgoing, 68 server_side=False, 69 server_hostname=None, 70 ) 71 while True: 72 try: 73 if not is_debug: 74 print('.', end='') 75 sys.stdout.flush() 76 self.ssl_object.do_handshake() 77 break 78 79 # SSLWantWrite means ssl wants to send data over the link, 80 # but might need a receive first 81 except ssl.SSLWantWriteError: 82 output = await self.stream.recv(4096) 83 if output: 84 self.incoming.write(output) 85 data = self.outgoing.read() 86 if data: 87 await self.stream.send(data) 88 await asyncio.sleep(0.02) 89 90 # SSLWantRead means ssl wants to receive data from the link, 91 # but might need to send first 92 except ssl.SSLWantReadError: 93 data = self.outgoing.read() 94 if data: 95 await self.stream.send(data) 96 output = await self.stream.recv(4096) 97 if output: 98 self.incoming.write(output) 99 await asyncio.sleep(0.02) 100 101 async def send(self, bytes): 102 self.ssl_object.write(bytes) 103 encode = self.outgoing.read(4096) 104 await self.stream.send(encode) 105 106 async def recv(self, buffersize, timeout=1): 107 end_time = asyncio.get_event_loop().time() + timeout 108 data = await self.stream.recv(buffersize) 109 while not data and asyncio.get_event_loop().time() < end_time: 110 await asyncio.sleep(0.1) 111 data = await self.stream.recv(buffersize) 112 if not data: 113 logger.warning('No response when response expected.') 114 return b'' 115 116 self.incoming.write(data) 117 while True: 118 try: 119 decode = self.ssl_object.read(4096) 120 break 121 # if recv called before entire message was received from the link 122 except ssl.SSLWantReadError: 123 more = await self.stream.recv(buffersize) 124 while not more: 125 await asyncio.sleep(0.1) 126 more = await self.stream.recv(buffersize) 127 self.incoming.write(more) 128 return decode 129 130 async def send_with_resp(self, bytes): 131 await self.send(bytes) 132 res = await self.recv(buffersize=4096, timeout=5) 133 return res 134 135 async def close(self): 136 if self.ssl_object.session is not None: 137 data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes() 138 await self.send(data) 139 140 def log_cert_identities(self): 141 # using the internal object of the ssl library is necessary to see the cert data in 142 # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores 143 # Should work for Python >= 3.10 144 try: 145 cc = self.ssl_object._sslobj.get_unverified_chain() 146 if cc is None: 147 logger.info('No TCAT Device cert chain was received (yet).') 148 return 149 logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.') 150 for cert in cc: 151 logger.info(f' cert info:\n{cert.get_info()}') 152 peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER)) 153 logger.info(f' base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}') 154 logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}') 155 156 except Exception as e: 157 logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)') 158