1""" 2RSA Key management 3""" 4 5# SPDX-License-Identifier: Apache-2.0 6 7from cryptography.hazmat.backends import default_backend 8from cryptography.hazmat.primitives import serialization 9from cryptography.hazmat.primitives.asymmetric import rsa 10from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1 11from cryptography.hazmat.primitives.hashes import SHA256 12 13from .general import KeyClass 14from .privatebytes import PrivateBytesMixin 15 16 17# Sizes that bootutil will recognize 18RSA_KEY_SIZES = [2048, 3072] 19 20 21class RSAUsageError(Exception): 22 pass 23 24 25class RSAPublic(KeyClass): 26 """The public key can only do a few operations""" 27 def __init__(self, key): 28 self.key = key 29 30 def key_size(self): 31 return self.key.key_size 32 33 def shortname(self): 34 return "rsa" 35 36 def _unsupported(self, name): 37 raise RSAUsageError("Operation {} requires private key".format(name)) 38 39 def _get_public(self): 40 return self.key 41 42 def get_public_bytes(self): 43 # The key embedded into MCUboot is in PKCS1 format. 44 return self._get_public().public_bytes( 45 encoding=serialization.Encoding.DER, 46 format=serialization.PublicFormat.PKCS1) 47 48 def get_public_pem(self): 49 return self._get_public().public_bytes( 50 encoding=serialization.Encoding.PEM, 51 format=serialization.PublicFormat.SubjectPublicKeyInfo) 52 53 def get_private_bytes(self, minimal, format): 54 self._unsupported('get_private_bytes') 55 56 def export_private(self, path, passwd=None): 57 self._unsupported('export_private') 58 59 def export_public(self, path): 60 """Write the public key to the given file.""" 61 pem = self._get_public().public_bytes( 62 encoding=serialization.Encoding.PEM, 63 format=serialization.PublicFormat.SubjectPublicKeyInfo) 64 with open(path, 'wb') as f: 65 f.write(pem) 66 67 def sig_type(self): 68 return "PKCS1_PSS_RSA{}_SHA256".format(self.key_size()) 69 70 def sig_tlv(self): 71 return"RSA{}".format(self.key_size()) 72 73 def sig_len(self): 74 return self.key_size() / 8 75 76 def verify(self, signature, payload): 77 k = self.key 78 if isinstance(self.key, rsa.RSAPrivateKey): 79 k = self.key.public_key() 80 return k.verify(signature=signature, data=payload, 81 padding=PSS(mgf=MGF1(SHA256()), salt_length=32), 82 algorithm=SHA256()) 83 84 85class RSA(RSAPublic, PrivateBytesMixin): 86 """ 87 Wrapper around an RSA key, with imgtool support. 88 """ 89 90 def __init__(self, key): 91 """The key should be a private key from cryptography""" 92 self.key = key 93 94 @staticmethod 95 def generate(key_size=2048): 96 if key_size not in RSA_KEY_SIZES: 97 raise RSAUsageError("Key size {} is not supported by MCUboot" 98 .format(key_size)) 99 pk = rsa.generate_private_key( 100 public_exponent=65537, 101 key_size=key_size, 102 backend=default_backend()) 103 return RSA(pk) 104 105 def _get_public(self): 106 return self.key.public_key() 107 108 def _build_minimal_rsa_privkey(self, der): 109 ''' 110 Builds a new DER that only includes N/E/D/P/Q RSA parameters; 111 standard DER private bytes provided by OpenSSL also includes 112 CRT params (DP/DQ/QP) which can be removed. 113 ''' 114 OFFSET_N = 7 # N is always located at this offset 115 b = bytearray(der) 116 off = OFFSET_N 117 if b[off + 1] != 0x82: 118 raise RSAUsageError("Error parsing N while minimizing") 119 len_N = (b[off + 2] << 8) + b[off + 3] + 4 120 off += len_N 121 if b[off + 1] != 0x03: 122 raise RSAUsageError("Error parsing E while minimizing") 123 len_E = b[off + 2] + 4 124 off += len_E 125 if b[off + 1] != 0x82: 126 raise RSAUsageError("Error parsing D while minimizing") 127 len_D = (b[off + 2] << 8) + b[off + 3] + 4 128 off += len_D 129 if b[off + 1] != 0x81: 130 raise RSAUsageError("Error parsing P while minimizing") 131 len_P = b[off + 2] + 3 132 off += len_P 133 if b[off + 1] != 0x81: 134 raise RSAUsageError("Error parsing Q while minimizing") 135 len_Q = b[off + 2] + 3 136 off += len_Q 137 # adjust DER size for removed elements 138 b[2] = (off - 4) >> 8 139 b[3] = (off - 4) & 0xff 140 return b[:off] 141 142 _VALID_FORMATS = { 143 'openssl': serialization.PrivateFormat.TraditionalOpenSSL 144 } 145 _DEFAULT_FORMAT = 'openssl' 146 147 def get_private_bytes(self, minimal, format): 148 _, priv = self._get_private_bytes(minimal, format, RSAUsageError) 149 if minimal: 150 priv = self._build_minimal_rsa_privkey(priv) 151 return priv 152 153 def export_private(self, path, passwd=None): 154 """Write the private key to the given file, protecting it with the 155 optional password.""" 156 if passwd is None: 157 enc = serialization.NoEncryption() 158 else: 159 enc = serialization.BestAvailableEncryption(passwd) 160 pem = self.key.private_bytes( 161 encoding=serialization.Encoding.PEM, 162 format=serialization.PrivateFormat.PKCS8, 163 encryption_algorithm=enc) 164 with open(path, 'wb') as f: 165 f.write(pem) 166 167 def sign(self, payload): 168 # The verification code only allows the salt length to be the 169 # same as the hash length, 32. 170 return self.key.sign( 171 data=payload, 172 padding=PSS(mgf=MGF1(SHA256()), salt_length=32), 173 algorithm=SHA256()) 174