1# Copyright 2018 Nordic Semiconductor ASA
2# Copyright 2017-2020 Linaro Limited
3# Copyright 2019-2021 Arm Limited
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11#     http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19"""
20Image signing and management.
21"""
22
23from . import version as versmod
24from .boot_record import create_sw_component_data
25import click
26from enum import Enum
27from intelhex import IntelHex
28import hashlib
29import struct
30import os.path
31from .keys import rsa, ecdsa, x25519
32from cryptography.hazmat.primitives.asymmetric import ec, padding
33from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
34from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
35from cryptography.hazmat.primitives.kdf.hkdf import HKDF
36from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
37from cryptography.hazmat.backends import default_backend
38from cryptography.hazmat.primitives import hashes, hmac
39from cryptography.exceptions import InvalidSignature
40
41IMAGE_MAGIC = 0x96f3b83d
42IMAGE_HEADER_SIZE = 32
43BIN_EXT = "bin"
44INTEL_HEX_EXT = "hex"
45DEFAULT_MAX_SECTORS = 128
46MAX_ALIGN = 8
47DEP_IMAGES_KEY = "images"
48DEP_VERSIONS_KEY = "versions"
49MAX_SW_TYPE_LENGTH = 12  # Bytes
50
51# Image header flags.
52IMAGE_F = {
53        'PIC':                   0x0000001,
54        'ENCRYPTED_AES128':      0x0000004,
55        'ENCRYPTED_AES256':      0x0000008,
56        'NON_BOOTABLE':          0x0000010,
57        'RAM_LOAD':              0x0000020,
58        'ROM_FIXED':             0x0000100,
59}
60
61TLV_VALUES = {
62        'KEYHASH': 0x01,
63        'PUBKEY': 0x02,
64        'SHA256': 0x10,
65        'RSA2048': 0x20,
66        'ECDSA224': 0x21,
67        'ECDSA256': 0x22,
68        'RSA3072': 0x23,
69        'ED25519': 0x24,
70        'ENCRSA2048': 0x30,
71        'ENCKW': 0x31,
72        'ENCEC256': 0x32,
73        'ENCX25519': 0x33,
74        'DEPENDENCY': 0x40,
75        'SEC_CNT': 0x50,
76        'BOOT_RECORD': 0x60,
77}
78
79TLV_SIZE = 4
80TLV_INFO_SIZE = 4
81TLV_INFO_MAGIC = 0x6907
82TLV_PROT_INFO_MAGIC = 0x6908
83
84boot_magic = bytes([
85    0x77, 0xc2, 0x95, 0xf3,
86    0x60, 0xd2, 0xef, 0x7f,
87    0x35, 0x52, 0x50, 0x0f,
88    0x2c, 0xb6, 0x79, 0x80, ])
89
90STRUCT_ENDIAN_DICT = {
91        'little': '<',
92        'big':    '>'
93}
94
95VerifyResult = Enum('VerifyResult',
96                    """
97                    OK INVALID_MAGIC INVALID_TLV_INFO_MAGIC INVALID_HASH
98                    INVALID_SIGNATURE
99                    """)
100
101
102class TLV():
103    def __init__(self, endian, magic=TLV_INFO_MAGIC):
104        self.magic = magic
105        self.buf = bytearray()
106        self.endian = endian
107
108    def __len__(self):
109        return TLV_INFO_SIZE + len(self.buf)
110
111    def add(self, kind, payload):
112        """
113        Add a TLV record.  Kind should be a string found in TLV_VALUES above.
114        """
115        e = STRUCT_ENDIAN_DICT[self.endian]
116        if isinstance(kind, int):
117            buf = struct.pack(e + 'BBH', kind, 0, len(payload))
118        else:
119            buf = struct.pack(e + 'BBH', TLV_VALUES[kind], 0, len(payload))
120        self.buf += buf
121        self.buf += payload
122
123    def get(self):
124        if len(self.buf) == 0:
125            return bytes()
126        e = STRUCT_ENDIAN_DICT[self.endian]
127        header = struct.pack(e + 'HH', self.magic, len(self))
128        return header + bytes(self.buf)
129
130
131class Image():
132
133    def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
134                 pad_header=False, pad=False, confirm=False, align=1,
135                 slot_size=0, max_sectors=DEFAULT_MAX_SECTORS,
136                 overwrite_only=False, endian="little", load_addr=0,
137                 rom_fixed=None, erased_val=None, save_enctlv=False,
138                 security_counter=None):
139
140        if load_addr and rom_fixed:
141            raise click.UsageError("Can not set rom_fixed and load_addr at the same time")
142
143        self.version = version or versmod.decode_version("0")
144        self.header_size = header_size
145        self.pad_header = pad_header
146        self.pad = pad
147        self.confirm = confirm
148        self.align = align
149        self.slot_size = slot_size
150        self.max_sectors = max_sectors
151        self.overwrite_only = overwrite_only
152        self.endian = endian
153        self.base_addr = None
154        self.load_addr = 0 if load_addr is None else load_addr
155        self.rom_fixed = rom_fixed
156        self.erased_val = 0xff if erased_val is None else int(erased_val, 0)
157        self.payload = []
158        self.enckey = None
159        self.save_enctlv = save_enctlv
160        self.enctlv_len = 0
161
162        if security_counter == 'auto':
163            # Security counter has not been explicitly provided,
164            # generate it from the version number
165            self.security_counter = ((self.version.major << 24)
166                                     + (self.version.minor << 16)
167                                     + self.version.revision)
168        else:
169            self.security_counter = security_counter
170
171    def __repr__(self):
172        return "<Image version={}, header_size={}, security_counter={}, \
173                base_addr={}, load_addr={}, align={}, slot_size={}, \
174                max_sectors={}, overwrite_only={}, endian={} format={}, \
175                payloadlen=0x{:x}>".format(
176                    self.version,
177                    self.header_size,
178                    self.security_counter,
179                    self.base_addr if self.base_addr is not None else "N/A",
180                    self.load_addr,
181                    self.align,
182                    self.slot_size,
183                    self.max_sectors,
184                    self.overwrite_only,
185                    self.endian,
186                    self.__class__.__name__,
187                    len(self.payload))
188
189    def load(self, path):
190        """Load an image from a given file"""
191        ext = os.path.splitext(path)[1][1:].lower()
192        try:
193            if ext == INTEL_HEX_EXT:
194                ih = IntelHex(path)
195                self.payload = ih.tobinarray()
196                self.base_addr = ih.minaddr()
197            else:
198                with open(path, 'rb') as f:
199                    self.payload = f.read()
200        except FileNotFoundError:
201            raise click.UsageError("Input file not found")
202
203        # Add the image header if needed.
204        if self.pad_header and self.header_size > 0:
205            if self.base_addr:
206                # Adjust base_addr for new header
207                self.base_addr -= self.header_size
208            self.payload = bytes([self.erased_val] * self.header_size) + \
209                self.payload
210
211        self.check_header()
212
213    def save(self, path, hex_addr=None):
214        """Save an image from a given file"""
215        ext = os.path.splitext(path)[1][1:].lower()
216        if ext == INTEL_HEX_EXT:
217            # input was in binary format, but HEX needs to know the base addr
218            if self.base_addr is None and hex_addr is None:
219                raise click.UsageError("No address exists in input file "
220                                       "neither was it provided by user")
221            h = IntelHex()
222            if hex_addr is not None:
223                self.base_addr = hex_addr
224            h.frombytes(bytes=self.payload, offset=self.base_addr)
225            if self.pad:
226                trailer_size = self._trailer_size(self.align, self.max_sectors,
227                                                  self.overwrite_only,
228                                                  self.enckey,
229                                                  self.save_enctlv,
230                                                  self.enctlv_len)
231                trailer_addr = (self.base_addr + self.slot_size) - trailer_size
232                padding = bytearray([self.erased_val] *
233                                    (trailer_size - len(boot_magic)))
234                if self.confirm and not self.overwrite_only:
235                    padding[-MAX_ALIGN] = 0x01  # image_ok = 0x01
236                padding += boot_magic
237                h.puts(trailer_addr, bytes(padding))
238            h.tofile(path, 'hex')
239        else:
240            if self.pad:
241                self.pad_to(self.slot_size)
242            with open(path, 'wb') as f:
243                f.write(self.payload)
244
245    def check_header(self):
246        if self.header_size > 0 and not self.pad_header:
247            if any(v != 0 for v in self.payload[0:self.header_size]):
248                raise click.UsageError("Header padding was not requested and "
249                                       "image does not start with zeros")
250
251    def check_trailer(self):
252        if self.slot_size > 0:
253            tsize = self._trailer_size(self.align, self.max_sectors,
254                                       self.overwrite_only, self.enckey,
255                                       self.save_enctlv, self.enctlv_len)
256            padding = self.slot_size - (len(self.payload) + tsize)
257            if padding < 0:
258                msg = "Image size (0x{:x}) + trailer (0x{:x}) exceeds " \
259                      "requested size 0x{:x}".format(
260                          len(self.payload), tsize, self.slot_size)
261                raise click.UsageError(msg)
262
263    def ecies_hkdf(self, enckey, plainkey):
264        if isinstance(enckey, ecdsa.ECDSA256P1Public):
265            newpk = ec.generate_private_key(ec.SECP256R1(), default_backend())
266            shared = newpk.exchange(ec.ECDH(), enckey._get_public())
267        else:
268            newpk = X25519PrivateKey.generate()
269            shared = newpk.exchange(enckey._get_public())
270        derived_key = HKDF(
271            algorithm=hashes.SHA256(), length=48, salt=None,
272            info=b'MCUBoot_ECIES_v1', backend=default_backend()).derive(shared)
273        encryptor = Cipher(algorithms.AES(derived_key[:16]),
274                           modes.CTR(bytes([0] * 16)),
275                           backend=default_backend()).encryptor()
276        cipherkey = encryptor.update(plainkey) + encryptor.finalize()
277        mac = hmac.HMAC(derived_key[16:], hashes.SHA256(),
278                        backend=default_backend())
279        mac.update(cipherkey)
280        ciphermac = mac.finalize()
281        if isinstance(enckey, ecdsa.ECDSA256P1Public):
282            pubk = newpk.public_key().public_bytes(
283                encoding=Encoding.X962,
284                format=PublicFormat.UncompressedPoint)
285        else:
286            pubk = newpk.public_key().public_bytes(
287                encoding=Encoding.Raw,
288                format=PublicFormat.Raw)
289        return cipherkey, ciphermac, pubk
290
291    def create(self, key, public_key_format, enckey, dependencies=None,
292               sw_type=None, custom_tlvs=None, encrypt_keylen=128):
293        self.enckey = enckey
294
295        # Calculate the hash of the public key
296        if key is not None:
297            pub = key.get_public_bytes()
298            sha = hashlib.sha256()
299            sha.update(pub)
300            pubbytes = sha.digest()
301        else:
302            pubbytes = bytes(hashlib.sha256().digest_size)
303
304        protected_tlv_size = 0
305
306        if self.security_counter is not None:
307            # Size of the security counter TLV: header ('HH') + payload ('I')
308            #                                   = 4 + 4 = 8 Bytes
309            protected_tlv_size += TLV_SIZE + 4
310
311        if sw_type is not None:
312            if len(sw_type) > MAX_SW_TYPE_LENGTH:
313                msg = "'{}' is too long ({} characters) for sw_type. Its " \
314                      "maximum allowed length is 12 characters.".format(
315                       sw_type, len(sw_type))
316                raise click.UsageError(msg)
317
318            image_version = (str(self.version.major) + '.'
319                             + str(self.version.minor) + '.'
320                             + str(self.version.revision))
321
322            # The image hash is computed over the image header, the image
323            # itself and the protected TLV area. However, the boot record TLV
324            # (which is part of the protected area) should contain this hash
325            # before it is even calculated. For this reason the script fills
326            # this field with zeros and the bootloader will insert the right
327            # value later.
328            digest = bytes(hashlib.sha256().digest_size)
329
330            # Create CBOR encoded boot record
331            boot_record = create_sw_component_data(sw_type, image_version,
332                                                   "SHA256", digest,
333                                                   pubbytes)
334
335            protected_tlv_size += TLV_SIZE + len(boot_record)
336
337        if dependencies is not None:
338            # Size of a Dependency TLV = Header ('HH') + Payload('IBBHI')
339            # = 4 + 12 = 16 Bytes
340            dependencies_num = len(dependencies[DEP_IMAGES_KEY])
341            protected_tlv_size += (dependencies_num * 16)
342
343        if custom_tlvs is not None:
344            for value in custom_tlvs.values():
345                protected_tlv_size += TLV_SIZE + len(value)
346
347        if protected_tlv_size != 0:
348            # Add the size of the TLV info header
349            protected_tlv_size += TLV_INFO_SIZE
350
351        # At this point the image is already on the payload
352        #
353        # This adds the padding if image is not aligned to the 16 Bytes
354        # in encrypted mode
355        if self.enckey is not None:
356            pad_len = len(self.payload) % 16
357            if pad_len > 0:
358                pad = bytes(16 - pad_len)
359                if isinstance(self.payload, bytes):
360                    self.payload += pad
361                else:
362                    self.payload.extend(pad)
363
364        # This adds the header to the payload as well
365        if encrypt_keylen == 256:
366            self.add_header(enckey, protected_tlv_size, 256)
367        else:
368            self.add_header(enckey, protected_tlv_size)
369
370        prot_tlv = TLV(self.endian, TLV_PROT_INFO_MAGIC)
371
372        # Protected TLVs must be added first, because they are also included
373        # in the hash calculation
374        protected_tlv_off = None
375        if protected_tlv_size != 0:
376
377            e = STRUCT_ENDIAN_DICT[self.endian]
378
379            if self.security_counter is not None:
380                payload = struct.pack(e + 'I', self.security_counter)
381                prot_tlv.add('SEC_CNT', payload)
382
383            if sw_type is not None:
384                prot_tlv.add('BOOT_RECORD', boot_record)
385
386            if dependencies is not None:
387                for i in range(dependencies_num):
388                    payload = struct.pack(
389                                    e + 'B3x'+'BBHI',
390                                    int(dependencies[DEP_IMAGES_KEY][i]),
391                                    dependencies[DEP_VERSIONS_KEY][i].major,
392                                    dependencies[DEP_VERSIONS_KEY][i].minor,
393                                    dependencies[DEP_VERSIONS_KEY][i].revision,
394                                    dependencies[DEP_VERSIONS_KEY][i].build
395                                    )
396                    prot_tlv.add('DEPENDENCY', payload)
397
398            if custom_tlvs is not None:
399                for tag, value in custom_tlvs.items():
400                    prot_tlv.add(tag, value)
401
402            protected_tlv_off = len(self.payload)
403            self.payload += prot_tlv.get()
404
405        tlv = TLV(self.endian)
406
407        # Note that ecdsa wants to do the hashing itself, which means
408        # we get to hash it twice.
409        sha = hashlib.sha256()
410        sha.update(self.payload)
411        digest = sha.digest()
412
413        tlv.add('SHA256', digest)
414
415        if key is not None:
416            if public_key_format == 'hash':
417                tlv.add('KEYHASH', pubbytes)
418            else:
419                tlv.add('PUBKEY', pub)
420
421            # `sign` expects the full image payload (sha256 done internally),
422            # while `sign_digest` expects only the digest of the payload
423
424            if hasattr(key, 'sign'):
425                sig = key.sign(bytes(self.payload))
426            else:
427                sig = key.sign_digest(digest)
428            tlv.add(key.sig_tlv(), sig)
429
430        # At this point the image was hashed + signed, we can remove the
431        # protected TLVs from the payload (will be re-added later)
432        if protected_tlv_off is not None:
433            self.payload = self.payload[:protected_tlv_off]
434
435        if enckey is not None:
436            if encrypt_keylen == 256:
437                plainkey = os.urandom(32)
438            else:
439                plainkey = os.urandom(16)
440
441            if isinstance(enckey, rsa.RSAPublic):
442                cipherkey = enckey._get_public().encrypt(
443                    plainkey, padding.OAEP(
444                        mgf=padding.MGF1(algorithm=hashes.SHA256()),
445                        algorithm=hashes.SHA256(),
446                        label=None))
447                self.enctlv_len = len(cipherkey)
448                tlv.add('ENCRSA2048', cipherkey)
449            elif isinstance(enckey, (ecdsa.ECDSA256P1Public,
450                                     x25519.X25519Public)):
451                cipherkey, mac, pubk = self.ecies_hkdf(enckey, plainkey)
452                enctlv = pubk + mac + cipherkey
453                self.enctlv_len = len(enctlv)
454                if isinstance(enckey, ecdsa.ECDSA256P1Public):
455                    tlv.add('ENCEC256', enctlv)
456                else:
457                    tlv.add('ENCX25519', enctlv)
458
459            nonce = bytes([0] * 16)
460            cipher = Cipher(algorithms.AES(plainkey), modes.CTR(nonce),
461                            backend=default_backend())
462            encryptor = cipher.encryptor()
463            img = bytes(self.payload[self.header_size:])
464            self.payload[self.header_size:] = \
465                encryptor.update(img) + encryptor.finalize()
466
467        self.payload += prot_tlv.get()
468        self.payload += tlv.get()
469
470        self.check_trailer()
471
472    def add_header(self, enckey, protected_tlv_size, aes_length=128):
473        """Install the image header."""
474
475        flags = 0
476        if enckey is not None:
477            if aes_length == 128:
478                flags |= IMAGE_F['ENCRYPTED_AES128']
479            else:
480                flags |= IMAGE_F['ENCRYPTED_AES256']
481        if self.load_addr != 0:
482            # Indicates that this image should be loaded into RAM
483            # instead of run directly from flash.
484            flags |= IMAGE_F['RAM_LOAD']
485        if self.rom_fixed:
486            flags |= IMAGE_F['ROM_FIXED']
487
488        e = STRUCT_ENDIAN_DICT[self.endian]
489        fmt = (e +
490               # type ImageHdr struct {
491               'I' +     # Magic    uint32
492               'I' +     # LoadAddr uint32
493               'H' +     # HdrSz    uint16
494               'H' +     # PTLVSz   uint16
495               'I' +     # ImgSz    uint32
496               'I' +     # Flags    uint32
497               'BBHI' +  # Vers     ImageVersion
498               'I'       # Pad1     uint32
499               )  # }
500        assert struct.calcsize(fmt) == IMAGE_HEADER_SIZE
501        header = struct.pack(fmt,
502                IMAGE_MAGIC,
503                self.rom_fixed or self.load_addr,
504                self.header_size,
505                protected_tlv_size,  # TLV Info header + Protected TLVs
506                len(self.payload) - self.header_size,  # ImageSz
507                flags,
508                self.version.major,
509                self.version.minor or 0,
510                self.version.revision or 0,
511                self.version.build or 0,
512                0)  # Pad1
513        self.payload = bytearray(self.payload)
514        self.payload[:len(header)] = header
515
516    def _trailer_size(self, write_size, max_sectors, overwrite_only, enckey,
517                      save_enctlv, enctlv_len):
518        # NOTE: should already be checked by the argument parser
519        magic_size = 16
520        if overwrite_only:
521            return MAX_ALIGN * 2 + magic_size
522        else:
523            if write_size not in set([1, 2, 4, 8]):
524                raise click.BadParameter("Invalid alignment: {}".format(
525                    write_size))
526            m = DEFAULT_MAX_SECTORS if max_sectors is None else max_sectors
527            trailer = m * 3 * write_size  # status area
528            if enckey is not None:
529                if save_enctlv:
530                    # TLV saved by the bootloader is aligned
531                    keylen = (int((enctlv_len - 1) / MAX_ALIGN) + 1) * MAX_ALIGN
532                else:
533                    keylen = 16
534                trailer += keylen * 2  # encryption keys
535            trailer += MAX_ALIGN * 4  # image_ok/copy_done/swap_info/swap_size
536            trailer += magic_size
537            return trailer
538
539    def pad_to(self, size):
540        """Pad the image to the given size, with the given flash alignment."""
541        tsize = self._trailer_size(self.align, self.max_sectors,
542                                   self.overwrite_only, self.enckey,
543                                   self.save_enctlv, self.enctlv_len)
544        padding = size - (len(self.payload) + tsize)
545        pbytes = bytearray([self.erased_val] * padding)
546        pbytes += bytearray([self.erased_val] * (tsize - len(boot_magic)))
547        if self.confirm and not self.overwrite_only:
548            pbytes[-MAX_ALIGN] = 0x01  # image_ok = 0x01
549        pbytes += boot_magic
550        self.payload += pbytes
551
552    @staticmethod
553    def verify(imgfile, key):
554        with open(imgfile, "rb") as f:
555            b = f.read()
556
557        magic, _, header_size, _, img_size = struct.unpack('IIHHI', b[:16])
558        version = struct.unpack('BBHI', b[20:28])
559
560        if magic != IMAGE_MAGIC:
561            return VerifyResult.INVALID_MAGIC, None, None
562
563        tlv_off = header_size + img_size
564        tlv_info = b[tlv_off:tlv_off+TLV_INFO_SIZE]
565        magic, tlv_tot = struct.unpack('HH', tlv_info)
566        if magic == TLV_PROT_INFO_MAGIC:
567            tlv_off += tlv_tot
568            tlv_info = b[tlv_off:tlv_off+TLV_INFO_SIZE]
569            magic, tlv_tot = struct.unpack('HH', tlv_info)
570
571        if magic != TLV_INFO_MAGIC:
572            return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None
573
574        sha = hashlib.sha256()
575        prot_tlv_size = tlv_off
576        sha.update(b[:prot_tlv_size])
577        digest = sha.digest()
578
579        tlv_end = tlv_off + tlv_tot
580        tlv_off += TLV_INFO_SIZE  # skip tlv info
581        while tlv_off < tlv_end:
582            tlv = b[tlv_off:tlv_off+TLV_SIZE]
583            tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
584            if tlv_type == TLV_VALUES["SHA256"]:
585                off = tlv_off + TLV_SIZE
586                if digest == b[off:off+tlv_len]:
587                    if key is None:
588                        return VerifyResult.OK, version, digest
589                else:
590                    return VerifyResult.INVALID_HASH, None, None
591            elif key is not None and tlv_type == TLV_VALUES[key.sig_tlv()]:
592                off = tlv_off + TLV_SIZE
593                tlv_sig = b[off:off+tlv_len]
594                payload = b[:prot_tlv_size]
595                try:
596                    if hasattr(key, 'verify'):
597                        key.verify(tlv_sig, payload)
598                    else:
599                        key.verify_digest(tlv_sig, digest)
600                    return VerifyResult.OK, version, digest
601                except InvalidSignature:
602                    # continue to next TLV
603                    pass
604            tlv_off += TLV_SIZE + tlv_len
605        return VerifyResult.INVALID_SIGNATURE, None, None
606