1"""
2Tests for RSA keys
3"""
4
5# SPDX-License-Identifier: Apache-2.0
6
7import io
8import os
9import sys
10import tempfile
11import unittest
12
13from cryptography.exceptions import InvalidSignature
14from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
15from cryptography.hazmat.primitives.hashes import SHA256
16
17# Setup sys path so 'imgtool' is in it.
18sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__),
19                                                '../..')))
20
21from imgtool.keys import load, RSA, RSAUsageError
22from imgtool.keys.rsa import RSA_KEY_SIZES
23
24
25class KeyGeneration(unittest.TestCase):
26
27    def setUp(self):
28        self.test_dir = tempfile.TemporaryDirectory()
29
30    def tname(self, base):
31        return os.path.join(self.test_dir.name, base)
32
33    def tearDown(self):
34        self.test_dir.cleanup()
35
36    def test_keygen(self):
37        # Try generating a RSA key with non-supported size
38        with self.assertRaises(RSAUsageError):
39            RSA.generate(key_size=1024)
40
41        for key_size in RSA_KEY_SIZES:
42            name1 = self.tname("keygen.pem")
43            k = RSA.generate(key_size=key_size)
44            k.export_private(name1, b'secret')
45
46            # Try loading the key without a password.
47            self.assertIsNone(load(name1))
48
49            k2 = load(name1, b'secret')
50
51            pubname = self.tname('keygen-pub.pem')
52            k2.export_public(pubname)
53            pk2 = load(pubname)
54
55            # We should be able to export the public key from the loaded
56            # public key, but not the private key.
57            pk2.export_public(self.tname('keygen-pub2.pem'))
58            self.assertRaises(RSAUsageError, pk2.export_private,
59                              self.tname('keygen-priv2.pem'))
60
61    def test_emit(self):
62        """Basic sanity check on the code emitters."""
63        for key_size in RSA_KEY_SIZES:
64            k = RSA.generate(key_size=key_size)
65
66            pubpem = io.StringIO()
67            k.emit_public_pem(pubpem)
68            self.assertIn("BEGIN PUBLIC KEY", pubpem.getvalue())
69            self.assertIn("END PUBLIC KEY", pubpem.getvalue())
70
71            ccode = io.StringIO()
72            k.emit_c_public(ccode)
73            self.assertIn("rsa_pub_key", ccode.getvalue())
74            self.assertIn("rsa_pub_key_len", ccode.getvalue())
75
76            hashccode = io.StringIO()
77            k.emit_c_public_hash(hashccode)
78            self.assertIn("rsa_pub_key_hash", hashccode.getvalue())
79            self.assertIn("rsa_pub_key_hash_len", hashccode.getvalue())
80
81            rustcode = io.StringIO()
82            k.emit_rust_public(rustcode)
83            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
84
85            # raw data - bytes
86            pubraw = io.BytesIO()
87            k.emit_raw_public(pubraw)
88            self.assertTrue(len(pubraw.getvalue()) > 0)
89
90            hashraw = io.BytesIO()
91            k.emit_raw_public_hash(hashraw)
92            self.assertTrue(len(hashraw.getvalue()) > 0)
93
94    def test_emit_pub(self):
95        """Basic sanity check on the code emitters, from public key."""
96        pubname = self.tname("public.pem")
97        for key_size in RSA_KEY_SIZES:
98            k = RSA.generate(key_size=key_size)
99            k.export_public(pubname)
100
101            k2 = load(pubname)
102
103            ccode = io.StringIO()
104            k2.emit_c_public(ccode)
105            self.assertIn("rsa_pub_key", ccode.getvalue())
106            self.assertIn("rsa_pub_key_len", ccode.getvalue())
107
108            rustcode = io.StringIO()
109            k2.emit_rust_public(rustcode)
110            self.assertIn("RSA_PUB_KEY", rustcode.getvalue())
111
112    def test_sig(self):
113        for key_size in RSA_KEY_SIZES:
114            k = RSA.generate(key_size=key_size)
115            buf = b'This is the message'
116            sig = k.sign(buf)
117
118            # The code doesn't have any verification, so verify this
119            # manually.
120            k.key.public_key().verify(
121                signature=sig,
122                data=buf,
123                padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
124                algorithm=SHA256())
125
126            # Modify the message to make sure the signature fails.
127            self.assertRaises(InvalidSignature,
128                              k.key.public_key().verify,
129                              signature=sig,
130                              data=b'This is thE message',
131                              padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
132                              algorithm=SHA256())
133
134
135if __name__ == '__main__':
136    unittest.main()
137