1# SPDX-License-Identifier: Apache-2.0
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import subprocess
18from click.testing import CliRunner
19from imgtool import main as imgtool_main
20from imgtool.main import imgtool
21
22# all supported key types for 'keygen'
23KEY_TYPES = [*imgtool_main.keygens]
24KEY_ENCODINGS = [*imgtool_main.valid_encodings]
25PUB_HASH_ENCODINGS = [*imgtool_main.valid_hash_encodings]
26PVT_KEY_FORMATS = [*imgtool_main.valid_formats]
27
28OPENSSL_KEY_TYPES = {
29    "rsa-2048": "Private-Key: (2048 bit, 2 primes)",
30    "rsa-3072": "Private-Key: (3072 bit, 2 primes)",
31    "ecdsa-p256": "Private-Key: (256 bit)",
32    "ecdsa-p384": "Private-Key: (384 bit)",
33    "ed25519": "ED25519 Private-Key:",
34    "x25519": "X25519 Private-Key:",
35}
36
37GEN_KEY_EXT = ".key"
38GEN_ANOTHER_KEY_EXT = ".another.key"
39PUB_KEY_EXT = ".pub"
40PUB_KEY_HASH_EXT = ".pubhash"
41
42
43def tmp_name(tmp_path, key_type, suffix=""):
44    return tmp_path / (key_type + suffix)
45
46
47@pytest.fixture(scope="session")
48def tmp_path_persistent(tmp_path_factory):
49    return tmp_path_factory.mktemp("keys")
50
51
52@pytest.mark.parametrize("key_type", KEY_TYPES)
53def test_keygen(key_type, tmp_path_persistent):
54    """Generate keys by imgtool"""
55
56    runner = CliRunner()
57
58    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
59
60    assert not gen_key.exists()
61    result = runner.invoke(
62        imgtool, ["keygen", "--key", str(gen_key), "--type", key_type]
63    )
64    assert result.exit_code == 0
65    assert gen_key.exists()
66    assert gen_key.stat().st_size > 0
67
68    # another key
69    gen_key2 = tmp_name(tmp_path_persistent, key_type, GEN_ANOTHER_KEY_EXT)
70
71    assert str(gen_key2) != str(gen_key)
72
73    assert not gen_key2.exists()
74    result = runner.invoke(
75        imgtool, ["keygen", "--key", str(gen_key2), "--type", key_type]
76    )
77    assert result.exit_code == 0
78    assert gen_key2.exists()
79    assert gen_key2.stat().st_size > 0
80
81    # content must be different
82    assert gen_key.read_bytes() != gen_key2.read_bytes()
83
84
85@pytest.mark.parametrize("key_type", KEY_TYPES)
86def test_keygen_type(key_type, tmp_path_persistent):
87    """Check generated keys"""
88    assert key_type in OPENSSL_KEY_TYPES
89
90    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
91
92    result = subprocess.run(
93        ["openssl", "pkey", "-in", str(gen_key), "-check", "-noout", "-text"],
94        capture_output=True,
95        text=True,
96    )
97    assert result.returncode == 0
98    assert "Key is valid" in result.stdout
99    assert OPENSSL_KEY_TYPES[key_type] in result.stdout
100
101
102@pytest.mark.parametrize("key_type", KEY_TYPES)
103@pytest.mark.parametrize("format", PVT_KEY_FORMATS)
104def test_getpriv(key_type, format, tmp_path_persistent):
105    """Get private key"""
106    runner = CliRunner()
107
108    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
109
110    result = runner.invoke(
111        imgtool,
112        [
113            "getpriv",
114            "--key",
115            str(gen_key),
116            "--format",
117            format,
118        ],
119    )
120    assert result.exit_code == 0
121
122
123@pytest.mark.parametrize("key_type", KEY_TYPES)
124@pytest.mark.parametrize("encoding", KEY_ENCODINGS)
125def test_getpub(key_type, encoding, tmp_path_persistent):
126    """Get public key"""
127    runner = CliRunner()
128
129    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
130    pub_key = tmp_name(tmp_path_persistent, key_type, PUB_KEY_EXT
131                       + "." + encoding)
132
133    assert not pub_key.exists()
134    result = runner.invoke(
135        imgtool,
136        [
137            "getpub",
138            "--key",
139            str(gen_key),
140            "--output",
141            str(pub_key),
142            "--encoding",
143            encoding,
144        ],
145    )
146    assert result.exit_code == 0
147    assert pub_key.exists()
148    assert pub_key.stat().st_size > 0
149
150
151@pytest.mark.parametrize("key_type", KEY_TYPES)
152@pytest.mark.parametrize("encoding", PUB_HASH_ENCODINGS)
153def test_getpubhash(key_type, encoding, tmp_path_persistent):
154    """Get the hash of the public key"""
155    runner = CliRunner()
156
157    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
158    pub_key_hash = tmp_name(
159        tmp_path_persistent, key_type, PUB_KEY_HASH_EXT + "." + encoding
160    )
161
162    assert not pub_key_hash.exists()
163    result = runner.invoke(
164        imgtool,
165        [
166            "getpubhash",
167            "--key",
168            str(gen_key),
169            "--output",
170            str(pub_key_hash),
171            "--encoding",
172            encoding,
173        ],
174    )
175    assert result.exit_code == 0
176    assert pub_key_hash.exists()
177    assert pub_key_hash.stat().st_size > 0
178
179
180@pytest.mark.parametrize("key_type", KEY_TYPES)
181def test_sign_verify(key_type, tmp_path_persistent):
182    """Test basic sign and verify"""
183    runner = CliRunner()
184
185    gen_key = tmp_name(tmp_path_persistent, key_type, GEN_KEY_EXT)
186    wrong_key = tmp_name(tmp_path_persistent, key_type, GEN_ANOTHER_KEY_EXT)
187    image = tmp_name(tmp_path_persistent, "image", "bin")
188    image_signed = tmp_name(tmp_path_persistent, "image", "signed")
189
190    with image.open("wb") as f:
191        f.write(b"\x00" * 1024)
192
193    # not all required arguments are provided
194    result = runner.invoke(
195        imgtool,
196        [
197            "sign",
198            "--key",
199            str(gen_key),
200            str(image),
201            str(image_signed),
202        ],
203    )
204    assert result.exit_code != 0
205    assert not image_signed.exists()
206
207    result = runner.invoke(
208        imgtool,
209        [
210            "sign",
211            "--key",
212            str(gen_key),
213            "--align",
214            "16",
215            "--version",
216            "1.0.0",
217            "--header-size",
218            "0x400",
219            "--slot-size",
220            "0x10000",
221            "--pad-header",
222            str(image),
223            str(image_signed),
224        ],
225    )
226    assert result.exit_code == 0
227    assert image_signed.exists()
228    assert image_signed.stat().st_size > 0
229
230    # original key can be used to verify a signed image
231    result = runner.invoke(
232        imgtool,
233        [
234            "verify",
235            "--key",
236            str(gen_key),
237            str(image_signed),
238        ],
239    )
240    assert result.exit_code == 0
241
242    # 'another' key is not valid to verify a signed image
243    result = runner.invoke(
244        imgtool,
245        [
246            "verify",
247            "--key",
248            str(wrong_key),
249            str(image_signed),
250        ],
251    )
252    assert result.exit_code != 0
253    image_signed.unlink()
254