1# Copyright 2018 Nordic Semiconductor ASA
2# Copyright 2017-2020 Linaro Limited
3# Copyright 2019-2024 Arm Limited
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11#     http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19"""
20Image signing and management.
21"""
22
23from . import version as versmod
24from .boot_record import create_sw_component_data
25import click
26import copy
27from enum import Enum
28import array
29from intelhex import IntelHex
30import hashlib
31import array
32import os.path
33import struct
34from enum import Enum
35
36import click
37from cryptography.exceptions import InvalidSignature
38from cryptography.hazmat.backends import default_backend
39from cryptography.hazmat.primitives import hashes, hmac
40from cryptography.hazmat.primitives.asymmetric import ec, padding
41from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
42from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
43from cryptography.hazmat.primitives.kdf.hkdf import HKDF
44from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
45from intelhex import IntelHex
46
47from . import version as versmod, keys
48from .boot_record import create_sw_component_data
49from .keys import rsa, ecdsa, x25519
50
51from collections import namedtuple
52
53IMAGE_MAGIC = 0x96f3b83d
54IMAGE_HEADER_SIZE = 32
55BIN_EXT = "bin"
56INTEL_HEX_EXT = "hex"
57DEFAULT_MAX_SECTORS = 128
58DEFAULT_MAX_ALIGN = 8
59DEP_IMAGES_KEY = "images"
60DEP_VERSIONS_KEY = "versions"
61MAX_SW_TYPE_LENGTH = 12  # Bytes
62
63# Image header flags.
64IMAGE_F = {
65        'PIC':                   0x0000001,
66        'ENCRYPTED_AES128':      0x0000004,
67        'ENCRYPTED_AES256':      0x0000008,
68        'NON_BOOTABLE':          0x0000010,
69        'RAM_LOAD':              0x0000020,
70        'ROM_FIXED':             0x0000100,
71        'COMPRESSED_LZMA1':      0x0000200,
72        'COMPRESSED_LZMA2':      0x0000400,
73        'COMPRESSED_ARM_THUMB':  0x0000800,
74}
75
76TLV_VALUES = {
77        'KEYHASH': 0x01,
78        'PUBKEY': 0x02,
79        'SHA256': 0x10,
80        'SHA384': 0x11,
81        'SHA512': 0x12,
82        'RSA2048': 0x20,
83        'ECDSASIG': 0x22,
84        'RSA3072': 0x23,
85        'ED25519': 0x24,
86        'SIG_PURE': 0x25,
87        'ENCRSA2048': 0x30,
88        'ENCKW': 0x31,
89        'ENCEC256': 0x32,
90        'ENCX25519': 0x33,
91        'DEPENDENCY': 0x40,
92        'SEC_CNT': 0x50,
93        'BOOT_RECORD': 0x60,
94        'DECOMP_SIZE': 0x70,
95        'DECOMP_SHA': 0x71,
96        'DECOMP_SIGNATURE': 0x72,
97        'COMP_DEC_SIZE' : 0x73,
98}
99
100TLV_SIZE = 4
101TLV_INFO_SIZE = 4
102TLV_INFO_MAGIC = 0x6907
103TLV_PROT_INFO_MAGIC = 0x6908
104
105TLV_VENDOR_RES_MIN = 0x00a0
106TLV_VENDOR_RES_MAX = 0xfffe
107
108STRUCT_ENDIAN_DICT = {
109        'little': '<',
110        'big':    '>'
111}
112
113VerifyResult = Enum('VerifyResult',
114                    ['OK', 'INVALID_MAGIC', 'INVALID_TLV_INFO_MAGIC', 'INVALID_HASH', 'INVALID_SIGNATURE',
115                     'KEY_MISMATCH'])
116
117
118def align_up(num, align):
119    assert (align & (align - 1) == 0) and align != 0
120    return (num + (align - 1)) & ~(align - 1)
121
122
123class TLV():
124    def __init__(self, endian, magic=TLV_INFO_MAGIC):
125        self.magic = magic
126        self.buf = bytearray()
127        self.endian = endian
128
129    def __len__(self):
130        return TLV_INFO_SIZE + len(self.buf)
131
132    def add(self, kind, payload):
133        """
134        Add a TLV record.  Kind should be a string found in TLV_VALUES above.
135        """
136        e = STRUCT_ENDIAN_DICT[self.endian]
137        if isinstance(kind, int):
138            if not TLV_VENDOR_RES_MIN <= kind <= TLV_VENDOR_RES_MAX:
139                msg = "Invalid custom TLV type value '0x{:04x}', allowed " \
140                      "value should be between 0x{:04x} and 0x{:04x}".format(
141                        kind, TLV_VENDOR_RES_MIN, TLV_VENDOR_RES_MAX)
142                raise click.UsageError(msg)
143            buf = struct.pack(e + 'HH', kind, len(payload))
144        else:
145            buf = struct.pack(e + 'BBH', TLV_VALUES[kind], 0, len(payload))
146        self.buf += buf
147        self.buf += payload
148
149    def get(self):
150        if len(self.buf) == 0:
151            return bytes()
152        e = STRUCT_ENDIAN_DICT[self.endian]
153        header = struct.pack(e + 'HH', self.magic, len(self))
154        return header + bytes(self.buf)
155
156
157SHAAndAlgT = namedtuple('SHAAndAlgT', ['sha', 'alg'])
158
159TLV_SHA_TO_SHA_AND_ALG = {
160    TLV_VALUES['SHA256'] : SHAAndAlgT('256', hashlib.sha256),
161    TLV_VALUES['SHA384'] : SHAAndAlgT('384', hashlib.sha384),
162    TLV_VALUES['SHA512'] : SHAAndAlgT('512', hashlib.sha512),
163}
164
165
166USER_SHA_TO_ALG_AND_TLV = {
167    'auto'   : (hashlib.sha256, 'SHA256'),
168    '256'    : (hashlib.sha256, 'SHA256'),
169    '384'    : (hashlib.sha384, 'SHA384'),
170    '512'    : (hashlib.sha512, 'SHA512')
171}
172
173
174def is_sha_tlv(tlv):
175    return tlv in TLV_SHA_TO_SHA_AND_ALG.keys()
176
177
178def tlv_sha_to_sha(tlv):
179    return TLV_SHA_TO_SHA_AND_ALG[tlv].sha
180
181
182# Auto selecting hash algorithm for type(key)
183ALLOWED_KEY_SHA = {
184    keys.ECDSA384P1         : ['384'],
185    keys.ECDSA384P1Public   : ['384'],
186    keys.ECDSA256P1         : ['256'],
187    keys.RSA                : ['256'],
188    keys.RSAPublic          : ['256'],
189    # This two are set to 256 for compatibility, the right would be 512
190    keys.Ed25519            : ['256', '512'],
191    keys.X25519             : ['256', '512']
192}
193
194ALLOWED_PURE_KEY_SHA = {
195    keys.Ed25519            : ['512']
196}
197
198ALLOWED_PURE_SIG_TLVS = [
199    TLV_VALUES['ED25519']
200]
201
202def key_and_user_sha_to_alg_and_tlv(key, user_sha, is_pure = False):
203    """Matches key and user requested sha to sha alogrithm and TLV name.
204
205       The returned tuple will contain hash functions and TVL name.
206       The function is designed to succeed or completely fail execution,
207       as providing incorrect pair here basically prevents doing
208       any more work.
209    """
210    if key is None:
211        # If key is none, we allow whatever user has selected for sha
212        return USER_SHA_TO_ALG_AND_TLV[user_sha]
213
214    # If key is not None, then we have to filter hash to only allowed
215    allowed = None
216    allowed_key_ssh = ALLOWED_PURE_KEY_SHA if is_pure else ALLOWED_KEY_SHA
217    try:
218        allowed = allowed_key_ssh[type(key)]
219
220    except KeyError:
221        raise click.UsageError("Colud not find allowed hash algorithms for {}"
222                               .format(type(key)))
223
224    # Pure enforces auto, and user selection is ignored
225    if user_sha == 'auto' or is_pure:
226        return USER_SHA_TO_ALG_AND_TLV[allowed[0]]
227
228    if user_sha in allowed:
229        return USER_SHA_TO_ALG_AND_TLV[user_sha]
230
231    raise click.UsageError("Key {} can not be used with --sha {}; allowed sha are one of {}"
232                           .format(key.sig_type(), user_sha, allowed))
233
234
235def get_digest(tlv_type, hash_region):
236    sha = TLV_SHA_TO_SHA_AND_ALG[tlv_type].alg()
237
238    sha.update(hash_region)
239    return sha.digest()
240
241
242def tlv_matches_key_type(tlv_type, key):
243    """Check if provided key matches to TLV record in the image"""
244    try:
245        # We do not need the result here, and the key_and_user_sha_to_alg_and_tlv
246        # will either succeed finding match or rise exception, so on success we
247        # return True, on exception we return False.
248        _, _ = key_and_user_sha_to_alg_and_tlv(key, tlv_sha_to_sha(tlv_type))
249        return True
250    except:
251        pass
252
253    return False
254
255
256class Image:
257
258    def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
259                 pad_header=False, pad=False, confirm=False, align=1,
260                 slot_size=0, max_sectors=DEFAULT_MAX_SECTORS,
261                 overwrite_only=False, endian="little", load_addr=0,
262                 rom_fixed=None, erased_val=None, save_enctlv=False,
263                 security_counter=None, max_align=None,
264                 non_bootable=False):
265
266        if load_addr and rom_fixed:
267            raise click.UsageError("Can not set rom_fixed and load_addr at the same time")
268
269        self.image_hash = None
270        self.image_size = None
271        self.signature = None
272        self.version = version or versmod.decode_version("0")
273        self.header_size = header_size
274        self.pad_header = pad_header
275        self.pad = pad
276        self.confirm = confirm
277        self.align = align
278        self.slot_size = slot_size
279        self.max_sectors = max_sectors
280        self.overwrite_only = overwrite_only
281        self.endian = endian
282        self.base_addr = None
283        self.load_addr = 0 if load_addr is None else load_addr
284        self.rom_fixed = rom_fixed
285        self.erased_val = 0xff if erased_val is None else int(erased_val, 0)
286        self.payload = []
287        self.infile_data = []
288        self.enckey = None
289        self.save_enctlv = save_enctlv
290        self.enctlv_len = 0
291        self.max_align = max(DEFAULT_MAX_ALIGN, align) if max_align is None else int(max_align)
292        self.non_bootable = non_bootable
293
294        if self.max_align == DEFAULT_MAX_ALIGN:
295            self.boot_magic = bytes([
296                0x77, 0xc2, 0x95, 0xf3,
297                0x60, 0xd2, 0xef, 0x7f,
298                0x35, 0x52, 0x50, 0x0f,
299                0x2c, 0xb6, 0x79, 0x80, ])
300        else:
301            lsb = self.max_align & 0x00ff
302            msb = (self.max_align & 0xff00) >> 8
303            align = bytes([msb, lsb]) if self.endian == "big" else bytes([lsb, msb])
304            self.boot_magic = align + bytes([0x2d, 0xe1,
305                                             0x5d, 0x29, 0x41, 0x0b,
306                                             0x8d, 0x77, 0x67, 0x9c,
307                                             0x11, 0x0f, 0x1f, 0x8a, ])
308
309        if security_counter == 'auto':
310            # Security counter has not been explicitly provided,
311            # generate it from the version number
312            self.security_counter = ((self.version.major << 24)
313                                     + (self.version.minor << 16)
314                                     + self.version.revision)
315        else:
316            self.security_counter = security_counter
317
318    def __repr__(self):
319        return "<Image version={}, header_size={}, security_counter={}, \
320                base_addr={}, load_addr={}, align={}, slot_size={}, \
321                max_sectors={}, overwrite_only={}, endian={} format={}, \
322                payloadlen=0x{:x}>".format(
323                    self.version,
324                    self.header_size,
325                    self.security_counter,
326                    self.base_addr if self.base_addr is not None else "N/A",
327                    self.load_addr,
328                    self.align,
329                    self.slot_size,
330                    self.max_sectors,
331                    self.overwrite_only,
332                    self.endian,
333                    self.__class__.__name__,
334                    len(self.payload))
335
336    def load(self, path):
337        """Load an image from a given file"""
338        ext = os.path.splitext(path)[1][1:].lower()
339        try:
340            if ext == INTEL_HEX_EXT:
341                ih = IntelHex(path)
342                self.infile_data = ih.tobinarray()
343                self.payload = copy.copy(self.infile_data)
344                self.base_addr = ih.minaddr()
345            else:
346                with open(path, 'rb') as f:
347                    self.infile_data = f.read()
348                    self.payload = copy.copy(self.infile_data)
349        except FileNotFoundError:
350            raise click.UsageError("Input file not found")
351        self.image_size = len(self.payload)
352
353        # Add the image header if needed.
354        if self.pad_header and self.header_size > 0:
355            if self.base_addr:
356                # Adjust base_addr for new header
357                self.base_addr -= self.header_size
358            self.payload = bytes([self.erased_val] * self.header_size) + \
359                self.payload
360
361        self.check_header()
362
363    def load_compressed(self, data, compression_header):
364        """Load an image from buffer"""
365        self.payload = compression_header + data
366        self.image_size = len(self.payload)
367
368        # Add the image header if needed.
369        if self.pad_header and self.header_size > 0:
370            if self.base_addr:
371                # Adjust base_addr for new header
372                self.base_addr -= self.header_size
373            self.payload = bytes([self.erased_val] * self.header_size) + \
374                self.payload
375
376        self.check_header()
377
378    def save(self, path, hex_addr=None):
379        """Save an image from a given file"""
380        ext = os.path.splitext(path)[1][1:].lower()
381        if ext == INTEL_HEX_EXT:
382            # input was in binary format, but HEX needs to know the base addr
383            if self.base_addr is None and hex_addr is None:
384                raise click.UsageError("No address exists in input file "
385                                       "neither was it provided by user")
386            h = IntelHex()
387            if hex_addr is not None:
388                self.base_addr = hex_addr
389            h.frombytes(bytes=self.payload, offset=self.base_addr)
390            if self.pad:
391                trailer_size = self._trailer_size(self.align, self.max_sectors,
392                                                  self.overwrite_only,
393                                                  self.enckey,
394                                                  self.save_enctlv,
395                                                  self.enctlv_len)
396                trailer_addr = (self.base_addr + self.slot_size) - trailer_size
397                if self.confirm and not self.overwrite_only:
398                    magic_align_size = align_up(len(self.boot_magic),
399                                                self.max_align)
400                    image_ok_idx = -(magic_align_size + self.max_align)
401                    flag = bytearray([self.erased_val] * self.max_align)
402                    flag[0] = 0x01  # image_ok = 0x01
403                    h.puts(trailer_addr + trailer_size + image_ok_idx,
404                           bytes(flag))
405                h.puts(trailer_addr + (trailer_size - len(self.boot_magic)),
406                       bytes(self.boot_magic))
407            h.tofile(path, 'hex')
408        else:
409            if self.pad:
410                self.pad_to(self.slot_size)
411            with open(path, 'wb') as f:
412                f.write(self.payload)
413
414    def check_header(self):
415        if self.header_size > 0 and not self.pad_header:
416            if any(v != 0 for v in self.payload[0:self.header_size]):
417                raise click.UsageError("Header padding was not requested and "
418                                       "image does not start with zeros")
419
420    def check_trailer(self):
421        if self.slot_size > 0:
422            tsize = self._trailer_size(self.align, self.max_sectors,
423                                       self.overwrite_only, self.enckey,
424                                       self.save_enctlv, self.enctlv_len)
425            padding = self.slot_size - (len(self.payload) + tsize)
426            if padding < 0:
427                msg = "Image size (0x{:x}) + trailer (0x{:x}) exceeds " \
428                      "requested size 0x{:x}".format(
429                          len(self.payload), tsize, self.slot_size)
430                raise click.UsageError(msg)
431
432    def ecies_hkdf(self, enckey, plainkey):
433        if isinstance(enckey, ecdsa.ECDSA256P1Public):
434            newpk = ec.generate_private_key(ec.SECP256R1(), default_backend())
435            shared = newpk.exchange(ec.ECDH(), enckey._get_public())
436        else:
437            newpk = X25519PrivateKey.generate()
438            shared = newpk.exchange(enckey._get_public())
439        derived_key = HKDF(
440            algorithm=hashes.SHA256(), length=48, salt=None,
441            info=b'MCUBoot_ECIES_v1', backend=default_backend()).derive(shared)
442        encryptor = Cipher(algorithms.AES(derived_key[:16]),
443                           modes.CTR(bytes([0] * 16)),
444                           backend=default_backend()).encryptor()
445        cipherkey = encryptor.update(plainkey) + encryptor.finalize()
446        mac = hmac.HMAC(derived_key[16:], hashes.SHA256(),
447                        backend=default_backend())
448        mac.update(cipherkey)
449        ciphermac = mac.finalize()
450        if isinstance(enckey, ecdsa.ECDSA256P1Public):
451            pubk = newpk.public_key().public_bytes(
452                encoding=Encoding.X962,
453                format=PublicFormat.UncompressedPoint)
454        else:
455            pubk = newpk.public_key().public_bytes(
456                encoding=Encoding.Raw,
457                format=PublicFormat.Raw)
458        return cipherkey, ciphermac, pubk
459
460    def create(self, key, public_key_format, enckey, dependencies=None,
461               sw_type=None, custom_tlvs=None, compression_tlvs=None,
462               compression_type=None, encrypt_keylen=128, clear=False,
463               fixed_sig=None, pub_key=None, vector_to_sign=None,
464               user_sha='auto', is_pure=False, keep_comp_size=False, dont_encrypt=False):
465        self.enckey = enckey
466
467        # key decides on sha, then pub_key; of both are none default is used
468        check_key = key if key is not None else pub_key
469        hash_algorithm, hash_tlv = key_and_user_sha_to_alg_and_tlv(check_key, user_sha, is_pure)
470
471        # Calculate the hash of the public key
472        if key is not None:
473            pub = key.get_public_bytes()
474            sha = hash_algorithm()
475            sha.update(pub)
476            pubbytes = sha.digest()
477        elif pub_key is not None:
478            if hasattr(pub_key, 'sign'):
479                print(os.path.basename(__file__) + ": sign the payload")
480            pub = pub_key.get_public_bytes()
481            sha = hash_algorithm()
482            sha.update(pub)
483            pubbytes = sha.digest()
484        else:
485            pubbytes = bytes(hashlib.sha256().digest_size)
486
487        protected_tlv_size = 0
488
489        if self.security_counter is not None:
490            # Size of the security counter TLV: header ('HH') + payload ('I')
491            #                                   = 4 + 4 = 8 Bytes
492            protected_tlv_size += TLV_SIZE + 4
493
494        if sw_type is not None:
495            if len(sw_type) > MAX_SW_TYPE_LENGTH:
496                msg = "'{}' is too long ({} characters) for sw_type. Its " \
497                      "maximum allowed length is 12 characters.".format(
498                       sw_type, len(sw_type))
499                raise click.UsageError(msg)
500
501            image_version = (str(self.version.major) + '.'
502                             + str(self.version.minor) + '.'
503                             + str(self.version.revision))
504
505            # The image hash is computed over the image header, the image
506            # itself and the protected TLV area. However, the boot record TLV
507            # (which is part of the protected area) should contain this hash
508            # before it is even calculated. For this reason the script fills
509            # this field with zeros and the bootloader will insert the right
510            # value later.
511            digest = bytes(hash_algorithm().digest_size)
512
513            # Create CBOR encoded boot record
514            boot_record = create_sw_component_data(sw_type, image_version,
515                                                   hash_tlv, digest,
516                                                   pubbytes)
517
518            protected_tlv_size += TLV_SIZE + len(boot_record)
519
520        if dependencies is not None:
521            # Size of a Dependency TLV = Header ('HH') + Payload('IBBHI')
522            # = 4 + 12 = 16 Bytes
523            dependencies_num = len(dependencies[DEP_IMAGES_KEY])
524            protected_tlv_size += (dependencies_num * 16)
525
526        if keep_comp_size:
527            compression_tlvs["COMP_DEC_SIZE"] = struct.pack(
528                self.get_struct_endian() + 'L', self.image_size)
529        if compression_tlvs is not None:
530            for value in compression_tlvs.values():
531                protected_tlv_size += TLV_SIZE + len(value)
532        if custom_tlvs is not None:
533            for value in custom_tlvs.values():
534                protected_tlv_size += TLV_SIZE + len(value)
535
536        if protected_tlv_size != 0:
537            # Add the size of the TLV info header
538            protected_tlv_size += TLV_INFO_SIZE
539
540        # At this point the image is already on the payload
541        #
542        # This adds the padding if image is not aligned to the 16 Bytes
543        # in encrypted mode
544        if self.enckey is not None and dont_encrypt is False:
545            pad_len = len(self.payload) % 16
546            if pad_len > 0:
547                pad = bytes(16 - pad_len)
548                if isinstance(self.payload, bytes):
549                    self.payload += pad
550                else:
551                    self.payload.extend(pad)
552
553        compression_flags = 0x0
554        if compression_tlvs is not None:
555            if compression_type in ["lzma2", "lzma2armthumb"]:
556                compression_flags = IMAGE_F['COMPRESSED_LZMA2']
557                if compression_type == "lzma2armthumb":
558                    compression_flags |= IMAGE_F['COMPRESSED_ARM_THUMB']
559        # This adds the header to the payload as well
560        if encrypt_keylen == 256:
561            self.add_header(enckey, protected_tlv_size, compression_flags, 256)
562        else:
563            self.add_header(enckey, protected_tlv_size, compression_flags)
564
565        prot_tlv = TLV(self.endian, TLV_PROT_INFO_MAGIC)
566
567        # Protected TLVs must be added first, because they are also included
568        # in the hash calculation
569        protected_tlv_off = None
570        if protected_tlv_size != 0:
571
572            e = STRUCT_ENDIAN_DICT[self.endian]
573
574            if self.security_counter is not None:
575                payload = struct.pack(e + 'I', self.security_counter)
576                prot_tlv.add('SEC_CNT', payload)
577
578            if sw_type is not None:
579                prot_tlv.add('BOOT_RECORD', boot_record)
580
581            if dependencies is not None:
582                for i in range(dependencies_num):
583                    payload = struct.pack(
584                        e + 'B3x' + 'BBHI',
585                        int(dependencies[DEP_IMAGES_KEY][i]),
586                        dependencies[DEP_VERSIONS_KEY][i].major,
587                        dependencies[DEP_VERSIONS_KEY][i].minor,
588                        dependencies[DEP_VERSIONS_KEY][i].revision,
589                        dependencies[DEP_VERSIONS_KEY][i].build
590                    )
591                    prot_tlv.add('DEPENDENCY', payload)
592
593            if compression_tlvs is not None:
594                for tag, value in compression_tlvs.items():
595                    prot_tlv.add(tag, value)
596            if custom_tlvs is not None:
597                for tag, value in custom_tlvs.items():
598                    prot_tlv.add(tag, value)
599
600            protected_tlv_off = len(self.payload)
601
602            self.payload += prot_tlv.get()
603
604        tlv = TLV(self.endian)
605
606        # These signature is done over sha of image. In case of
607        # EC signatures so called Pure algorithm, designated to be run
608        # over entire message is used with sha of image as message,
609        # so, for example, in case of ED25519 we have here SHAxxx-ED25519-SHA512.
610        sha = hash_algorithm()
611        sha.update(self.payload)
612        digest = sha.digest()
613        tlv.add(hash_tlv, digest)
614        self.image_hash = digest
615        # Unless pure, we are signing digest.
616        message = digest
617
618        if is_pure:
619            # Note that when Pure signature is used, hash TLV is not present.
620            message = bytes(self.payload)
621            e = STRUCT_ENDIAN_DICT[self.endian]
622            sig_pure = struct.pack(e + '?', True)
623            tlv.add('SIG_PURE', sig_pure)
624
625        if vector_to_sign == 'payload':
626            # Stop amending data to the image
627            # Just keep data vector which is expected to be signed
628            print(os.path.basename(__file__) + ': export payload')
629            return
630        elif vector_to_sign == 'digest':
631            self.payload = digest
632            print(os.path.basename(__file__) + ': export digest')
633            return
634
635        if key is not None or fixed_sig is not None:
636            if public_key_format == 'hash':
637                tlv.add('KEYHASH', pubbytes)
638            else:
639                tlv.add('PUBKEY', pub)
640
641            if key is not None and fixed_sig is None:
642                # `sign` expects the full image payload (hashing done
643                # internally), while `sign_digest` expects only the digest
644                # of the payload
645
646                if hasattr(key, 'sign'):
647                    print(os.path.basename(__file__) + ": sign the payload")
648                    sig = key.sign(bytes(self.payload))
649                else:
650                    print(os.path.basename(__file__) + ": sign the digest")
651                    sig = key.sign_digest(message)
652                tlv.add(key.sig_tlv(), sig)
653                self.signature = sig
654            elif fixed_sig is not None and key is None:
655                tlv.add(pub_key.sig_tlv(), fixed_sig['value'])
656                self.signature = fixed_sig['value']
657            else:
658                raise click.UsageError("Can not sign using key and provide fixed-signature at the same time")
659
660        # At this point the image was hashed + signed, we can remove the
661        # protected TLVs from the payload (will be re-added later)
662        if protected_tlv_off is not None:
663            self.payload = self.payload[:protected_tlv_off]
664
665        if enckey is not None and dont_encrypt is False:
666            if encrypt_keylen == 256:
667                plainkey = os.urandom(32)
668            else:
669                plainkey = os.urandom(16)
670
671            if isinstance(enckey, rsa.RSAPublic):
672                cipherkey = enckey._get_public().encrypt(
673                    plainkey, padding.OAEP(
674                        mgf=padding.MGF1(algorithm=hashes.SHA256()),
675                        algorithm=hashes.SHA256(),
676                        label=None))
677                self.enctlv_len = len(cipherkey)
678                tlv.add('ENCRSA2048', cipherkey)
679            elif isinstance(enckey, (ecdsa.ECDSA256P1Public,
680                                     x25519.X25519Public)):
681                cipherkey, mac, pubk = self.ecies_hkdf(enckey, plainkey)
682                enctlv = pubk + mac + cipherkey
683                self.enctlv_len = len(enctlv)
684                if isinstance(enckey, ecdsa.ECDSA256P1Public):
685                    tlv.add('ENCEC256', enctlv)
686                else:
687                    tlv.add('ENCX25519', enctlv)
688
689            if not clear:
690                nonce = bytes([0] * 16)
691                cipher = Cipher(algorithms.AES(plainkey), modes.CTR(nonce),
692                                backend=default_backend())
693                encryptor = cipher.encryptor()
694                img = bytes(self.payload[self.header_size:])
695                self.payload[self.header_size:] = \
696                    encryptor.update(img) + encryptor.finalize()
697
698        self.payload += prot_tlv.get()
699        self.payload += tlv.get()
700
701        self.check_trailer()
702
703    def get_struct_endian(self):
704        return STRUCT_ENDIAN_DICT[self.endian]
705
706    def get_signature(self):
707        return self.signature
708
709    def get_infile_data(self):
710        return self.infile_data
711
712    def add_header(self, enckey, protected_tlv_size, compression_flags, aes_length=128):
713        """Install the image header."""
714
715        flags = 0
716        if enckey is not None:
717            if aes_length == 128:
718                flags |= IMAGE_F['ENCRYPTED_AES128']
719            else:
720                flags |= IMAGE_F['ENCRYPTED_AES256']
721        if self.load_addr != 0:
722            # Indicates that this image should be loaded into RAM
723            # instead of run directly from flash.
724            flags |= IMAGE_F['RAM_LOAD']
725        if self.rom_fixed:
726            flags |= IMAGE_F['ROM_FIXED']
727        if self.non_bootable:
728            flags |= IMAGE_F['NON_BOOTABLE']
729
730        e = STRUCT_ENDIAN_DICT[self.endian]
731        fmt = (e +
732               # type ImageHdr struct {
733               'I' +     # Magic    uint32
734               'I' +     # LoadAddr uint32
735               'H' +     # HdrSz    uint16
736               'H' +     # PTLVSz   uint16
737               'I' +     # ImgSz    uint32
738               'I' +     # Flags    uint32
739               'BBHI' +  # Vers     ImageVersion
740               'I'       # Pad1     uint32
741               )  # }
742        assert struct.calcsize(fmt) == IMAGE_HEADER_SIZE
743        header = struct.pack(fmt,
744                             IMAGE_MAGIC,
745                             self.rom_fixed or self.load_addr,
746                             self.header_size,
747                             protected_tlv_size,  # TLV Info header +
748                                                  # Protected TLVs
749                             len(self.payload) - self.header_size,  # ImageSz
750                             flags | compression_flags,
751                             self.version.major,
752                             self.version.minor or 0,
753                             self.version.revision or 0,
754                             self.version.build or 0,
755                             0)  # Pad1
756        self.payload = bytearray(self.payload)
757        self.payload[:len(header)] = header
758
759    def _trailer_size(self, write_size, max_sectors, overwrite_only, enckey,
760                      save_enctlv, enctlv_len):
761        # NOTE: should already be checked by the argument parser
762        magic_size = 16
763        magic_align_size = align_up(magic_size, self.max_align)
764        if overwrite_only:
765            return self.max_align * 2 + magic_align_size
766        else:
767            if write_size not in set([1, 2, 4, 8, 16, 32]):
768                raise click.BadParameter("Invalid alignment: {}".format(
769                    write_size))
770            m = DEFAULT_MAX_SECTORS if max_sectors is None else max_sectors
771            trailer = m * 3 * write_size  # status area
772            if enckey is not None:
773                if save_enctlv:
774                    # TLV saved by the bootloader is aligned
775                    keylen = align_up(enctlv_len, self.max_align)
776                else:
777                    keylen = align_up(16, self.max_align)
778                trailer += keylen * 2  # encryption keys
779            trailer += self.max_align * 4  # image_ok/copy_done/swap_info/swap_size
780            trailer += magic_align_size
781            return trailer
782
783    def pad_to(self, size):
784        """Pad the image to the given size, with the given flash alignment."""
785        tsize = self._trailer_size(self.align, self.max_sectors,
786                                   self.overwrite_only, self.enckey,
787                                   self.save_enctlv, self.enctlv_len)
788        padding = size - (len(self.payload) + tsize)
789        pbytes = bytearray([self.erased_val] * padding)
790        pbytes += bytearray([self.erased_val] * (tsize - len(self.boot_magic)))
791        pbytes += self.boot_magic
792        if self.confirm and not self.overwrite_only:
793            magic_size = 16
794            magic_align_size = align_up(magic_size, self.max_align)
795            image_ok_idx = -(magic_align_size + self.max_align)
796            pbytes[image_ok_idx] = 0x01  # image_ok = 0x01
797        self.payload += pbytes
798
799    @staticmethod
800    def verify(imgfile, key):
801        ext = os.path.splitext(imgfile)[1][1:].lower()
802        try:
803            if ext == INTEL_HEX_EXT:
804                b = IntelHex(imgfile).tobinstr()
805            else:
806                with open(imgfile, 'rb') as f:
807                    b = f.read()
808        except FileNotFoundError:
809            raise click.UsageError(f"Image file {imgfile} not found")
810
811        magic, _, header_size, _, img_size = struct.unpack('IIHHI', b[:16])
812        version = struct.unpack('BBHI', b[20:28])
813
814        if magic != IMAGE_MAGIC:
815            return VerifyResult.INVALID_MAGIC, None, None, None
816
817        tlv_off = header_size + img_size
818        tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
819        magic, tlv_tot = struct.unpack('HH', tlv_info)
820        if magic == TLV_PROT_INFO_MAGIC:
821            tlv_off += tlv_tot
822            tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
823            magic, tlv_tot = struct.unpack('HH', tlv_info)
824
825        if magic != TLV_INFO_MAGIC:
826            return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None, None
827
828        # This is set by existence of TLV SIG_PURE
829        is_pure = False
830
831        prot_tlv_size = tlv_off
832        hash_region = b[:prot_tlv_size]
833        tlv_end = tlv_off + tlv_tot
834        tlv_off += TLV_INFO_SIZE  # skip tlv info
835
836        # First scan all TLVs in search of SIG_PURE
837        while tlv_off < tlv_end:
838            tlv = b[tlv_off:tlv_off + TLV_SIZE]
839            tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
840            if tlv_type == TLV_VALUES['SIG_PURE']:
841                is_pure = True
842                break
843            tlv_off += TLV_SIZE + tlv_len
844
845        digest = None
846        tlv_off = header_size + img_size
847        tlv_end = tlv_off + tlv_tot
848        tlv_off += TLV_INFO_SIZE  # skip tlv info
849        while tlv_off < tlv_end:
850            tlv = b[tlv_off:tlv_off + TLV_SIZE]
851            tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
852            if is_sha_tlv(tlv_type):
853                if not tlv_matches_key_type(tlv_type, key):
854                    return VerifyResult.KEY_MISMATCH, None, None, None
855                off = tlv_off + TLV_SIZE
856                digest = get_digest(tlv_type, hash_region)
857                if digest == b[off:off + tlv_len]:
858                    if key is None:
859                        return VerifyResult.OK, version, digest, None
860                else:
861                    return VerifyResult.INVALID_HASH, None, None, None
862            elif not is_pure and key is not None and tlv_type == TLV_VALUES[key.sig_tlv()]:
863                off = tlv_off + TLV_SIZE
864                tlv_sig = b[off:off + tlv_len]
865                payload = b[:prot_tlv_size]
866                try:
867                    if hasattr(key, 'verify'):
868                        key.verify(tlv_sig, payload)
869                    else:
870                        key.verify_digest(tlv_sig, digest)
871                    return VerifyResult.OK, version, digest, None
872                except InvalidSignature:
873                    # continue to next TLV
874                    pass
875            elif is_pure and key is not None and tlv_type in ALLOWED_PURE_SIG_TLVS:
876                off = tlv_off + TLV_SIZE
877                tlv_sig = b[off:off + tlv_len]
878                try:
879                    key.verify_digest(tlv_sig, hash_region)
880                    return VerifyResult.OK, version, None, tlv_sig
881                except InvalidSignature:
882                    # continue to next TLV
883                    pass
884            tlv_off += TLV_SIZE + tlv_len
885        return VerifyResult.INVALID_SIGNATURE, None, None, None
886