1"""Knowledge about the PSA key store as implemented in Mbed TLS.
2
3Note that if you need to make a change that affects how keys are
4stored, this may indicate that the key store is changing in a
5backward-incompatible way! Think carefully about backward compatibility
6before changing how test data is constructed or validated.
7"""
8
9# Copyright The Mbed TLS Contributors
10# SPDX-License-Identifier: Apache-2.0
11#
12# Licensed under the Apache License, Version 2.0 (the "License"); you may
13# not use this file except in compliance with the License.
14# You may obtain a copy of the License at
15#
16# http://www.apache.org/licenses/LICENSE-2.0
17#
18# Unless required by applicable law or agreed to in writing, software
19# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
20# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21# See the License for the specific language governing permissions and
22# limitations under the License.
23
24import re
25import struct
26from typing import Dict, List, Optional, Set, Union
27import unittest
28
29from . import c_build_helper
30
31
32class Expr:
33    """Representation of a C expression with a known or knowable numerical value."""
34
35    def __init__(self, content: Union[int, str]):
36        if isinstance(content, int):
37            digits = 8 if content > 0xffff else 4
38            self.string = '{0:#0{1}x}'.format(content, digits + 2)
39            self.value_if_known = content #type: Optional[int]
40        else:
41            self.string = content
42            self.unknown_values.add(self.normalize(content))
43            self.value_if_known = None
44
45    value_cache = {} #type: Dict[str, int]
46    """Cache of known values of expressions."""
47
48    unknown_values = set() #type: Set[str]
49    """Expressions whose values are not present in `value_cache` yet."""
50
51    def update_cache(self) -> None:
52        """Update `value_cache` for expressions registered in `unknown_values`."""
53        expressions = sorted(self.unknown_values)
54        values = c_build_helper.get_c_expression_values(
55            'unsigned long', '%lu',
56            expressions,
57            header="""
58            #include <psa/crypto.h>
59            """,
60            include_path=['include']) #type: List[str]
61        for e, v in zip(expressions, values):
62            self.value_cache[e] = int(v, 0)
63        self.unknown_values.clear()
64
65    @staticmethod
66    def normalize(string: str) -> str:
67        """Put the given C expression in a canonical form.
68
69        This function is only intended to give correct results for the
70        relatively simple kind of C expression typically used with this
71        module.
72        """
73        return re.sub(r'\s+', r'', string)
74
75    def value(self) -> int:
76        """Return the numerical value of the expression."""
77        if self.value_if_known is None:
78            if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
79                return int(self.string, 0)
80            normalized = self.normalize(self.string)
81            if normalized not in self.value_cache:
82                self.update_cache()
83            self.value_if_known = self.value_cache[normalized]
84        return self.value_if_known
85
86Exprable = Union[str, int, Expr]
87"""Something that can be converted to a C expression with a known numerical value."""
88
89def as_expr(thing: Exprable) -> Expr:
90    """Return an `Expr` object for `thing`.
91
92    If `thing` is already an `Expr` object, return it. Otherwise build a new
93    `Expr` object from `thing`. `thing` can be an integer or a string that
94    contains a C expression.
95    """
96    if isinstance(thing, Expr):
97        return thing
98    else:
99        return Expr(thing)
100
101
102class Key:
103    """Representation of a PSA crypto key object and its storage encoding.
104    """
105
106    LATEST_VERSION = 0
107    """The latest version of the storage format."""
108
109    def __init__(self, *,
110                 version: Optional[int] = None,
111                 id: Optional[int] = None, #pylint: disable=redefined-builtin
112                 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
113                 type: Exprable, #pylint: disable=redefined-builtin
114                 bits: int,
115                 usage: Exprable, alg: Exprable, alg2: Exprable,
116                 material: bytes #pylint: disable=used-before-assignment
117                ) -> None:
118        self.version = self.LATEST_VERSION if version is None else version
119        self.id = id #pylint: disable=invalid-name #type: Optional[int]
120        self.lifetime = as_expr(lifetime) #type: Expr
121        self.type = as_expr(type) #type: Expr
122        self.bits = bits #type: int
123        self.usage = as_expr(usage) #type: Expr
124        self.alg = as_expr(alg) #type: Expr
125        self.alg2 = as_expr(alg2) #type: Expr
126        self.material = material #type: bytes
127
128    MAGIC = b'PSA\000KEY\000'
129
130    @staticmethod
131    def pack(
132            fmt: str,
133            *args: Union[int, Expr]
134    ) -> bytes: #pylint: disable=used-before-assignment
135        """Pack the given arguments into a byte string according to the given format.
136
137        This function is similar to `struct.pack`, but with the following differences:
138        * All integer values are encoded with standard sizes and in
139          little-endian representation. `fmt` must not include an endianness
140          prefix.
141        * Arguments can be `Expr` objects instead of integers.
142        * Only integer-valued elements are supported.
143        """
144        return struct.pack('<' + fmt, # little-endian, standard sizes
145                           *[arg.value() if isinstance(arg, Expr) else arg
146                             for arg in args])
147
148    def bytes(self) -> bytes:
149        """Return the representation of the key in storage as a byte array.
150
151        This is the content of the PSA storage file. When PSA storage is
152        implemented over stdio files, this does not include any wrapping made
153        by the PSA-storage-over-stdio-file implementation.
154
155        Note that if you need to make a change in this function,
156        this may indicate that the key store is changing in a
157        backward-incompatible way! Think carefully about backward
158        compatibility before making any change here.
159        """
160        header = self.MAGIC + self.pack('L', self.version)
161        if self.version == 0:
162            attributes = self.pack('LHHLLL',
163                                   self.lifetime, self.type, self.bits,
164                                   self.usage, self.alg, self.alg2)
165            material = self.pack('L', len(self.material)) + self.material
166        else:
167            raise NotImplementedError
168        return header + attributes + material
169
170    def hex(self) -> str:
171        """Return the representation of the key as a hexadecimal string.
172
173        This is the hexadecimal representation of `self.bytes`.
174        """
175        return self.bytes().hex()
176
177    def location_value(self) -> int:
178        """The numerical value of the location encoded in the key's lifetime."""
179        return self.lifetime.value() >> 8
180
181
182class TestKey(unittest.TestCase):
183    # pylint: disable=line-too-long
184    """A few smoke tests for the functionality of the `Key` class."""
185
186    def test_numerical(self):
187        key = Key(version=0,
188                  id=1, lifetime=0x00000001,
189                  type=0x2400, bits=128,
190                  usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
191                  material=b'@ABCDEFGHIJKLMNO')
192        expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
193        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
194        self.assertEqual(key.hex(), expected_hex)
195
196    def test_names(self):
197        length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
198        key = Key(version=0,
199                  id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
200                  type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
201                  usage=0, alg=0, alg2=0,
202                  material=b'\x00' * length)
203        expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
204        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
205        self.assertEqual(key.hex(), expected_hex)
206
207    def test_defaults(self):
208        key = Key(type=0x1001, bits=8,
209                  usage=0, alg=0, alg2=0,
210                  material=b'\x2a')
211        expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
212        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
213        self.assertEqual(key.hex(), expected_hex)
214