1""" 2 Copyright (c) 2024, The OpenThread Authors. 3 All rights reserved. 4 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions are met: 7 1. Redistributions of source code must retain the above copyright 8 notice, this list of conditions and the following disclaimer. 9 2. Redistributions in binary form must reproduce the above copyright 10 notice, this list of conditions and the following disclaimer in the 11 documentation and/or other materials provided with the distribution. 12 3. Neither the name of the copyright holder nor the 13 names of its contributors may be used to endorse or promote products 14 derived from this software without specific prior written permission. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 POSSIBILITY OF SUCH DAMAGE. 27""" 28 29import _ssl 30import asyncio 31import ssl 32import sys 33import logging 34 35from cryptography.x509 import load_der_x509_certificate 36from cryptography.hazmat.primitives.serialization import (Encoding, PublicFormat) 37from tlv.tlv import TLV 38from tlv.tcat_tlv import TcatTLVType 39from time import time 40import utils 41 42logger = logging.getLogger(__name__) 43 44 45class BleStreamSecure: 46 47 def __init__(self, stream): 48 self.stream = stream 49 self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 50 self.incoming = ssl.MemoryBIO() 51 self.outgoing = ssl.MemoryBIO() 52 self.ssl_object = None 53 self.cert = '' 54 self.peer_challenge = None 55 self._peer_public_key = None 56 57 def load_cert(self, certfile='', keyfile='', cafile=''): 58 if certfile and keyfile: 59 self.ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile) 60 self.cert = utils.load_cert_pem(certfile) 61 elif certfile: 62 self.ssl_context.load_cert_chain(certfile=certfile) 63 self.cert = utils.load_cert_pem(certfile) 64 65 if cafile: 66 self.ssl_context.load_verify_locations(cafile=cafile) 67 68 async def do_handshake(self, timeout=30.0): 69 is_debug = logger.getEffectiveLevel() <= logging.DEBUG 70 self.ssl_object = self.ssl_context.wrap_bio( 71 incoming=self.incoming, 72 outgoing=self.outgoing, 73 server_side=False, 74 server_hostname=None, 75 ) 76 start = time() 77 while (time() - start) < timeout: 78 try: 79 if not is_debug: 80 print('.', end='') 81 sys.stdout.flush() 82 self.ssl_object.do_handshake() 83 break 84 85 # SSLWantWrite means ssl wants to send data over the link, 86 # but might need a receive first 87 except ssl.SSLWantWriteError: 88 output = await self.stream.recv(4096) 89 if output: 90 self.incoming.write(output) 91 data = self.outgoing.read() 92 if data: 93 await self.stream.send(data) 94 await asyncio.sleep(0.02) 95 96 # SSLWantRead means ssl wants to receive data from the link, 97 # but might need to send first 98 except ssl.SSLWantReadError: 99 data = self.outgoing.read() 100 if data: 101 await self.stream.send(data) 102 output = await self.stream.recv(4096) 103 if output: 104 self.incoming.write(output) 105 await asyncio.sleep(0.02) 106 else: 107 print('TLS Connection timed out.') 108 return False 109 print('') 110 cert = self.ssl_object.getpeercert(True) 111 cert_obj = load_der_x509_certificate(cert) 112 self._peer_public_key = cert_obj.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) 113 self.log_cert_identities() 114 return True 115 116 async def send(self, bytes): 117 self.ssl_object.write(bytes) 118 encode = self.outgoing.read(4096) 119 await self.stream.send(encode) 120 121 async def recv(self, buffersize, timeout=1): 122 end_time = asyncio.get_event_loop().time() + timeout 123 data = await self.stream.recv(buffersize) 124 while not data and asyncio.get_event_loop().time() < end_time: 125 await asyncio.sleep(0.1) 126 data = await self.stream.recv(buffersize) 127 if not data: 128 logger.warning('No response when response expected.') 129 return b'' 130 131 self.incoming.write(data) 132 while True: 133 try: 134 decode = self.ssl_object.read(4096) 135 break 136 # if recv called before entire message was received from the link 137 except ssl.SSLWantReadError: 138 more = await self.stream.recv(buffersize) 139 while not more: 140 await asyncio.sleep(0.1) 141 more = await self.stream.recv(buffersize) 142 self.incoming.write(more) 143 return decode 144 145 async def send_with_resp(self, bytes): 146 await self.send(bytes) 147 res = await self.recv(buffersize=4096, timeout=5) 148 return res 149 150 async def close(self): 151 if self.ssl_object.session is not None: 152 logger.debug('sending Disconnect command TLV') 153 data = TLV(TcatTLVType.DISCONNECT.value, bytes()).to_bytes() 154 self.peer_challenge = None 155 self._peer_public_key = None 156 await self.send(data) 157 158 @property 159 def peer_public_key(self): 160 return self._peer_public_key 161 162 @property 163 def peer_challenge(self): 164 return self._peer_challenge 165 166 @peer_challenge.setter 167 def peer_challenge(self, value): 168 self._peer_challenge = value 169 170 def log_cert_identities(self): 171 # using the internal object of the ssl library is necessary to see the cert data in 172 # case of handshake failure - see https://sethmlarson.dev/experimental-python-3.10-apis-and-trust-stores 173 # Should work for Python >= 3.10 174 try: 175 cc = self.ssl_object._sslobj.get_unverified_chain() 176 if cc is None: 177 logger.info('No TCAT Device cert chain was received (yet).') 178 return 179 logger.info(f'TCAT Device cert chain: {len(cc)} certificates received.') 180 for cert in cc: 181 logger.info(f' cert info:\n{cert.get_info()}') 182 peer_cert_der_hex = utils.base64_string(cert.public_bytes(_ssl.ENCODING_DER)) 183 logger.info(f' base64: (paste in https://lapo.it/asn1js/ to decode)\n{peer_cert_der_hex}') 184 logger.info(f'TCAT Commissioner cert, PEM:\n{self.cert}') 185 186 except Exception as e: 187 logger.warning('Could not display TCAT client cert info (check Python version is >= 3.10?)') 188