1"""Knowledge about the PSA key store as implemented in Mbed TLS.
2
3Note that if you need to make a change that affects how keys are
4stored, this may indicate that the key store is changing in a
5backward-incompatible way! Think carefully about backward compatibility
6before changing how test data is constructed or validated.
7"""
8
9# Copyright The Mbed TLS Contributors
10# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
11#
12
13import re
14import struct
15from typing import Dict, List, Optional, Set, Union
16import unittest
17
18from . import c_build_helper
19from . import build_tree
20
21
22class Expr:
23    """Representation of a C expression with a known or knowable numerical value."""
24
25    def __init__(self, content: Union[int, str]):
26        if isinstance(content, int):
27            digits = 8 if content > 0xffff else 4
28            self.string = '{0:#0{1}x}'.format(content, digits + 2)
29            self.value_if_known = content #type: Optional[int]
30        else:
31            self.string = content
32            self.unknown_values.add(self.normalize(content))
33            self.value_if_known = None
34
35    value_cache = {} #type: Dict[str, int]
36    """Cache of known values of expressions."""
37
38    unknown_values = set() #type: Set[str]
39    """Expressions whose values are not present in `value_cache` yet."""
40
41    def update_cache(self) -> None:
42        """Update `value_cache` for expressions registered in `unknown_values`."""
43        expressions = sorted(self.unknown_values)
44        includes = ['include']
45        if build_tree.looks_like_tf_psa_crypto_root('.'):
46            includes.append('drivers/builtin/include')
47        values = c_build_helper.get_c_expression_values(
48            'unsigned long', '%lu',
49            expressions,
50            header="""
51            #include <psa/crypto.h>
52            """,
53            include_path=includes) #type: List[str]
54        for e, v in zip(expressions, values):
55            self.value_cache[e] = int(v, 0)
56        self.unknown_values.clear()
57
58    @staticmethod
59    def normalize(string: str) -> str:
60        """Put the given C expression in a canonical form.
61
62        This function is only intended to give correct results for the
63        relatively simple kind of C expression typically used with this
64        module.
65        """
66        return re.sub(r'\s+', r'', string)
67
68    def value(self) -> int:
69        """Return the numerical value of the expression."""
70        if self.value_if_known is None:
71            if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
72                return int(self.string, 0)
73            normalized = self.normalize(self.string)
74            if normalized not in self.value_cache:
75                self.update_cache()
76            self.value_if_known = self.value_cache[normalized]
77        return self.value_if_known
78
79Exprable = Union[str, int, Expr]
80"""Something that can be converted to a C expression with a known numerical value."""
81
82def as_expr(thing: Exprable) -> Expr:
83    """Return an `Expr` object for `thing`.
84
85    If `thing` is already an `Expr` object, return it. Otherwise build a new
86    `Expr` object from `thing`. `thing` can be an integer or a string that
87    contains a C expression.
88    """
89    if isinstance(thing, Expr):
90        return thing
91    else:
92        return Expr(thing)
93
94
95class Key:
96    """Representation of a PSA crypto key object and its storage encoding.
97    """
98
99    LATEST_VERSION = 0
100    """The latest version of the storage format."""
101
102    def __init__(self, *,
103                 version: Optional[int] = None,
104                 id: Optional[int] = None, #pylint: disable=redefined-builtin
105                 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
106                 type: Exprable, #pylint: disable=redefined-builtin
107                 bits: int,
108                 usage: Exprable, alg: Exprable, alg2: Exprable,
109                 material: bytes #pylint: disable=used-before-assignment
110                ) -> None:
111        self.version = self.LATEST_VERSION if version is None else version
112        self.id = id #pylint: disable=invalid-name #type: Optional[int]
113        self.lifetime = as_expr(lifetime) #type: Expr
114        self.type = as_expr(type) #type: Expr
115        self.bits = bits #type: int
116        self.usage = as_expr(usage) #type: Expr
117        self.alg = as_expr(alg) #type: Expr
118        self.alg2 = as_expr(alg2) #type: Expr
119        self.material = material #type: bytes
120
121    MAGIC = b'PSA\000KEY\000'
122
123    @staticmethod
124    def pack(
125            fmt: str,
126            *args: Union[int, Expr]
127    ) -> bytes: #pylint: disable=used-before-assignment
128        """Pack the given arguments into a byte string according to the given format.
129
130        This function is similar to `struct.pack`, but with the following differences:
131        * All integer values are encoded with standard sizes and in
132          little-endian representation. `fmt` must not include an endianness
133          prefix.
134        * Arguments can be `Expr` objects instead of integers.
135        * Only integer-valued elements are supported.
136        """
137        return struct.pack('<' + fmt, # little-endian, standard sizes
138                           *[arg.value() if isinstance(arg, Expr) else arg
139                             for arg in args])
140
141    def bytes(self) -> bytes:
142        """Return the representation of the key in storage as a byte array.
143
144        This is the content of the PSA storage file. When PSA storage is
145        implemented over stdio files, this does not include any wrapping made
146        by the PSA-storage-over-stdio-file implementation.
147
148        Note that if you need to make a change in this function,
149        this may indicate that the key store is changing in a
150        backward-incompatible way! Think carefully about backward
151        compatibility before making any change here.
152        """
153        header = self.MAGIC + self.pack('L', self.version)
154        if self.version == 0:
155            attributes = self.pack('LHHLLL',
156                                   self.lifetime, self.type, self.bits,
157                                   self.usage, self.alg, self.alg2)
158            material = self.pack('L', len(self.material)) + self.material
159        else:
160            raise NotImplementedError
161        return header + attributes + material
162
163    def hex(self) -> str:
164        """Return the representation of the key as a hexadecimal string.
165
166        This is the hexadecimal representation of `self.bytes`.
167        """
168        return self.bytes().hex()
169
170    def location_value(self) -> int:
171        """The numerical value of the location encoded in the key's lifetime."""
172        return self.lifetime.value() >> 8
173
174
175class TestKey(unittest.TestCase):
176    # pylint: disable=line-too-long
177    """A few smoke tests for the functionality of the `Key` class."""
178
179    def test_numerical(self):
180        key = Key(version=0,
181                  id=1, lifetime=0x00000001,
182                  type=0x2400, bits=128,
183                  usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
184                  material=b'@ABCDEFGHIJKLMNO')
185        expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
186        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
187        self.assertEqual(key.hex(), expected_hex)
188
189    def test_names(self):
190        length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
191        key = Key(version=0,
192                  id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
193                  type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
194                  usage=0, alg=0, alg2=0,
195                  material=b'\x00' * length)
196        expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
197        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
198        self.assertEqual(key.hex(), expected_hex)
199
200    def test_defaults(self):
201        key = Key(type=0x1001, bits=8,
202                  usage=0, alg=0, alg2=0,
203                  material=b'\x2a')
204        expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
205        self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
206        self.assertEqual(key.hex(), expected_hex)
207