1"""
2Tests for RSA keys
3"""
4
5# SPDX-License-Identifier: Apache-2.0
6
7import io
8import os
9import sys
10import tempfile
11import unittest
12
13from cryptography.exceptions import InvalidSignature
14from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
15from cryptography.hazmat.primitives.hashes import SHA256
16
17# Setup sys path so 'imgtool' is in it.
18sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
19                                                '../..')))
20
21from imgtool.keys import load, RSA, RSAUsageError
22from imgtool.keys.rsa import RSA_KEY_SIZES
23
24
25class KeyGeneration(unittest.TestCase):
26
27    def setUp(self):
28        self.test_dir = tempfile.TemporaryDirectory()
29
30    def tname(self, base):
31        return os.path.join(self.test_dir.name, base)
32
33    def tearDown(self):
34        self.test_dir.cleanup()
35
36    def test_keygen(self):
37        # Try generating a RSA key with non-supported size
38        with self.assertRaises(RSAUsageError):
39            RSA.generate(key_size=1024)
40
41        for key_size in RSA_KEY_SIZES:
42            name1 = self.tname("keygen.pem")
43            k = RSA.generate(key_size=key_size)
44            k.export_private(name1, b'secret')
45
46            # Try loading the key without a password.
47            self.assertIsNone(load(name1))
48
49            k2 = load(name1, b'secret')
50
51            pubname = self.tname('keygen-pub.pem')
52            k2.export_public(pubname)
53            pk2 = load(pubname)
54
55            # We should be able to export the public key from the loaded
56            # public key, but not the private key.
57            pk2.export_public(self.tname('keygen-pub2.pem'))
58            self.assertRaises(RSAUsageError, pk2.export_private,
59                              self.tname('keygen-priv2.pem'))
60
61    def test_emit(self):
62        """Basic sanity check on the code emitters."""
63        for key_size in RSA_KEY_SIZES:
64            k = RSA.generate(key_size=key_size)
65
66            ccode = io.StringIO()
67            k.emit_c_public(ccode)
68            self.assertIn("rsa_pub_key", ccode.getvalue())
69            self.assertIn("rsa_pub_key_len", ccode.getvalue())
70
71            rustcode = io.StringIO()
72            k.emit_rust_public(rustcode)
73            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
74
75    def test_emit_pub(self):
76        """Basic sanity check on the code emitters, from public key."""
77        pubname = self.tname("public.pem")
78        for key_size in RSA_KEY_SIZES:
79            k = RSA.generate(key_size=key_size)
80            k.export_public(pubname)
81
82            k2 = load(pubname)
83
84            ccode = io.StringIO()
85            k2.emit_c_public(ccode)
86            self.assertIn("rsa_pub_key", ccode.getvalue())
87            self.assertIn("rsa_pub_key_len", ccode.getvalue())
88
89            rustcode = io.StringIO()
90            k2.emit_rust_public(rustcode)
91            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
92
93    def test_sig(self):
94        for key_size in RSA_KEY_SIZES:
95            k = RSA.generate(key_size=key_size)
96            buf = b'This is the message'
97            sig = k.sign(buf)
98
99            # The code doesn't have any verification, so verify this
100            # manually.
101            k.key.public_key().verify(
102                signature=sig,
103                data=buf,
104                padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
105                algorithm=SHA256())
106
107            # Modify the message to make sure the signature fails.
108            self.assertRaises(InvalidSignature,
109                              k.key.public_key().verify,
110                              signature=sig,
111                              data=b'This is thE message',
112                              padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
113                              algorithm=SHA256())
114
115
116if __name__ == '__main__':
117    unittest.main()
118