1"""General key class."""
2
3# SPDX-License-Identifier: Apache-2.0
4
5import binascii
6import io
7import os
8import sys
9from cryptography.hazmat.primitives.hashes import Hash, SHA256
10
11AUTOGEN_MESSAGE = "/* Autogenerated by imgtool.py, do not edit. */"
12
13
14class FileHandler(object):
15    def __init__(self, file, *args, **kwargs):
16        self.file_in = file
17        self.args = args
18        self.kwargs = kwargs
19
20    def __enter__(self):
21        if isinstance(self.file_in, (str, bytes, os.PathLike)):
22            self.file = open(self.file_in, *self.args, **self.kwargs)
23        else:
24            self.file = self.file_in
25        return self.file
26
27    def __exit__(self, *args):
28        if self.file != self.file_in:
29            self.file.close()
30
31
32class KeyClass(object):
33    def _emit(self, header, trailer, encoded_bytes, indent, file=sys.stdout,
34              len_format=None):
35        with FileHandler(file, 'w') as file:
36            self._emit_to_output(header, trailer, encoded_bytes, indent,
37                                     file, len_format)
38
39    def _emit_to_output(self, header, trailer, encoded_bytes, indent, file,
40                        len_format):
41        print(AUTOGEN_MESSAGE, file=file)
42        print(header, end='', file=file)
43        for count, b in enumerate(encoded_bytes):
44            if count % 8 == 0:
45                print("\n" + indent, end='', file=file)
46            else:
47                print(" ", end='', file=file)
48            print("0x{:02x},".format(b), end='', file=file)
49        print("\n" + trailer, file=file)
50        if len_format is not None:
51            print(len_format.format(len(encoded_bytes)), file=file)
52
53    def _emit_raw(self, encoded_bytes, file):
54        with FileHandler(file, 'wb') as file:
55            try:
56                # file.buffer is not part of the TextIOBase API
57                # and may not exist in some implementations.
58                file.buffer.write(encoded_bytes)
59            except AttributeError:
60                # raw binary data, can be for example io.BytesIO
61                file.write(encoded_bytes)
62
63    def emit_c_public(self, file=sys.stdout):
64        self._emit(
65                header="const unsigned char {}_pub_key[] = {{"
66                       .format(self.shortname()),
67                trailer="};",
68                encoded_bytes=self.get_public_bytes(),
69                indent="    ",
70                len_format="const unsigned int {}_pub_key_len = {{}};"
71                           .format(self.shortname()),
72                file=file)
73
74    def emit_c_public_hash(self, file=sys.stdout):
75        digest = Hash(SHA256())
76        digest.update(self.get_public_bytes())
77        self._emit(
78                header="const unsigned char {}_pub_key_hash[] = {{"
79                       .format(self.shortname()),
80                trailer="};",
81                encoded_bytes=digest.finalize(),
82                indent="    ",
83                len_format="const unsigned int {}_pub_key_hash_len = {{}};"
84                           .format(self.shortname()),
85                file=file)
86
87    def emit_raw_public(self, file=sys.stdout):
88        self._emit_raw(self.get_public_bytes(), file=file)
89
90    def emit_raw_public_hash(self, file=sys.stdout):
91        digest = Hash(SHA256())
92        digest.update(self.get_public_bytes())
93        self._emit_raw(digest.finalize(), file=file)
94
95    def emit_rust_public(self, file=sys.stdout):
96        self._emit(
97                header="static {}_PUB_KEY: &[u8] = &["
98                       .format(self.shortname().upper()),
99                trailer="];",
100                encoded_bytes=self.get_public_bytes(),
101                indent="    ",
102                file=file)
103
104    def emit_public_pem(self, file=sys.stdout):
105        with FileHandler(file, 'w') as file:
106            print(str(self.get_public_pem(), 'utf-8'), file=file, end='')
107
108    def emit_private(self, minimal, format, file=sys.stdout):
109        self._emit(
110                header="const unsigned char enc_priv_key[] = {",
111                trailer="};",
112                encoded_bytes=self.get_private_bytes(minimal, format),
113                indent="    ",
114                len_format="const unsigned int enc_priv_key_len = {};",
115                file=file)
116