1"""
2ECDSA key management
3"""
4
5# SPDX-License-Identifier: Apache-2.0
6import os.path
7import hashlib
8
9from cryptography.hazmat.backends import default_backend
10from cryptography.hazmat.primitives import serialization
11from cryptography.hazmat.primitives.asymmetric import ec
12from cryptography.hazmat.primitives.hashes import SHA256, SHA384
13
14from .general import KeyClass
15from .privatebytes import PrivateBytesMixin
16
17
18class ECDSAUsageError(Exception):
19    pass
20
21
22class ECDSAPublicKey(KeyClass):
23    """
24    Wrapper around an ECDSA public key.
25    """
26    def __init__(self, key):
27        self.key = key
28
29    def _unsupported(self, name):
30        raise ECDSAUsageError("Operation {} requires private key".format(name))
31
32    def _get_public(self):
33        return self.key.public_key()
34
35    def get_public_bytes(self):
36        # The key is embedded into MBUboot in "SubjectPublicKeyInfo" format
37        return self._get_public().public_bytes(
38                encoding=serialization.Encoding.DER,
39                format=serialization.PublicFormat.SubjectPublicKeyInfo)
40
41    def get_public_pem(self):
42        return self._get_public().public_bytes(
43                encoding=serialization.Encoding.PEM,
44                format=serialization.PublicFormat.SubjectPublicKeyInfo)
45
46    def get_private_bytes(self, minimal, format):
47        self._unsupported('get_private_bytes')
48
49    def export_private(self, path, passwd=None):
50        self._unsupported('export_private')
51
52    def export_public(self, path):
53        """Write the public key to the given file."""
54        pem = self._get_public().public_bytes(
55                encoding=serialization.Encoding.PEM,
56                format=serialization.PublicFormat.SubjectPublicKeyInfo)
57        with open(path, 'wb') as f:
58            f.write(pem)
59
60
61class ECDSAPrivateKey(PrivateBytesMixin):
62    """
63    Wrapper around an ECDSA private key.
64    """
65    def __init__(self, key):
66        self.key = key
67
68    def _build_minimal_ecdsa_privkey(self, der, format):
69        '''
70        Builds a new DER that only includes the EC private key, removing the
71        public key that is added as an "optional" BITSTRING.
72        '''
73
74        if format == serialization.PrivateFormat.OpenSSH:
75            print(os.path.basename(__file__) +
76                  ': Warning: --minimal is supported only for PKCS8 '
77                  'or TraditionalOpenSSL formats')
78            return bytearray(der)
79
80        EXCEPTION_TEXT = "Error parsing ecdsa key. Please submit an issue!"
81        if format == serialization.PrivateFormat.PKCS8:
82            offset_PUB = 68  # where the context specific TLV starts (tag 0xA1)
83            if der[offset_PUB] != 0xa1:
84                raise ECDSAUsageError(EXCEPTION_TEXT)
85            len_PUB = der[offset_PUB + 1] + 2  # + 2 for 0xA1 0x44 bytes
86            b = bytearray(der[:offset_PUB])  # remove the TLV with the PUB key
87            offset_SEQ = 29
88            if b[offset_SEQ] != 0x30:
89                raise ECDSAUsageError(EXCEPTION_TEXT)
90            b[offset_SEQ + 1] -= len_PUB
91            offset_OCT_STR = 27
92            if b[offset_OCT_STR] != 0x04:
93                raise ECDSAUsageError(EXCEPTION_TEXT)
94            b[offset_OCT_STR + 1] -= len_PUB
95            if b[0] != 0x30 or b[1] != 0x81:
96                raise ECDSAUsageError(EXCEPTION_TEXT)
97            # as b[1] has bit7 set, the length is on b[2]
98            b[2] -= len_PUB
99            if b[2] < 0x80:
100                del(b[1])
101
102        elif format == serialization.PrivateFormat.TraditionalOpenSSL:
103            offset_PUB = 51
104            if der[offset_PUB] != 0xA1:
105                raise ECDSAUsageError(EXCEPTION_TEXT)
106            len_PUB = der[offset_PUB + 1] + 2
107            b = bytearray(der[0:offset_PUB])
108            b[1] -= len_PUB
109
110        return b
111
112    _VALID_FORMATS = {
113        'pkcs8': serialization.PrivateFormat.PKCS8,
114        'openssl': serialization.PrivateFormat.TraditionalOpenSSL
115    }
116    _DEFAULT_FORMAT = 'pkcs8'
117
118    def get_private_bytes(self, minimal, format):
119        format, priv = self._get_private_bytes(minimal,
120                                               format, ECDSAUsageError)
121        if minimal:
122            priv = self._build_minimal_ecdsa_privkey(
123                priv, self._VALID_FORMATS[format])
124        return priv
125
126    def export_private(self, path, passwd=None):
127        """Write the private key to the given file, protecting it with '
128          'the optional password."""
129        if passwd is None:
130            enc = serialization.NoEncryption()
131        else:
132            enc = serialization.BestAvailableEncryption(passwd)
133        pem = self.key.private_bytes(
134                encoding=serialization.Encoding.PEM,
135                format=serialization.PrivateFormat.PKCS8,
136                encryption_algorithm=enc)
137        with open(path, 'wb') as f:
138            f.write(pem)
139
140
141class ECDSA256P1Public(ECDSAPublicKey):
142    """
143    Wrapper around an ECDSA (p256) public key.
144    """
145    def __init__(self, key):
146        super().__init__(key)
147        self.key = key
148
149    def shortname(self):
150        return "ecdsa"
151
152    def sig_type(self):
153        return "ECDSA256_SHA256"
154
155    def sig_tlv(self):
156        return "ECDSASIG"
157
158    def sig_len(self):
159        # Early versions of MCUboot (< v1.5.0) required ECDSA
160        # signatures to be padded to 72 bytes.  Because the DER
161        # encoding is done with signed integers, the size of the
162        # signature will vary depending on whether the high bit is set
163        # in each value.  This padding was done in a
164        # not-easily-reversible way (by just adding zeros).
165        #
166        # The signing code no longer requires this padding, and newer
167        # versions of MCUboot don't require it.  But, continue to
168        # return the total length so that the padding can be done if
169        # requested.
170        return 72
171
172    def verify(self, signature, payload):
173        # strip possible paddings added during sign
174        signature = signature[:signature[1] + 2]
175        k = self.key
176        if isinstance(self.key, ec.EllipticCurvePrivateKey):
177            k = self.key.public_key()
178        return k.verify(signature=signature, data=payload,
179                        signature_algorithm=ec.ECDSA(SHA256()))
180
181
182class ECDSA256P1(ECDSA256P1Public, ECDSAPrivateKey):
183    """
184    Wrapper around an ECDSA (p256) private key.
185    """
186    def __init__(self, key):
187        super().__init__(key)
188        self.key = key
189        self.pad_sig = False
190
191    @staticmethod
192    def generate():
193        pk = ec.generate_private_key(
194                ec.SECP256R1(),
195                backend=default_backend())
196        return ECDSA256P1(pk)
197
198    def raw_sign(self, payload):
199        """Return the actual signature"""
200        return self.key.sign(
201                data=payload,
202                signature_algorithm=ec.ECDSA(SHA256()))
203
204    def sign(self, payload):
205        sig = self.raw_sign(payload)
206        if self.pad_sig:
207            # To make fixed length, pad with one or two zeros.
208            sig += b'\000' * (self.sig_len() - len(sig))
209            return sig
210        else:
211            return sig
212
213
214class ECDSA384P1Public(ECDSAPublicKey):
215    """
216    Wrapper around an ECDSA (p384) public key.
217    """
218    def __init__(self, key):
219        super().__init__(key)
220        self.key = key
221
222    def shortname(self):
223        return "ecdsap384"
224
225    def sig_type(self):
226        return "ECDSA384_SHA384"
227
228    def sig_tlv(self):
229        return "ECDSASIG"
230
231    def sig_len(self):
232        # Early versions of MCUboot (< v1.5.0) required ECDSA
233        # signatures to be padded to a fixed length.  Because the DER
234        # encoding is done with signed integers, the size of the
235        # signature will vary depending on whether the high bit is set
236        # in each value.  This padding was done in a
237        # not-easily-reversible way (by just adding zeros).
238        #
239        # The signing code no longer requires this padding, and newer
240        # versions of MCUboot don't require it.  But, continue to
241        # return the total length so that the padding can be done if
242        # requested.
243        return 103
244
245    def verify(self, signature, payload):
246        # strip possible paddings added during sign
247        signature = signature[:signature[1] + 2]
248        k = self.key
249        if isinstance(self.key, ec.EllipticCurvePrivateKey):
250            k = self.key.public_key()
251        return k.verify(signature=signature, data=payload,
252                        signature_algorithm=ec.ECDSA(SHA384()))
253
254
255class ECDSA384P1(ECDSA384P1Public, ECDSAPrivateKey):
256    """
257    Wrapper around an ECDSA (p384) private key.
258    """
259
260    def __init__(self, key):
261        """key should be an instance of EllipticCurvePrivateKey"""
262        super().__init__(key)
263        self.key = key
264        self.pad_sig = False
265
266    @staticmethod
267    def generate():
268        pk = ec.generate_private_key(
269                ec.SECP384R1(),
270                backend=default_backend())
271        return ECDSA384P1(pk)
272
273    def raw_sign(self, payload):
274        """Return the actual signature"""
275        return self.key.sign(
276                data=payload,
277                signature_algorithm=ec.ECDSA(SHA384()))
278
279    def sign(self, payload):
280        sig = self.raw_sign(payload)
281        if self.pad_sig:
282            # To make fixed length, pad with one or two zeros.
283            sig += b'\000' * (self.sig_len() - len(sig))
284            return sig
285        else:
286            return sig
287