1# Copyright 2018 Nordic Semiconductor ASA
2# Copyright 2017-2020 Linaro Limited
3# Copyright 2019-2024 Arm Limited
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License");
8# you may not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11#     http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS,
15# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19"""
20Image signing and management.
21"""
22
23from . import version as versmod
24from .boot_record import create_sw_component_data
25import click
26import copy
27from enum import Enum
28import array
29from intelhex import IntelHex
30import hashlib
31import array
32import os.path
33import struct
34from enum import Enum
35
36import click
37from cryptography.exceptions import InvalidSignature
38from cryptography.hazmat.backends import default_backend
39from cryptography.hazmat.primitives import hashes, hmac
40from cryptography.hazmat.primitives.asymmetric import ec, padding
41from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
42from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
43from cryptography.hazmat.primitives.kdf.hkdf import HKDF
44from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
45from intelhex import IntelHex
46
47from . import version as versmod, keys
48from .boot_record import create_sw_component_data
49from .keys import rsa, ecdsa, x25519
50
51from collections import namedtuple
52
53IMAGE_MAGIC = 0x96f3b83d
54IMAGE_HEADER_SIZE = 32
55BIN_EXT = "bin"
56INTEL_HEX_EXT = "hex"
57DEFAULT_MAX_SECTORS = 128
58DEFAULT_MAX_ALIGN = 8
59DEP_IMAGES_KEY = "images"
60DEP_VERSIONS_KEY = "versions"
61MAX_SW_TYPE_LENGTH = 12  # Bytes
62
63# Image header flags.
64IMAGE_F = {
65        'PIC':                   0x0000001,
66        'ENCRYPTED_AES128':      0x0000004,
67        'ENCRYPTED_AES256':      0x0000008,
68        'NON_BOOTABLE':          0x0000010,
69        'RAM_LOAD':              0x0000020,
70        'ROM_FIXED':             0x0000100,
71        'COMPRESSED_LZMA1':      0x0000200,
72        'COMPRESSED_LZMA2':      0x0000400,
73        'COMPRESSED_ARM_THUMB':  0x0000800,
74}
75
76TLV_VALUES = {
77        'KEYHASH': 0x01,
78        'PUBKEY': 0x02,
79        'SHA256': 0x10,
80        'SHA384': 0x11,
81        'SHA512': 0x12,
82        'RSA2048': 0x20,
83        'ECDSASIG': 0x22,
84        'RSA3072': 0x23,
85        'ED25519': 0x24,
86        'SIG_PURE': 0x25,
87        'ENCRSA2048': 0x30,
88        'ENCKW': 0x31,
89        'ENCEC256': 0x32,
90        'ENCX25519': 0x33,
91        'DEPENDENCY': 0x40,
92        'SEC_CNT': 0x50,
93        'BOOT_RECORD': 0x60,
94        'DECOMP_SIZE': 0x70,
95        'DECOMP_SHA': 0x71,
96        'DECOMP_SIGNATURE': 0x72,
97}
98
99TLV_SIZE = 4
100TLV_INFO_SIZE = 4
101TLV_INFO_MAGIC = 0x6907
102TLV_PROT_INFO_MAGIC = 0x6908
103
104TLV_VENDOR_RES_MIN = 0x00a0
105TLV_VENDOR_RES_MAX = 0xfffe
106
107STRUCT_ENDIAN_DICT = {
108        'little': '<',
109        'big':    '>'
110}
111
112VerifyResult = Enum('VerifyResult',
113                    ['OK', 'INVALID_MAGIC', 'INVALID_TLV_INFO_MAGIC', 'INVALID_HASH', 'INVALID_SIGNATURE',
114                     'KEY_MISMATCH'])
115
116
117def align_up(num, align):
118    assert (align & (align - 1) == 0) and align != 0
119    return (num + (align - 1)) & ~(align - 1)
120
121
122class TLV():
123    def __init__(self, endian, magic=TLV_INFO_MAGIC):
124        self.magic = magic
125        self.buf = bytearray()
126        self.endian = endian
127
128    def __len__(self):
129        return TLV_INFO_SIZE + len(self.buf)
130
131    def add(self, kind, payload):
132        """
133        Add a TLV record.  Kind should be a string found in TLV_VALUES above.
134        """
135        e = STRUCT_ENDIAN_DICT[self.endian]
136        if isinstance(kind, int):
137            if not TLV_VENDOR_RES_MIN <= kind <= TLV_VENDOR_RES_MAX:
138                msg = "Invalid custom TLV type value '0x{:04x}', allowed " \
139                      "value should be between 0x{:04x} and 0x{:04x}".format(
140                        kind, TLV_VENDOR_RES_MIN, TLV_VENDOR_RES_MAX)
141                raise click.UsageError(msg)
142            buf = struct.pack(e + 'HH', kind, len(payload))
143        else:
144            buf = struct.pack(e + 'BBH', TLV_VALUES[kind], 0, len(payload))
145        self.buf += buf
146        self.buf += payload
147
148    def get(self):
149        if len(self.buf) == 0:
150            return bytes()
151        e = STRUCT_ENDIAN_DICT[self.endian]
152        header = struct.pack(e + 'HH', self.magic, len(self))
153        return header + bytes(self.buf)
154
155
156SHAAndAlgT = namedtuple('SHAAndAlgT', ['sha', 'alg'])
157
158TLV_SHA_TO_SHA_AND_ALG = {
159    TLV_VALUES['SHA256'] : SHAAndAlgT('256', hashlib.sha256),
160    TLV_VALUES['SHA384'] : SHAAndAlgT('384', hashlib.sha384),
161    TLV_VALUES['SHA512'] : SHAAndAlgT('512', hashlib.sha512),
162}
163
164
165USER_SHA_TO_ALG_AND_TLV = {
166    'auto'   : (hashlib.sha256, 'SHA256'),
167    '256'    : (hashlib.sha256, 'SHA256'),
168    '384'    : (hashlib.sha384, 'SHA384'),
169    '512'    : (hashlib.sha512, 'SHA512')
170}
171
172
173def is_sha_tlv(tlv):
174    return tlv in TLV_SHA_TO_SHA_AND_ALG.keys()
175
176
177def tlv_sha_to_sha(tlv):
178    return TLV_SHA_TO_SHA_AND_ALG[tlv].sha
179
180
181# Auto selecting hash algorithm for type(key)
182ALLOWED_KEY_SHA = {
183    keys.ECDSA384P1         : ['384'],
184    keys.ECDSA384P1Public   : ['384'],
185    keys.ECDSA256P1         : ['256'],
186    keys.RSA                : ['256'],
187    keys.RSAPublic          : ['256'],
188    # This two are set to 256 for compatibility, the right would be 512
189    keys.Ed25519            : ['256', '512'],
190    keys.X25519             : ['256', '512']
191}
192
193def key_and_user_sha_to_alg_and_tlv(key, user_sha):
194    """Matches key and user requested sha to sha alogrithm and TLV name.
195
196       The returned tuple will contain hash functions and TVL name.
197       The function is designed to succeed or completely fail execution,
198       as providing incorrect pair here basically prevents doing
199       any more work.
200    """
201    if key is None:
202        # If key is none, we allow whatever user has selected for sha
203        return USER_SHA_TO_ALG_AND_TLV[user_sha]
204
205    # If key is not None, then we have to filter hash to only allowed
206    allowed = None
207    try:
208        allowed = ALLOWED_KEY_SHA[type(key)]
209    except KeyError:
210        raise click.UsageError("Colud not find allowed hash algorithms for {}"
211                               .format(type(key)))
212    if user_sha == 'auto':
213        return USER_SHA_TO_ALG_AND_TLV[allowed[0]]
214
215    if user_sha in allowed:
216        return USER_SHA_TO_ALG_AND_TLV[user_sha]
217
218    raise click.UsageError("Key {} can not be used with --sha {}; allowed sha are one of {}"
219                           .format(key.sig_type(), user_sha, allowed))
220
221
222def get_digest(tlv_type, hash_region):
223    sha = TLV_SHA_TO_SHA_AND_ALG[tlv_type].alg()
224
225    sha.update(hash_region)
226    return sha.digest()
227
228
229def tlv_matches_key_type(tlv_type, key):
230    """Check if provided key matches to TLV record in the image"""
231    try:
232        # We do not need the result here, and the key_and_user_sha_to_alg_and_tlv
233        # will either succeed finding match or rise exception, so on success we
234        # return True, on exception we return False.
235        _, _ = key_and_user_sha_to_alg_and_tlv(key, tlv_sha_to_sha(tlv_type))
236        return True
237    except:
238        pass
239
240    return False
241
242
243class Image:
244
245    def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
246                 pad_header=False, pad=False, confirm=False, align=1,
247                 slot_size=0, max_sectors=DEFAULT_MAX_SECTORS,
248                 overwrite_only=False, endian="little", load_addr=0,
249                 rom_fixed=None, erased_val=None, save_enctlv=False,
250                 security_counter=None, max_align=None,
251                 non_bootable=False):
252
253        if load_addr and rom_fixed:
254            raise click.UsageError("Can not set rom_fixed and load_addr at the same time")
255
256        self.image_hash = None
257        self.image_size = None
258        self.signature = None
259        self.version = version or versmod.decode_version("0")
260        self.header_size = header_size
261        self.pad_header = pad_header
262        self.pad = pad
263        self.confirm = confirm
264        self.align = align
265        self.slot_size = slot_size
266        self.max_sectors = max_sectors
267        self.overwrite_only = overwrite_only
268        self.endian = endian
269        self.base_addr = None
270        self.load_addr = 0 if load_addr is None else load_addr
271        self.rom_fixed = rom_fixed
272        self.erased_val = 0xff if erased_val is None else int(erased_val, 0)
273        self.payload = []
274        self.infile_data = []
275        self.enckey = None
276        self.save_enctlv = save_enctlv
277        self.enctlv_len = 0
278        self.max_align = max(DEFAULT_MAX_ALIGN, align) if max_align is None else int(max_align)
279        self.non_bootable = non_bootable
280
281        if self.max_align == DEFAULT_MAX_ALIGN:
282            self.boot_magic = bytes([
283                0x77, 0xc2, 0x95, 0xf3,
284                0x60, 0xd2, 0xef, 0x7f,
285                0x35, 0x52, 0x50, 0x0f,
286                0x2c, 0xb6, 0x79, 0x80, ])
287        else:
288            lsb = self.max_align & 0x00ff
289            msb = (self.max_align & 0xff00) >> 8
290            align = bytes([msb, lsb]) if self.endian == "big" else bytes([lsb, msb])
291            self.boot_magic = align + bytes([0x2d, 0xe1,
292                                             0x5d, 0x29, 0x41, 0x0b,
293                                             0x8d, 0x77, 0x67, 0x9c,
294                                             0x11, 0x0f, 0x1f, 0x8a, ])
295
296        if security_counter == 'auto':
297            # Security counter has not been explicitly provided,
298            # generate it from the version number
299            self.security_counter = ((self.version.major << 24)
300                                     + (self.version.minor << 16)
301                                     + self.version.revision)
302        else:
303            self.security_counter = security_counter
304
305    def __repr__(self):
306        return "<Image version={}, header_size={}, security_counter={}, \
307                base_addr={}, load_addr={}, align={}, slot_size={}, \
308                max_sectors={}, overwrite_only={}, endian={} format={}, \
309                payloadlen=0x{:x}>".format(
310                    self.version,
311                    self.header_size,
312                    self.security_counter,
313                    self.base_addr if self.base_addr is not None else "N/A",
314                    self.load_addr,
315                    self.align,
316                    self.slot_size,
317                    self.max_sectors,
318                    self.overwrite_only,
319                    self.endian,
320                    self.__class__.__name__,
321                    len(self.payload))
322
323    def load(self, path):
324        """Load an image from a given file"""
325        ext = os.path.splitext(path)[1][1:].lower()
326        try:
327            if ext == INTEL_HEX_EXT:
328                ih = IntelHex(path)
329                self.infile_data = ih.tobinarray()
330                self.payload = copy.copy(self.infile_data)
331                self.base_addr = ih.minaddr()
332            else:
333                with open(path, 'rb') as f:
334                    self.infile_data = f.read()
335                    self.payload = copy.copy(self.infile_data)
336        except FileNotFoundError:
337            raise click.UsageError("Input file not found")
338        self.image_size = len(self.payload)
339
340        # Add the image header if needed.
341        if self.pad_header and self.header_size > 0:
342            if self.base_addr:
343                # Adjust base_addr for new header
344                self.base_addr -= self.header_size
345            self.payload = bytes([self.erased_val] * self.header_size) + \
346                self.payload
347
348        self.check_header()
349
350    def load_compressed(self, data, compression_header):
351        """Load an image from buffer"""
352        self.payload = compression_header + data
353        self.image_size = len(self.payload)
354
355        # Add the image header if needed.
356        if self.pad_header and self.header_size > 0:
357            if self.base_addr:
358                # Adjust base_addr for new header
359                self.base_addr -= self.header_size
360            self.payload = bytes([self.erased_val] * self.header_size) + \
361                self.payload
362
363        self.check_header()
364
365    def save(self, path, hex_addr=None):
366        """Save an image from a given file"""
367        ext = os.path.splitext(path)[1][1:].lower()
368        if ext == INTEL_HEX_EXT:
369            # input was in binary format, but HEX needs to know the base addr
370            if self.base_addr is None and hex_addr is None:
371                raise click.UsageError("No address exists in input file "
372                                       "neither was it provided by user")
373            h = IntelHex()
374            if hex_addr is not None:
375                self.base_addr = hex_addr
376            h.frombytes(bytes=self.payload, offset=self.base_addr)
377            if self.pad:
378                trailer_size = self._trailer_size(self.align, self.max_sectors,
379                                                  self.overwrite_only,
380                                                  self.enckey,
381                                                  self.save_enctlv,
382                                                  self.enctlv_len)
383                trailer_addr = (self.base_addr + self.slot_size) - trailer_size
384                if self.confirm and not self.overwrite_only:
385                    magic_align_size = align_up(len(self.boot_magic),
386                                                self.max_align)
387                    image_ok_idx = -(magic_align_size + self.max_align)
388                    flag = bytearray([self.erased_val] * self.max_align)
389                    flag[0] = 0x01  # image_ok = 0x01
390                    h.puts(trailer_addr + trailer_size + image_ok_idx,
391                           bytes(flag))
392                h.puts(trailer_addr + (trailer_size - len(self.boot_magic)),
393                       bytes(self.boot_magic))
394            h.tofile(path, 'hex')
395        else:
396            if self.pad:
397                self.pad_to(self.slot_size)
398            with open(path, 'wb') as f:
399                f.write(self.payload)
400
401    def check_header(self):
402        if self.header_size > 0 and not self.pad_header:
403            if any(v != 0 for v in self.payload[0:self.header_size]):
404                raise click.UsageError("Header padding was not requested and "
405                                       "image does not start with zeros")
406
407    def check_trailer(self):
408        if self.slot_size > 0:
409            tsize = self._trailer_size(self.align, self.max_sectors,
410                                       self.overwrite_only, self.enckey,
411                                       self.save_enctlv, self.enctlv_len)
412            padding = self.slot_size - (len(self.payload) + tsize)
413            if padding < 0:
414                msg = "Image size (0x{:x}) + trailer (0x{:x}) exceeds " \
415                      "requested size 0x{:x}".format(
416                          len(self.payload), tsize, self.slot_size)
417                raise click.UsageError(msg)
418
419    def ecies_hkdf(self, enckey, plainkey):
420        if isinstance(enckey, ecdsa.ECDSA256P1Public):
421            newpk = ec.generate_private_key(ec.SECP256R1(), default_backend())
422            shared = newpk.exchange(ec.ECDH(), enckey._get_public())
423        else:
424            newpk = X25519PrivateKey.generate()
425            shared = newpk.exchange(enckey._get_public())
426        derived_key = HKDF(
427            algorithm=hashes.SHA256(), length=48, salt=None,
428            info=b'MCUBoot_ECIES_v1', backend=default_backend()).derive(shared)
429        encryptor = Cipher(algorithms.AES(derived_key[:16]),
430                           modes.CTR(bytes([0] * 16)),
431                           backend=default_backend()).encryptor()
432        cipherkey = encryptor.update(plainkey) + encryptor.finalize()
433        mac = hmac.HMAC(derived_key[16:], hashes.SHA256(),
434                        backend=default_backend())
435        mac.update(cipherkey)
436        ciphermac = mac.finalize()
437        if isinstance(enckey, ecdsa.ECDSA256P1Public):
438            pubk = newpk.public_key().public_bytes(
439                encoding=Encoding.X962,
440                format=PublicFormat.UncompressedPoint)
441        else:
442            pubk = newpk.public_key().public_bytes(
443                encoding=Encoding.Raw,
444                format=PublicFormat.Raw)
445        return cipherkey, ciphermac, pubk
446
447    def create(self, key, public_key_format, enckey, dependencies=None,
448               sw_type=None, custom_tlvs=None, compression_tlvs=None,
449               compression_type=None, encrypt_keylen=128, clear=False,
450               fixed_sig=None, pub_key=None, vector_to_sign=None, user_sha='auto'):
451        self.enckey = enckey
452
453        # key decides on sha, then pub_key; of both are none default is used
454        check_key = key if key is not None else pub_key
455        hash_algorithm, hash_tlv = key_and_user_sha_to_alg_and_tlv(check_key, user_sha)
456
457        # Calculate the hash of the public key
458        if key is not None:
459            pub = key.get_public_bytes()
460            sha = hash_algorithm()
461            sha.update(pub)
462            pubbytes = sha.digest()
463        elif pub_key is not None:
464            if hasattr(pub_key, 'sign'):
465                print(os.path.basename(__file__) + ": sign the payload")
466            pub = pub_key.get_public_bytes()
467            sha = hash_algorithm()
468            sha.update(pub)
469            pubbytes = sha.digest()
470        else:
471            pubbytes = bytes(hashlib.sha256().digest_size)
472
473        protected_tlv_size = 0
474
475        if self.security_counter is not None:
476            # Size of the security counter TLV: header ('HH') + payload ('I')
477            #                                   = 4 + 4 = 8 Bytes
478            protected_tlv_size += TLV_SIZE + 4
479
480        if sw_type is not None:
481            if len(sw_type) > MAX_SW_TYPE_LENGTH:
482                msg = "'{}' is too long ({} characters) for sw_type. Its " \
483                      "maximum allowed length is 12 characters.".format(
484                       sw_type, len(sw_type))
485                raise click.UsageError(msg)
486
487            image_version = (str(self.version.major) + '.'
488                             + str(self.version.minor) + '.'
489                             + str(self.version.revision))
490
491            # The image hash is computed over the image header, the image
492            # itself and the protected TLV area. However, the boot record TLV
493            # (which is part of the protected area) should contain this hash
494            # before it is even calculated. For this reason the script fills
495            # this field with zeros and the bootloader will insert the right
496            # value later.
497            digest = bytes(hash_algorithm().digest_size)
498
499            # Create CBOR encoded boot record
500            boot_record = create_sw_component_data(sw_type, image_version,
501                                                   hash_tlv, digest,
502                                                   pubbytes)
503
504            protected_tlv_size += TLV_SIZE + len(boot_record)
505
506        if dependencies is not None:
507            # Size of a Dependency TLV = Header ('HH') + Payload('IBBHI')
508            # = 4 + 12 = 16 Bytes
509            dependencies_num = len(dependencies[DEP_IMAGES_KEY])
510            protected_tlv_size += (dependencies_num * 16)
511
512        if compression_tlvs is not None:
513            for value in compression_tlvs.values():
514                protected_tlv_size += TLV_SIZE + len(value)
515        if custom_tlvs is not None:
516            for value in custom_tlvs.values():
517                protected_tlv_size += TLV_SIZE + len(value)
518
519        if protected_tlv_size != 0:
520            # Add the size of the TLV info header
521            protected_tlv_size += TLV_INFO_SIZE
522
523        # At this point the image is already on the payload
524        #
525        # This adds the padding if image is not aligned to the 16 Bytes
526        # in encrypted mode
527        if self.enckey is not None:
528            pad_len = len(self.payload) % 16
529            if pad_len > 0:
530                pad = bytes(16 - pad_len)
531                if isinstance(self.payload, bytes):
532                    self.payload += pad
533                else:
534                    self.payload.extend(pad)
535
536        compression_flags = 0x0
537        if compression_tlvs is not None:
538            if compression_type in ["lzma2", "lzma2armthumb"]:
539                compression_flags = IMAGE_F['COMPRESSED_LZMA2']
540                if compression_type == "lzma2armthumb":
541                    compression_flags |= IMAGE_F['COMPRESSED_ARM_THUMB']
542        # This adds the header to the payload as well
543        if encrypt_keylen == 256:
544            self.add_header(enckey, protected_tlv_size, compression_flags, 256)
545        else:
546            self.add_header(enckey, protected_tlv_size, compression_flags)
547
548        prot_tlv = TLV(self.endian, TLV_PROT_INFO_MAGIC)
549
550        # Protected TLVs must be added first, because they are also included
551        # in the hash calculation
552        protected_tlv_off = None
553        if protected_tlv_size != 0:
554
555            e = STRUCT_ENDIAN_DICT[self.endian]
556
557            if self.security_counter is not None:
558                payload = struct.pack(e + 'I', self.security_counter)
559                prot_tlv.add('SEC_CNT', payload)
560
561            if sw_type is not None:
562                prot_tlv.add('BOOT_RECORD', boot_record)
563
564            if dependencies is not None:
565                for i in range(dependencies_num):
566                    payload = struct.pack(
567                        e + 'B3x' + 'BBHI',
568                        int(dependencies[DEP_IMAGES_KEY][i]),
569                        dependencies[DEP_VERSIONS_KEY][i].major,
570                        dependencies[DEP_VERSIONS_KEY][i].minor,
571                        dependencies[DEP_VERSIONS_KEY][i].revision,
572                        dependencies[DEP_VERSIONS_KEY][i].build
573                    )
574                    prot_tlv.add('DEPENDENCY', payload)
575
576            if compression_tlvs is not None:
577                for tag, value in compression_tlvs.items():
578                    prot_tlv.add(tag, value)
579            if custom_tlvs is not None:
580                for tag, value in custom_tlvs.items():
581                    prot_tlv.add(tag, value)
582
583            protected_tlv_off = len(self.payload)
584            self.payload += prot_tlv.get()
585
586        tlv = TLV(self.endian)
587
588        # These signature is done over sha of image. In case of
589        # EC signatures so called Pure algorithm, designated to be run
590        # over entire message is used with sha of image as message,
591        # so, for example, in case of ED25519 we have here SHAxxx-ED25519-SHA512.
592        sha = hash_algorithm()
593        sha.update(self.payload)
594        digest = sha.digest()
595        message = digest;
596        tlv.add(hash_tlv, digest)
597        self.image_hash = digest
598
599        if vector_to_sign == 'payload':
600            # Stop amending data to the image
601            # Just keep data vector which is expected to be signed
602            print(os.path.basename(__file__) + ': export payload')
603            return
604        elif vector_to_sign == 'digest':
605            self.payload = digest
606            print(os.path.basename(__file__) + ': export digest')
607            return
608
609        if key is not None or fixed_sig is not None:
610            if public_key_format == 'hash':
611                tlv.add('KEYHASH', pubbytes)
612            else:
613                tlv.add('PUBKEY', pub)
614
615            if key is not None and fixed_sig is None:
616                # `sign` expects the full image payload (hashing done
617                # internally), while `sign_digest` expects only the digest
618                # of the payload
619
620                if hasattr(key, 'sign'):
621                    print(os.path.basename(__file__) + ": sign the payload")
622                    sig = key.sign(bytes(self.payload))
623                else:
624                    print(os.path.basename(__file__) + ": sign the digest")
625                    sig = key.sign_digest(message)
626                tlv.add(key.sig_tlv(), sig)
627                self.signature = sig
628            elif fixed_sig is not None and key is None:
629                tlv.add(pub_key.sig_tlv(), fixed_sig['value'])
630                self.signature = fixed_sig['value']
631            else:
632                raise click.UsageError("Can not sign using key and provide fixed-signature at the same time")
633
634        # At this point the image was hashed + signed, we can remove the
635        # protected TLVs from the payload (will be re-added later)
636        if protected_tlv_off is not None:
637            self.payload = self.payload[:protected_tlv_off]
638
639        if enckey is not None:
640            if encrypt_keylen == 256:
641                plainkey = os.urandom(32)
642            else:
643                plainkey = os.urandom(16)
644
645            if isinstance(enckey, rsa.RSAPublic):
646                cipherkey = enckey._get_public().encrypt(
647                    plainkey, padding.OAEP(
648                        mgf=padding.MGF1(algorithm=hashes.SHA256()),
649                        algorithm=hashes.SHA256(),
650                        label=None))
651                self.enctlv_len = len(cipherkey)
652                tlv.add('ENCRSA2048', cipherkey)
653            elif isinstance(enckey, (ecdsa.ECDSA256P1Public,
654                                     x25519.X25519Public)):
655                cipherkey, mac, pubk = self.ecies_hkdf(enckey, plainkey)
656                enctlv = pubk + mac + cipherkey
657                self.enctlv_len = len(enctlv)
658                if isinstance(enckey, ecdsa.ECDSA256P1Public):
659                    tlv.add('ENCEC256', enctlv)
660                else:
661                    tlv.add('ENCX25519', enctlv)
662
663            if not clear:
664                nonce = bytes([0] * 16)
665                cipher = Cipher(algorithms.AES(plainkey), modes.CTR(nonce),
666                                backend=default_backend())
667                encryptor = cipher.encryptor()
668                img = bytes(self.payload[self.header_size:])
669                self.payload[self.header_size:] = \
670                    encryptor.update(img) + encryptor.finalize()
671
672        self.payload += prot_tlv.get()
673        self.payload += tlv.get()
674
675        self.check_trailer()
676
677    def get_struct_endian(self):
678        return STRUCT_ENDIAN_DICT[self.endian]
679
680    def get_signature(self):
681        return self.signature
682
683    def get_infile_data(self):
684        return self.infile_data
685
686    def add_header(self, enckey, protected_tlv_size, compression_flags, aes_length=128):
687        """Install the image header."""
688
689        flags = 0
690        if enckey is not None:
691            if aes_length == 128:
692                flags |= IMAGE_F['ENCRYPTED_AES128']
693            else:
694                flags |= IMAGE_F['ENCRYPTED_AES256']
695        if self.load_addr != 0:
696            # Indicates that this image should be loaded into RAM
697            # instead of run directly from flash.
698            flags |= IMAGE_F['RAM_LOAD']
699        if self.rom_fixed:
700            flags |= IMAGE_F['ROM_FIXED']
701        if self.non_bootable:
702            flags |= IMAGE_F['NON_BOOTABLE']
703
704        e = STRUCT_ENDIAN_DICT[self.endian]
705        fmt = (e +
706               # type ImageHdr struct {
707               'I' +     # Magic    uint32
708               'I' +     # LoadAddr uint32
709               'H' +     # HdrSz    uint16
710               'H' +     # PTLVSz   uint16
711               'I' +     # ImgSz    uint32
712               'I' +     # Flags    uint32
713               'BBHI' +  # Vers     ImageVersion
714               'I'       # Pad1     uint32
715               )  # }
716        assert struct.calcsize(fmt) == IMAGE_HEADER_SIZE
717        header = struct.pack(fmt,
718                             IMAGE_MAGIC,
719                             self.rom_fixed or self.load_addr,
720                             self.header_size,
721                             protected_tlv_size,  # TLV Info header +
722                                                  # Protected TLVs
723                             len(self.payload) - self.header_size,  # ImageSz
724                             flags | compression_flags,
725                             self.version.major,
726                             self.version.minor or 0,
727                             self.version.revision or 0,
728                             self.version.build or 0,
729                             0)  # Pad1
730        self.payload = bytearray(self.payload)
731        self.payload[:len(header)] = header
732
733    def _trailer_size(self, write_size, max_sectors, overwrite_only, enckey,
734                      save_enctlv, enctlv_len):
735        # NOTE: should already be checked by the argument parser
736        magic_size = 16
737        magic_align_size = align_up(magic_size, self.max_align)
738        if overwrite_only:
739            return self.max_align * 2 + magic_align_size
740        else:
741            if write_size not in set([1, 2, 4, 8, 16, 32]):
742                raise click.BadParameter("Invalid alignment: {}".format(
743                    write_size))
744            m = DEFAULT_MAX_SECTORS if max_sectors is None else max_sectors
745            trailer = m * 3 * write_size  # status area
746            if enckey is not None:
747                if save_enctlv:
748                    # TLV saved by the bootloader is aligned
749                    keylen = align_up(enctlv_len, self.max_align)
750                else:
751                    keylen = align_up(16, self.max_align)
752                trailer += keylen * 2  # encryption keys
753            trailer += self.max_align * 4  # image_ok/copy_done/swap_info/swap_size
754            trailer += magic_align_size
755            return trailer
756
757    def pad_to(self, size):
758        """Pad the image to the given size, with the given flash alignment."""
759        tsize = self._trailer_size(self.align, self.max_sectors,
760                                   self.overwrite_only, self.enckey,
761                                   self.save_enctlv, self.enctlv_len)
762        padding = size - (len(self.payload) + tsize)
763        pbytes = bytearray([self.erased_val] * padding)
764        pbytes += bytearray([self.erased_val] * (tsize - len(self.boot_magic)))
765        pbytes += self.boot_magic
766        if self.confirm and not self.overwrite_only:
767            magic_size = 16
768            magic_align_size = align_up(magic_size, self.max_align)
769            image_ok_idx = -(magic_align_size + self.max_align)
770            pbytes[image_ok_idx] = 0x01  # image_ok = 0x01
771        self.payload += pbytes
772
773    @staticmethod
774    def verify(imgfile, key):
775        ext = os.path.splitext(imgfile)[1][1:].lower()
776        try:
777            if ext == INTEL_HEX_EXT:
778                b = IntelHex(imgfile).tobinstr()
779            else:
780                with open(imgfile, 'rb') as f:
781                    b = f.read()
782        except FileNotFoundError:
783            raise click.UsageError(f"Image file {imgfile} not found")
784
785        magic, _, header_size, _, img_size = struct.unpack('IIHHI', b[:16])
786        version = struct.unpack('BBHI', b[20:28])
787
788        if magic != IMAGE_MAGIC:
789            return VerifyResult.INVALID_MAGIC, None, None
790
791        tlv_off = header_size + img_size
792        tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
793        magic, tlv_tot = struct.unpack('HH', tlv_info)
794        if magic == TLV_PROT_INFO_MAGIC:
795            tlv_off += tlv_tot
796            tlv_info = b[tlv_off:tlv_off + TLV_INFO_SIZE]
797            magic, tlv_tot = struct.unpack('HH', tlv_info)
798
799        if magic != TLV_INFO_MAGIC:
800            return VerifyResult.INVALID_TLV_INFO_MAGIC, None, None
801
802        prot_tlv_size = tlv_off
803        hash_region = b[:prot_tlv_size]
804        digest = None
805        tlv_end = tlv_off + tlv_tot
806        tlv_off += TLV_INFO_SIZE  # skip tlv info
807        while tlv_off < tlv_end:
808            tlv = b[tlv_off:tlv_off + TLV_SIZE]
809            tlv_type, _, tlv_len = struct.unpack('BBH', tlv)
810            if is_sha_tlv(tlv_type):
811                if not tlv_matches_key_type(tlv_type, key):
812                    return VerifyResult.KEY_MISMATCH, None, None
813                off = tlv_off + TLV_SIZE
814                digest = get_digest(tlv_type, hash_region)
815                if digest == b[off:off + tlv_len]:
816                    if key is None:
817                        return VerifyResult.OK, version, digest
818                else:
819                    return VerifyResult.INVALID_HASH, None, None
820            elif key is not None and tlv_type == TLV_VALUES[key.sig_tlv()]:
821                off = tlv_off + TLV_SIZE
822                tlv_sig = b[off:off + tlv_len]
823                payload = b[:prot_tlv_size]
824                try:
825                    if hasattr(key, 'verify'):
826                        key.verify(tlv_sig, payload)
827                    else:
828                        key.verify_digest(tlv_sig, digest)
829                    return VerifyResult.OK, version, digest
830                except InvalidSignature:
831                    # continue to next TLV
832                    pass
833            tlv_off += TLV_SIZE + tlv_len
834        return VerifyResult.INVALID_SIGNATURE, None, None
835