1# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD
2#
3# SPDX-License-Identifier: GPL-2.0-or-later
4
5import binascii
6import configparser
7import os
8import sys
9from getpass import getpass
10
11try:
12    import pkcs11
13    from .exceptions import handle_exceptions
14except ImportError:
15    raise ImportError(
16        "python-pkcs11 package is not installed. "
17        "Please install it using the required packages with command: "
18        "pip install 'esptool[hsm]'"
19    )
20
21import cryptography.hazmat.primitives.asymmetric.ec as EC
22import cryptography.hazmat.primitives.asymmetric.rsa as RSA
23
24import ecdsa
25
26
27def read_hsm_config(configfile):
28    config = configparser.ConfigParser()
29    config.read(configfile)
30
31    section = "hsm_config"
32    if not config.has_section(section):
33        raise configparser.NoSectionError(section)
34
35    section_options = ["pkcs11_lib", "slot", "label"]
36    for option in section_options:
37        if not config.has_option(section, option):
38            raise configparser.NoOptionError(option, section)
39
40    # If the config file does not contain the "credentials" option,
41    # prompt the user for the HSM PIN
42    if not config.has_option(section, "credentials"):
43        hsm_pin = getpass("Please enter the PIN of your HSM:\n")
44        config.set(section, "credentials", hsm_pin)
45
46    return config[section]
47
48
49def establish_session(config):
50    print("Trying to establish a session with the HSM.")
51    try:
52        if os.path.exists(config["pkcs11_lib"]):
53            lib = pkcs11.lib(config["pkcs11_lib"])
54        else:
55            print(f'LIB file does not exist at {config["pkcs11_lib"]}')
56            sys.exit(1)
57        for slot in lib.get_slots(token_present=True):
58            if slot.slot_id == int(config["slot"]):
59                break
60
61        token = slot.get_token()
62        session = token.open(rw=True, user_pin=config["credentials"])
63        print(f'Session creation successful with HSM slot {int(config["slot"])}.')
64        return session
65
66    except pkcs11.exceptions.PKCS11Error as e:
67        handle_exceptions(e)
68        print("Session establishment failed")
69        sys.exit(1)
70
71
72def get_privkey_info(session, config):
73    try:
74        private_key = session.get_key(
75            object_class=pkcs11.constants.ObjectClass.PRIVATE_KEY, label=config["label"]
76        )
77        print(f'Got private key metadata with label {config["label"]}.')
78        return private_key
79
80    except pkcs11.exceptions.PKCS11Error as e:
81        handle_exceptions(e)
82        print("Failed to get the private key")
83        sys.exit(1)
84
85
86def get_pubkey(session, config):
87    print("Trying to extract public key from the HSM.")
88    try:
89        if "label_pubkey" in config:
90            public_key_label = config["label_pubkey"]
91        else:
92            print(
93                "Config option 'label_pubkey' not found, "
94                "using config option 'label' for public key."
95            )
96            public_key_label = config["label"]
97
98        public_key = session.get_key(
99            object_class=pkcs11.constants.ObjectClass.PUBLIC_KEY,
100            label=public_key_label,
101        )
102        if public_key.key_type == pkcs11.mechanisms.KeyType.RSA:
103            exponent = public_key[pkcs11.Attribute.PUBLIC_EXPONENT]
104            modulus = public_key[pkcs11.Attribute.MODULUS]
105            e = int.from_bytes(exponent, byteorder="big")
106            n = int.from_bytes(modulus, byteorder="big")
107            public_key = RSA.RSAPublicNumbers(e, n).public_key()
108
109        elif public_key.key_type == pkcs11.mechanisms.KeyType.EC:
110            ecpoints, _ = ecdsa.der.remove_octet_string(
111                public_key[pkcs11.Attribute.EC_POINT]
112            )
113            public_key = EC.EllipticCurvePublicKey.from_encoded_point(
114                EC.SECP256R1(), ecpoints
115            )
116
117        else:
118            print("Incorrect public key algorithm")
119            sys.exit(1)
120
121        print(f"Got public key with label {public_key_label}.")
122        return public_key
123
124    except pkcs11.exceptions.PKCS11Error as e:
125        handle_exceptions(e)
126        print("Failed to extract the public key")
127        sys.exit(1)
128
129
130def sign_payload(private_key, payload):
131    try:
132        print("Signing payload using the HSM.")
133        key_type = private_key.key_type
134        mechanism, mechanism_params = get_mechanism(key_type)
135        signature = private_key.sign(
136            data=payload, mechanism=mechanism, mechanism_param=mechanism_params
137        )
138
139        if len(signature) != 0:
140            print("Signature generation successful.")
141
142        if key_type == pkcs11.mechanisms.KeyType.EC:
143            r = int(binascii.hexlify(signature[:32]), 16)
144            s = int(binascii.hexlify(signature[32:]), 16)
145
146            # der encoding in case of ecdsa signatures
147            signature = ecdsa.der.encode_sequence(
148                ecdsa.der.encode_integer(r), ecdsa.der.encode_integer(s)
149            )
150
151        return signature
152
153    except pkcs11.exceptions.PKCS11Error as e:
154        handle_exceptions(e, mechanism)
155        print("Payload Signing Failed")
156        sys.exit(1)
157
158
159def get_mechanism(key_type):
160    if key_type == pkcs11.mechanisms.KeyType.RSA:
161        return pkcs11.mechanisms.Mechanism.SHA256_RSA_PKCS_PSS, (
162            pkcs11.mechanisms.Mechanism.SHA256,
163            pkcs11.MGF.SHA256,
164            32,
165        )
166    elif key_type == pkcs11.mechanisms.KeyType.EC:
167        return pkcs11.mechanisms.Mechanism.ECDSA_SHA256, None
168    else:
169        print("Invalid signing key mechanism")
170        sys.exit(1)
171
172
173def close_connection(session):
174    try:
175        session.close()
176        print("Connection closed successfully")
177    except pkcs11.exceptions.PKCS11Error as e:
178        handle_exceptions(e)
179        print("Failed to close the HSM session")
180        sys.exit(1)
181