1""" 2Tests for RSA keys 3""" 4 5# SPDX-License-Identifier: Apache-2.0 6 7import io 8import os 9import sys 10import tempfile 11import unittest 12 13from cryptography.exceptions import InvalidSignature 14from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1 15from cryptography.hazmat.primitives.hashes import SHA256 16 17# Setup sys path so 'imgtool' is in it. 18sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 19 '../..'))) 20 21from imgtool.keys import load, RSA, RSAUsageError 22from imgtool.keys.rsa import RSA_KEY_SIZES 23 24 25class KeyGeneration(unittest.TestCase): 26 27 def setUp(self): 28 self.test_dir = tempfile.TemporaryDirectory() 29 30 def tname(self, base): 31 return os.path.join(self.test_dir.name, base) 32 33 def tearDown(self): 34 self.test_dir.cleanup() 35 36 def test_keygen(self): 37 # Try generating a RSA key with non-supported size 38 with self.assertRaises(RSAUsageError): 39 RSA.generate(key_size=1024) 40 41 for key_size in RSA_KEY_SIZES: 42 name1 = self.tname("keygen.pem") 43 k = RSA.generate(key_size=key_size) 44 k.export_private(name1, b'secret') 45 46 # Try loading the key without a password. 47 self.assertIsNone(load(name1)) 48 49 k2 = load(name1, b'secret') 50 51 pubname = self.tname('keygen-pub.pem') 52 k2.export_public(pubname) 53 pk2 = load(pubname) 54 55 # We should be able to export the public key from the loaded 56 # public key, but not the private key. 57 pk2.export_public(self.tname('keygen-pub2.pem')) 58 self.assertRaises(RSAUsageError, pk2.export_private, 59 self.tname('keygen-priv2.pem')) 60 61 def test_emit(self): 62 """Basic sanity check on the code emitters.""" 63 for key_size in RSA_KEY_SIZES: 64 k = RSA.generate(key_size=key_size) 65 66 pubpem = io.StringIO() 67 k.emit_public_pem(pubpem) 68 self.assertIn("BEGIN PUBLIC KEY", pubpem.getvalue()) 69 self.assertIn("END PUBLIC KEY", pubpem.getvalue()) 70 71 ccode = io.StringIO() 72 k.emit_c_public(ccode) 73 self.assertIn("rsa_pub_key", ccode.getvalue()) 74 self.assertIn("rsa_pub_key_len", ccode.getvalue()) 75 76 hashccode = io.StringIO() 77 k.emit_c_public_hash(hashccode) 78 self.assertIn("rsa_pub_key_hash", hashccode.getvalue()) 79 self.assertIn("rsa_pub_key_hash_len", hashccode.getvalue()) 80 81 rustcode = io.StringIO() 82 k.emit_rust_public(rustcode) 83 self.assertIn("RSA_PUB_KEY", rustcode.getvalue()) 84 85 # raw data - bytes 86 pubraw = io.BytesIO() 87 k.emit_raw_public(pubraw) 88 self.assertTrue(len(pubraw.getvalue()) > 0) 89 90 hashraw = io.BytesIO() 91 k.emit_raw_public_hash(hashraw) 92 self.assertTrue(len(hashraw.getvalue()) > 0) 93 94 def test_emit_pub(self): 95 """Basic sanity check on the code emitters, from public key.""" 96 pubname = self.tname("public.pem") 97 for key_size in RSA_KEY_SIZES: 98 k = RSA.generate(key_size=key_size) 99 k.export_public(pubname) 100 101 k2 = load(pubname) 102 103 ccode = io.StringIO() 104 k2.emit_c_public(ccode) 105 self.assertIn("rsa_pub_key", ccode.getvalue()) 106 self.assertIn("rsa_pub_key_len", ccode.getvalue()) 107 108 rustcode = io.StringIO() 109 k2.emit_rust_public(rustcode) 110 self.assertIn("RSA_PUB_KEY", rustcode.getvalue()) 111 112 def test_sig(self): 113 for key_size in RSA_KEY_SIZES: 114 k = RSA.generate(key_size=key_size) 115 buf = b'This is the message' 116 sig = k.sign(buf) 117 118 # The code doesn't have any verification, so verify this 119 # manually. 120 k.key.public_key().verify( 121 signature=sig, 122 data=buf, 123 padding=PSS(mgf=MGF1(SHA256()), salt_length=32), 124 algorithm=SHA256()) 125 126 # Modify the message to make sure the signature fails. 127 self.assertRaises(InvalidSignature, 128 k.key.public_key().verify, 129 signature=sig, 130 data=b'This is thE message', 131 padding=PSS(mgf=MGF1(SHA256()), salt_length=32), 132 algorithm=SHA256()) 133 134 135if __name__ == '__main__': 136 unittest.main() 137