1#-------------------------------------------------------------------------------
2# Copyright (c) 2021-2024, Arm Limited. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6#-------------------------------------------------------------------------------
7
8import hashlib
9from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
10from cryptography.hazmat.primitives import cmac
11from cryptography.hazmat.primitives.kdf.hkdf import HKDF
12from cryptography.hazmat.primitives import hashes
13from cryptography.hazmat.backends import default_backend
14import secrets
15import argparse
16import os
17import sys
18sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../bl2/ext/mcuboot/scripts"))
19import macro_parser
20import struct
21import pyhsslms
22
23def struct_pack(objects, pad_to=0):
24    defstring = "<"
25    for obj in objects:
26        defstring += str(len(obj)) + "s"
27
28    size = struct.calcsize(defstring)
29    if size < pad_to:
30        defstring += str(pad_to - size) + "x"
31
32    return (bytes(struct.pack(defstring, *objects)))
33
34def parse_version(version_string):
35    version = [0, 0, 0, 0]
36    split = version_string.split("+")
37    if len(split) > 1:
38        version[3] = int(split[1])
39    split = split[0].split(".")
40    for i in range(len(split)):
41        version[i] = int(split[i])
42
43    return struct_pack([version[0].to_bytes(1, "little"),
44                        version[1].to_bytes(1, "little"),
45                        version[2].to_bytes(2, "little"),
46                        version[3].to_bytes(4, "little")])
47
48def derive_encryption_key_cmac(security_counter):
49    with open(args.encrypt_key_file, "rb") as encrypt_key_file:
50        encrypt_key = encrypt_key_file.read()
51
52    state = struct_pack(["BL2_DECRYPTION_KEY".encode('ascii') + bytes(1),
53                         bytes(1), security_counter,
54                         (32).to_bytes(4, byteorder='little')])
55    c = cmac.CMAC(algorithms.AES(encrypt_key))
56    c.update(state)
57    k0 = c.finalize()
58
59    output_key = bytes(0);
60    # The KDF outputs 16 bytes per iteration, so we need 2 for an AES-256 key
61    for i in range(2):
62        state = struct_pack([(i + 1).to_bytes(4, byteorder='little'),
63                             # C keeps the null byte, python removes it, so we add
64                             # it back manually.
65                             "BL2_DECRYPTION_KEY".encode('ascii') + bytes(1),
66                             bytes(1), security_counter,
67                             (32).to_bytes(4, byteorder='little'),
68                             k0])
69        c = cmac.CMAC(algorithms.AES(encrypt_key))
70        c.update(state)
71        output_key += c.finalize()
72    return output_key
73
74def derive_encryption_key_hkdf(security_counter):
75    with open(args.encrypt_key_file, "rb") as encrypt_key_file:
76        encrypt_key = encrypt_key_file.read()
77
78    state = struct_pack([
79                        # C keeps the null byte, python removes it, so we add
80                        # it back manually.
81                        "BL2_DECRYPTION_KEY".encode('ascii') + bytes(1),
82                        security_counter
83                        ])
84
85    output_key = bytes(0)
86    hkdf = HKDF(
87        algorithm=hashes.SHA256(),
88        length=32,
89        salt=None,
90        info=state
91    )
92    output_key = hkdf.derive(encrypt_key)
93
94    return output_key
95
96def sign_binary_blob(blob):
97    priv_key = pyhsslms.HssLmsPrivateKey(args.sign_key_file)
98    # Remove the first 4 bytes since it's HSS info
99    sig = priv_key.sign(blob)[4:]
100    if (len(sig) != 1452):
101        raise Exception
102    return sig
103
104def hash_binary_blob(blob):
105   hash = hashlib.sha256()
106   hash.update(blob)
107   return hash.digest()
108
109def encrypt_binary_blob(blob, counter_val, encrypt_key):
110    cipher = Cipher(algorithms.AES(encrypt_key), modes.CTR(counter_val))
111    return cipher.encryptor().update(blob)
112
113parser = argparse.ArgumentParser()
114parser.add_argument("--input_file", help="the image to process", required=True)
115parser.add_argument("--img_version", help="version of the image", required=True)
116parser.add_argument("--img_security_counter", help="Secuity counter value for the image", required=True)
117parser.add_argument("--encrypt_key_file", help="encryption key file", required=True)
118parser.add_argument("--sign_key_file", help="signing key file", required=False)
119parser.add_argument("--img_output_file", help="image output file", required=True)
120parser.add_argument("--hash_output_file", help="hash output file", required=False)
121parser.add_argument("--signing_layout_file", help="signing layout file", required=True)
122parser.add_argument("--header_size", help="size of the header", required=True)
123parser.add_argument("--kdf_alg", help="which KDF will be used", required=False, default="cmac")
124args = parser.parse_args()
125
126with open(args.input_file, "rb") as in_file:
127    bl2_code = in_file.read()
128
129counter_val = secrets.token_bytes(12) + int(0).to_bytes(4, 'little')
130
131version = parse_version(args.img_version)
132
133bl2_partition_size = macro_parser.evaluate_macro(args.signing_layout_file,
134                                    ".*(RE_BL2_BIN_SIZE) = *(.*)",
135                                    1, 2, True)['RE_BL2_BIN_SIZE']
136
137plaintext = struct_pack([
138    int("0xDEADBEEF", 16).to_bytes(4, 'little'),
139    int(0).to_bytes(int(args.header_size, 0) - (1452 + 16 + 8 + 4 + 4), 'little'),
140    bl2_code,
141    ],
142    pad_to=bl2_partition_size - (1452 + 16 + 8 + 4))
143
144if args.kdf_alg == "hkdf":
145    encrypt_key = derive_encryption_key_hkdf(int(args.img_security_counter, 16).to_bytes(4, 'little'))
146else:
147    encrypt_key = derive_encryption_key_cmac(int(args.img_security_counter, 16).to_bytes(4, 'little'))
148ciphertext = encrypt_binary_blob(plaintext, counter_val, encrypt_key)
149
150data_to_sign = struct_pack([
151    version,
152    int(args.img_security_counter, 0).to_bytes(4, 'little'),
153    plaintext,
154    ])
155
156hash = hash_binary_blob(data_to_sign)
157sig = sign_binary_blob(data_to_sign)
158
159image = struct_pack([
160    counter_val,
161    sig,
162    version,
163    int(args.img_security_counter, 0).to_bytes(4, 'little'),
164    ciphertext,
165    ])
166
167if len(image) > bl2_partition_size:
168    print("Error: Signed image size {} exceeds BL2 partition size {}"
169          .format(len(image), bl2_partition_size))
170    exit(1)
171
172with open(args.img_output_file, "wb") as img_out_file:
173    img_out_file.write(image)
174
175with open(args.hash_output_file, "wb") as hash_out_file:
176    hash_out_file.write(hash)
177