1#-------------------------------------------------------------------------------
2# Copyright (c) 2024, Arm Limited. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6#-------------------------------------------------------------------------------
7
8import argparse
9from cryptography.hazmat.primitives import cmac
10
11from provisioning_common_utils import *
12
13def derive_encryption_key(input_key, provisioning_lcs, tp_mode, krtl_derivation_label):
14    # Every element of the boot state is 0
15    boot_state = struct_pack([
16        provisioning_lcs.to_bytes(4, byteorder='little'),
17        tp_mode.to_bytes(4, byteorder='little'),
18        bytes(32 + 4),
19    ])
20    hash = hashlib.sha256()
21    hash.update(boot_state)
22    context = hash.digest()
23
24    state = struct_pack([krtl_derivation_label.encode('ascii') + bytes(1),
25                         bytes(1), context,
26                         (32).to_bytes(4, byteorder='little')])
27    c = cmac.CMAC(algorithms.AES(input_key))
28    c.update(state)
29    k0 = c.finalize()
30
31    output_key = bytes(0);
32    # The KDF outputs 16 bytes per iteration, so we need 2 for an AES-256 key
33    for i in range(2):
34        state = struct_pack([(i + 1).to_bytes(4, byteorder='little'),
35                             # C keeps the null byte, python removes it, so we add
36                             # it back manually.
37                             krtl_derivation_label.encode('ascii') + bytes(1),
38                             bytes(1), context,
39                             (32).to_bytes(4, byteorder='little'),
40                             k0])
41        c = cmac.CMAC(algorithms.AES(input_key))
42        c.update(state)
43        output_key += c.finalize()
44    return output_key
45
46parser = argparse.ArgumentParser()
47parser.add_argument("--tp_mode", help="the test or production mode", choices=["TCI", "PCI"], required=True)
48parser.add_argument("--krtl_file", help="the RTL key file", required=True)
49parser.add_argument("--key_select", help="Which key to derive", choices=["cm", "dm"], required=True)
50parser.add_argument("--output_key_file", help="key output file", required=True)
51args = parser.parse_args()
52
53if args.tp_mode == "TCI":
54    tp_mode = 0x111155AA
55elif args.tp_mode == "PCI":
56    tp_mode = 0x2222AA55
57
58with open(args.krtl_file, "rb") as in_file:
59    input_key = in_file.read()
60
61if args.key_select == "cm":
62    output_key = derive_encryption_key(input_key, 0, tp_mode, "CM_PROVISIONING")
63elif args.key_select == "dm":
64    output_key = derive_encryption_key(input_key, 1, tp_mode, "DM_PROVISIONING")
65
66with open(args.output_key_file, "wb") as out_file:
67    out_file.write(output_key)
68