1"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6#
7
8import itertools
9import re
10from typing import Dict, IO, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
11
12
13class ReadFileLineException(Exception):
14    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
15        message = 'in {} at {}'.format(filename, line_number)
16        super(ReadFileLineException, self).__init__(message)
17        self.filename = filename
18        self.line_number = line_number
19
20
21class read_file_lines:
22    # Dear Pylint, conventionally, a context manager class name is lowercase.
23    # pylint: disable=invalid-name,too-few-public-methods
24    """Context manager to read a text file line by line.
25
26    ```
27    with read_file_lines(filename) as lines:
28        for line in lines:
29            process(line)
30    ```
31    is equivalent to
32    ```
33    with open(filename, 'r') as input_file:
34        for line in input_file:
35            process(line)
36    ```
37    except that if process(line) raises an exception, then the read_file_lines
38    snippet annotates the exception with the file name and line number.
39    """
40    def __init__(self, filename: str, binary: bool = False) -> None:
41        self.filename = filename
42        self.file = None #type: Optional[IO[str]]
43        self.line_number = 'entry' #type: Union[int, str]
44        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
45        self.binary = binary
46    def __enter__(self) -> 'read_file_lines':
47        self.file = open(self.filename, 'rb' if self.binary else 'r')
48        self.generator = enumerate(self.file)
49        return self
50    def __iter__(self) -> Iterator[str]:
51        assert self.generator is not None
52        for line_number, content in self.generator:
53            self.line_number = line_number
54            yield content
55        self.line_number = 'exit'
56    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
57        if self.file is not None:
58            self.file.close()
59        if exc_type is not None:
60            raise ReadFileLineException(self.filename, self.line_number) \
61                from exc_value
62
63
64class PSAMacroEnumerator:
65    """Information about constructors of various PSA Crypto types.
66
67    This includes macro names as well as information about their arguments
68    when applicable.
69
70    This class only provides ways to enumerate expressions that evaluate to
71    values of the covered types. Derived classes are expected to populate
72    the set of known constructors of each kind, as well as populate
73    `self.arguments_for` for arguments that are not of a kind that is
74    enumerated here.
75    """
76    #pylint: disable=too-many-instance-attributes
77
78    def __init__(self) -> None:
79        """Set up an empty set of known constructor macros.
80        """
81        self.statuses = set() #type: Set[str]
82        self.lifetimes = set() #type: Set[str]
83        self.locations = set() #type: Set[str]
84        self.persistence_levels = set() #type: Set[str]
85        self.algorithms = set() #type: Set[str]
86        self.ecc_curves = set() #type: Set[str]
87        self.dh_groups = set() #type: Set[str]
88        self.key_types = set() #type: Set[str]
89        self.key_usage_flags = set() #type: Set[str]
90        self.hash_algorithms = set() #type: Set[str]
91        self.mac_algorithms = set() #type: Set[str]
92        self.ka_algorithms = set() #type: Set[str]
93        self.kdf_algorithms = set() #type: Set[str]
94        self.pake_algorithms = set() #type: Set[str]
95        self.aead_algorithms = set() #type: Set[str]
96        self.sign_algorithms = set() #type: Set[str]
97        # macro name -> list of argument names
98        self.argspecs = {} #type: Dict[str, List[str]]
99        # argument name -> list of values
100        self.arguments_for = {
101            'mac_length': [],
102            'min_mac_length': [],
103            'tag_length': [],
104            'min_tag_length': [],
105        } #type: Dict[str, List[str]]
106        # Whether to include intermediate macros in enumerations. Intermediate
107        # macros serve as category headers and are not valid values of their
108        # type. See `is_internal_name`.
109        # Always false in this class, may be set to true in derived classes.
110        self.include_intermediate = False
111
112    def is_internal_name(self, name: str) -> bool:
113        """Whether this is an internal macro. Internal macros will be skipped."""
114        if not self.include_intermediate:
115            if name.endswith('_BASE') or name.endswith('_NONE'):
116                return True
117            if '_CATEGORY_' in name:
118                return True
119        return name.endswith('_FLAG') or name.endswith('_MASK')
120
121    def gather_arguments(self) -> None:
122        """Populate the list of values for macro arguments.
123
124        Call this after parsing all the inputs.
125        """
126        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
127        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
128        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
129        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
130        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
131        self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
132        self.arguments_for['curve'] = sorted(self.ecc_curves)
133        self.arguments_for['group'] = sorted(self.dh_groups)
134        self.arguments_for['persistence'] = sorted(self.persistence_levels)
135        self.arguments_for['location'] = sorted(self.locations)
136        self.arguments_for['lifetime'] = sorted(self.lifetimes)
137
138    @staticmethod
139    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
140        """Format a macro call with arguments.
141
142        The resulting format is consistent with
143        `InputsForTest.normalize_argument`.
144        """
145        return name + '(' + ', '.join(arguments) + ')'
146
147    _argument_split_re = re.compile(r' *, *')
148    @classmethod
149    def _argument_split(cls, arguments: str) -> List[str]:
150        return re.split(cls._argument_split_re, arguments)
151
152    def distribute_arguments(self, name: str) -> Iterator[str]:
153        """Generate macro calls with each tested argument set.
154
155        If name is a macro without arguments, just yield "name".
156        If name is a macro with arguments, yield a series of
157        "name(arg1,...,argN)" where each argument takes each possible
158        value at least once.
159        """
160        try:
161            if name not in self.argspecs:
162                yield name
163                return
164            argspec = self.argspecs[name]
165            if argspec == []:
166                yield name + '()'
167                return
168            argument_lists = [self.arguments_for[arg] for arg in argspec]
169            arguments = [values[0] for values in argument_lists]
170            yield self._format_arguments(name, arguments)
171            # Dear Pylint, enumerate won't work here since we're modifying
172            # the array.
173            # pylint: disable=consider-using-enumerate
174            for i in range(len(arguments)):
175                for value in argument_lists[i][1:]:
176                    arguments[i] = value
177                    yield self._format_arguments(name, arguments)
178                arguments[i] = argument_lists[i][0]
179        except BaseException as e:
180            raise Exception('distribute_arguments({})'.format(name)) from e
181
182    def distribute_arguments_without_duplicates(
183            self, seen: Set[str], name: str
184    ) -> Iterator[str]:
185        """Same as `distribute_arguments`, but don't repeat seen results."""
186        for result in self.distribute_arguments(name):
187            if result not in seen:
188                seen.add(result)
189                yield result
190
191    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
192        """Generate expressions covering values constructed from the given names.
193
194        `names` can be any iterable collection of macro names.
195
196        For example:
197        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
198          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
199          every known hash algorithm ``h``.
200        * ``macros.generate_expressions(macros.key_types)`` generates all
201          key types.
202        """
203        seen = set() #type: Set[str]
204        return itertools.chain(*(
205            self.distribute_arguments_without_duplicates(seen, name)
206            for name in names
207        ))
208
209
210class PSAMacroCollector(PSAMacroEnumerator):
211    """Collect PSA crypto macro definitions from C header files.
212    """
213
214    def __init__(self, include_intermediate: bool = False) -> None:
215        """Set up an object to collect PSA macro definitions.
216
217        Call the read_file method of the constructed object on each header file.
218
219        * include_intermediate: if true, include intermediate macros such as
220          PSA_XXX_BASE that do not designate semantic values.
221        """
222        super().__init__()
223        self.include_intermediate = include_intermediate
224        self.key_types_from_curve = {} #type: Dict[str, str]
225        self.key_types_from_group = {} #type: Dict[str, str]
226        self.algorithms_from_hash = {} #type: Dict[str, str]
227
228    @staticmethod
229    def algorithm_tester(name: str) -> str:
230        """The predicate for whether an algorithm is built from the given constructor.
231
232        The given name must be the name of an algorithm constructor of the
233        form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build
234        an algorithm value. Return the corresponding predicate macro which
235        is used as ``predicate(alg)`` to test whether ``alg`` can be built
236        as ``PSA_ALG_xxx(yyy)``. The predicate is usually called
237        ``PSA_ALG_IS_xxx``.
238        """
239        prefix = 'PSA_ALG_'
240        assert name.startswith(prefix)
241        midfix = 'IS_'
242        suffix = name[len(prefix):]
243        if suffix in ['DSA', 'ECDSA']:
244            midfix += 'RANDOMIZED_'
245        elif suffix == 'RSA_PSS':
246            suffix += '_STANDARD_SALT'
247        return prefix + midfix + suffix
248
249    def record_algorithm_subtype(self, name: str, expansion: str) -> None:
250        """Record the subtype of an algorithm constructor.
251
252        Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
253        is of a subtype that is tracked in its own set, add it to the relevant
254        set.
255        """
256        # This code is very ad hoc and fragile. It should be replaced by
257        # something more robust.
258        if re.match(r'MAC(?:_|\Z)', name):
259            self.mac_algorithms.add(name)
260        elif re.match(r'KDF(?:_|\Z)', name):
261            self.kdf_algorithms.add(name)
262        elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
263            self.hash_algorithms.add(name)
264        elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
265            self.mac_algorithms.add(name)
266        elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
267            self.aead_algorithms.add(name)
268        elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
269            self.ka_algorithms.add(name)
270        elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
271            self.kdf_algorithms.add(name)
272
273    # "#define" followed by a macro name with either no parameters
274    # or a single parameter and a non-empty expansion.
275    # Grab the macro name in group 1, the parameter name if any in group 2
276    # and the expansion in group 3.
277    _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
278                                      r'(?:\s+|\((\w+)\)\s*)' +
279                                      r'(.+)')
280    _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
281
282    def read_line(self, line):
283        """Parse a C header line and record the PSA identifier it defines if any.
284        This function analyzes lines that start with "#define PSA_"
285        (up to non-significant whitespace) and skips all non-matching lines.
286        """
287        # pylint: disable=too-many-branches
288        m = re.match(self._define_directive_re, line)
289        if not m:
290            return
291        name, parameter, expansion = m.groups()
292        expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
293        if parameter:
294            self.argspecs[name] = [parameter]
295        if re.match(self._deprecated_definition_re, expansion):
296            # Skip deprecated values, which are assumed to be
297            # backward compatibility aliases that share
298            # numerical values with non-deprecated values.
299            return
300        if self.is_internal_name(name):
301            # Macro only to build actual values
302            return
303        elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
304           and not parameter:
305            self.statuses.add(name)
306        elif name.startswith('PSA_KEY_TYPE_') and not parameter:
307            self.key_types.add(name)
308        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
309            self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
310        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
311            self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
312        elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
313            self.ecc_curves.add(name)
314        elif name.startswith('PSA_DH_FAMILY_') and not parameter:
315            self.dh_groups.add(name)
316        elif name.startswith('PSA_ALG_') and not parameter:
317            if name in ['PSA_ALG_ECDSA_BASE',
318                        'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
319                # Ad hoc skipping of duplicate names for some numerical values
320                return
321            self.algorithms.add(name)
322            self.record_algorithm_subtype(name, expansion)
323        elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
324            self.algorithms_from_hash[name] = self.algorithm_tester(name)
325        elif name.startswith('PSA_KEY_USAGE_') and not parameter:
326            self.key_usage_flags.add(name)
327        else:
328            # Other macro without parameter
329            return
330
331    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
332    _continued_line_re = re.compile(rb'\\\r?\n\Z')
333    def read_file(self, header_file):
334        for line in header_file:
335            m = re.search(self._continued_line_re, line)
336            while m:
337                cont = next(header_file)
338                line = line[:m.start(0)] + cont
339                m = re.search(self._continued_line_re, line)
340            line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
341            self.read_line(line)
342
343
344class InputsForTest(PSAMacroEnumerator):
345    # pylint: disable=too-many-instance-attributes
346    """Accumulate information about macros to test.
347enumerate
348    This includes macro names as well as information about their arguments
349    when applicable.
350    """
351
352    def __init__(self) -> None:
353        super().__init__()
354        self.all_declared = set() #type: Set[str]
355        # Identifier prefixes
356        self.table_by_prefix = {
357            'ERROR': self.statuses,
358            'ALG': self.algorithms,
359            'ECC_CURVE': self.ecc_curves,
360            'DH_GROUP': self.dh_groups,
361            'KEY_LIFETIME': self.lifetimes,
362            'KEY_LOCATION': self.locations,
363            'KEY_PERSISTENCE': self.persistence_levels,
364            'KEY_TYPE': self.key_types,
365            'KEY_USAGE': self.key_usage_flags,
366        } #type: Dict[str, Set[str]]
367        # Test functions
368        self.table_by_test_function = {
369            # Any function ending in _algorithm also gets added to
370            # self.algorithms.
371            'key_type': [self.key_types],
372            'block_cipher_key_type': [self.key_types],
373            'stream_cipher_key_type': [self.key_types],
374            'ecc_key_family': [self.ecc_curves],
375            'ecc_key_types': [self.ecc_curves],
376            'dh_key_family': [self.dh_groups],
377            'dh_key_types': [self.dh_groups],
378            'hash_algorithm': [self.hash_algorithms],
379            'mac_algorithm': [self.mac_algorithms],
380            'cipher_algorithm': [],
381            'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
382            'aead_algorithm': [self.aead_algorithms],
383            'key_derivation_algorithm': [self.kdf_algorithms],
384            'key_agreement_algorithm': [self.ka_algorithms],
385            'asymmetric_signature_algorithm': [self.sign_algorithms],
386            'asymmetric_signature_wildcard': [self.algorithms],
387            'asymmetric_encryption_algorithm': [],
388            'pake_algorithm': [self.pake_algorithms],
389            'other_algorithm': [],
390            'lifetime': [self.lifetimes],
391        } #type: Dict[str, List[Set[str]]]
392        mac_lengths = [str(n) for n in [
393            1,  # minimum expressible
394            4,  # minimum allowed by policy
395            13, # an odd size in a plausible range
396            14, # an even non-power-of-two size in a plausible range
397            16, # same as full size for at least one algorithm
398            63, # maximum expressible
399        ]]
400        self.arguments_for['mac_length'] += mac_lengths
401        self.arguments_for['min_mac_length'] += mac_lengths
402        aead_lengths = [str(n) for n in [
403            1,  # minimum expressible
404            4,  # minimum allowed by policy
405            13, # an odd size in a plausible range
406            14, # an even non-power-of-two size in a plausible range
407            16, # same as full size for at least one algorithm
408            63, # maximum expressible
409        ]]
410        self.arguments_for['tag_length'] += aead_lengths
411        self.arguments_for['min_tag_length'] += aead_lengths
412
413    def add_numerical_values(self) -> None:
414        """Add numerical values that are not supported to the known identifiers."""
415        # Sets of names per type
416        self.algorithms.add('0xffffffff')
417        self.ecc_curves.add('0xff')
418        self.dh_groups.add('0xff')
419        self.key_types.add('0xffff')
420        self.key_usage_flags.add('0x80000000')
421
422        # Hard-coded values for unknown algorithms
423        #
424        # These have to have values that are correct for their respective
425        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
426        # not likely to be assigned in the near future.
427        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
428        self.mac_algorithms.add('0x03007fff')
429        self.ka_algorithms.add('0x09fc0000')
430        self.kdf_algorithms.add('0x080000ff')
431        self.pake_algorithms.add('0x0a0000ff')
432        # For AEAD algorithms, the only variability is over the tag length,
433        # and this only applies to known algorithms, so don't test an
434        # unknown algorithm.
435
436    def get_names(self, type_word: str) -> Set[str]:
437        """Return the set of known names of values of the given type."""
438        return {
439            'status': self.statuses,
440            'algorithm': self.algorithms,
441            'ecc_curve': self.ecc_curves,
442            'dh_group': self.dh_groups,
443            'key_type': self.key_types,
444            'key_usage': self.key_usage_flags,
445        }[type_word]
446
447    # Regex for interesting header lines.
448    # Groups: 1=macro name, 2=type, 3=argument list (optional).
449    _header_line_re = \
450        re.compile(r'#define +' +
451                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
452                   r'(?:\(([^\n()]*)\))?')
453    # Regex of macro names to exclude.
454    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
455    # Additional excluded macros.
456    _excluded_names = set([
457        # Macros that provide an alternative way to build the same
458        # algorithm as another macro.
459        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
460        'PSA_ALG_FULL_LENGTH_MAC',
461        # Auxiliary macro whose name doesn't fit the usual patterns for
462        # auxiliary macros.
463        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
464    ])
465    def parse_header_line(self, line: str) -> None:
466        """Parse a C header line, looking for "#define PSA_xxx"."""
467        m = re.match(self._header_line_re, line)
468        if not m:
469            return
470        name = m.group(1)
471        self.all_declared.add(name)
472        if re.search(self._excluded_name_re, name) or \
473           name in self._excluded_names or \
474           self.is_internal_name(name):
475            return
476        dest = self.table_by_prefix.get(m.group(2))
477        if dest is None:
478            return
479        dest.add(name)
480        if m.group(3):
481            self.argspecs[name] = self._argument_split(m.group(3))
482
483    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
484    def parse_header(self, filename: str) -> None:
485        """Parse a C header file, looking for "#define PSA_xxx"."""
486        with read_file_lines(filename, binary=True) as lines:
487            for line in lines:
488                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
489                self.parse_header_line(line)
490
491    _macro_identifier_re = re.compile(r'[A-Z]\w+')
492    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
493        for name in re.findall(self._macro_identifier_re, expr):
494            if name not in self.all_declared:
495                yield name
496
497    def accept_test_case_line(self, function: str, argument: str) -> bool:
498        #pylint: disable=unused-argument
499        undeclared = list(self.generate_undeclared_names(argument))
500        if undeclared:
501            raise Exception('Undeclared names in test case', undeclared)
502        return True
503
504    @staticmethod
505    def normalize_argument(argument: str) -> str:
506        """Normalize whitespace in the given C expression.
507
508        The result uses the same whitespace as
509        ` PSAMacroEnumerator.distribute_arguments`.
510        """
511        return re.sub(r',', r', ', re.sub(r' +', r'', argument))
512
513    def add_test_case_line(self, function: str, argument: str) -> None:
514        """Parse a test case data line, looking for algorithm metadata tests."""
515        sets = []
516        if function.endswith('_algorithm'):
517            sets.append(self.algorithms)
518            if function == 'key_agreement_algorithm' and \
519               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
520                # We only want *raw* key agreement algorithms as such, so
521                # exclude ones that are already chained with a KDF.
522                # Keep the expression as one to test as an algorithm.
523                function = 'other_algorithm'
524        sets += self.table_by_test_function[function]
525        if self.accept_test_case_line(function, argument):
526            for s in sets:
527                s.add(self.normalize_argument(argument))
528
529    # Regex matching a *.data line containing a test function call and
530    # its arguments. The actual definition is partly positional, but this
531    # regex is good enough in practice.
532    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
533    def parse_test_cases(self, filename: str) -> None:
534        """Parse a test case file (*.data), looking for algorithm metadata tests."""
535        with read_file_lines(filename) as lines:
536            for line in lines:
537                m = re.match(self._test_case_line_re, line)
538                if m:
539                    self.add_test_case_line(m.group(1), m.group(2))
540