1"""
2Tests for ECDSA keys
3"""
4
5# SPDX-License-Identifier: Apache-2.0
6
7import io
8import os.path
9import sys
10import tempfile
11import unittest
12
13from cryptography.exceptions import InvalidSignature
14from cryptography.hazmat.primitives.asymmetric import ec
15from cryptography.hazmat.primitives.hashes import SHA256
16
17sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
18
19from imgtool.keys import load, ECDSA256P1, ECDSAUsageError
20
21class EcKeyGeneration(unittest.TestCase):
22
23    def setUp(self):
24        self.test_dir = tempfile.TemporaryDirectory()
25
26    def tname(self, base):
27        return os.path.join(self.test_dir.name, base)
28
29    def tearDown(self):
30        self.test_dir.cleanup()
31
32    def test_keygen(self):
33        name1 = self.tname("keygen.pem")
34        k = ECDSA256P1.generate()
35        k.export_private(name1, b'secret')
36
37        self.assertIsNone(load(name1))
38
39        k2 = load(name1, b'secret')
40
41        pubname = self.tname('keygen-pub.pem')
42        k2.export_public(pubname)
43        pk2 = load(pubname)
44
45        # We should be able to export the public key from the loaded
46        # public key, but not the private key.
47        pk2.export_public(self.tname('keygen-pub2.pem'))
48        self.assertRaises(ECDSAUsageError,
49                pk2.export_private, self.tname('keygen-priv2.pem'))
50
51    def test_emit(self):
52        """Basic sanity check on the code emitters."""
53        k = ECDSA256P1.generate()
54
55        pubpem = io.StringIO()
56        k.emit_public_pem(pubpem)
57        self.assertIn("BEGIN PUBLIC KEY", pubpem.getvalue())
58        self.assertIn("END PUBLIC KEY", pubpem.getvalue())
59
60        ccode = io.StringIO()
61        k.emit_c_public(ccode)
62        self.assertIn("ecdsa_pub_key", ccode.getvalue())
63        self.assertIn("ecdsa_pub_key_len", ccode.getvalue())
64
65        hashccode = io.StringIO()
66        k.emit_c_public_hash(hashccode)
67        self.assertIn("ecdsa_pub_key_hash", hashccode.getvalue())
68        self.assertIn("ecdsa_pub_key_hash_len", hashccode.getvalue())
69
70        rustcode = io.StringIO()
71        k.emit_rust_public(rustcode)
72        self.assertIn("ECDSA_PUB_KEY", rustcode.getvalue())
73
74        # raw data - bytes
75        pubraw = io.BytesIO()
76        k.emit_raw_public(pubraw)
77        self.assertTrue(len(pubraw.getvalue()) > 0)
78
79        hashraw = io.BytesIO()
80        k.emit_raw_public_hash(hashraw)
81        self.assertTrue(len(hashraw.getvalue()) > 0)
82
83    def test_emit_pub(self):
84        """Basic sanity check on the code emitters, from public key."""
85        pubname = self.tname("public.pem")
86        k = ECDSA256P1.generate()
87        k.export_public(pubname)
88
89        k2 = load(pubname)
90
91        ccode = io.StringIO()
92        k2.emit_c_public(ccode)
93        self.assertIn("ecdsa_pub_key", ccode.getvalue())
94        self.assertIn("ecdsa_pub_key_len", ccode.getvalue())
95
96        rustcode = io.StringIO()
97        k2.emit_rust_public(rustcode)
98        self.assertIn("ECDSA_PUB_KEY", rustcode.getvalue())
99
100    def test_sig(self):
101        k = ECDSA256P1.generate()
102        buf = b'This is the message'
103        sig = k.raw_sign(buf)
104
105        # The code doesn't have any verification, so verify this
106        # manually.
107        k.key.public_key().verify(
108                signature=sig,
109                data=buf,
110                signature_algorithm=ec.ECDSA(SHA256()))
111
112        # Modify the message to make sure the signature fails.
113        self.assertRaises(InvalidSignature,
114                k.key.public_key().verify,
115                signature=sig,
116                data=b'This is thE message',
117                signature_algorithm=ec.ECDSA(SHA256()))
118
119if __name__ == '__main__':
120    unittest.main()
121