1"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19import itertools
20import re
21from typing import Dict, IO, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
22
23
24class ReadFileLineException(Exception):
25    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
26        message = 'in {} at {}'.format(filename, line_number)
27        super(ReadFileLineException, self).__init__(message)
28        self.filename = filename
29        self.line_number = line_number
30
31
32class read_file_lines:
33    # Dear Pylint, conventionally, a context manager class name is lowercase.
34    # pylint: disable=invalid-name,too-few-public-methods
35    """Context manager to read a text file line by line.
36
37    ```
38    with read_file_lines(filename) as lines:
39        for line in lines:
40            process(line)
41    ```
42    is equivalent to
43    ```
44    with open(filename, 'r') as input_file:
45        for line in input_file:
46            process(line)
47    ```
48    except that if process(line) raises an exception, then the read_file_lines
49    snippet annotates the exception with the file name and line number.
50    """
51    def __init__(self, filename: str, binary: bool = False) -> None:
52        self.filename = filename
53        self.file = None #type: Optional[IO[str]]
54        self.line_number = 'entry' #type: Union[int, str]
55        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
56        self.binary = binary
57    def __enter__(self) -> 'read_file_lines':
58        self.file = open(self.filename, 'rb' if self.binary else 'r')
59        self.generator = enumerate(self.file)
60        return self
61    def __iter__(self) -> Iterator[str]:
62        assert self.generator is not None
63        for line_number, content in self.generator:
64            self.line_number = line_number
65            yield content
66        self.line_number = 'exit'
67    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
68        if self.file is not None:
69            self.file.close()
70        if exc_type is not None:
71            raise ReadFileLineException(self.filename, self.line_number) \
72                from exc_value
73
74
75class PSAMacroEnumerator:
76    """Information about constructors of various PSA Crypto types.
77
78    This includes macro names as well as information about their arguments
79    when applicable.
80
81    This class only provides ways to enumerate expressions that evaluate to
82    values of the covered types. Derived classes are expected to populate
83    the set of known constructors of each kind, as well as populate
84    `self.arguments_for` for arguments that are not of a kind that is
85    enumerated here.
86    """
87    #pylint: disable=too-many-instance-attributes
88
89    def __init__(self) -> None:
90        """Set up an empty set of known constructor macros.
91        """
92        self.statuses = set() #type: Set[str]
93        self.lifetimes = set() #type: Set[str]
94        self.locations = set() #type: Set[str]
95        self.persistence_levels = set() #type: Set[str]
96        self.algorithms = set() #type: Set[str]
97        self.ecc_curves = set() #type: Set[str]
98        self.dh_groups = set() #type: Set[str]
99        self.key_types = set() #type: Set[str]
100        self.key_usage_flags = set() #type: Set[str]
101        self.hash_algorithms = set() #type: Set[str]
102        self.mac_algorithms = set() #type: Set[str]
103        self.ka_algorithms = set() #type: Set[str]
104        self.kdf_algorithms = set() #type: Set[str]
105        self.pake_algorithms = set() #type: Set[str]
106        self.aead_algorithms = set() #type: Set[str]
107        self.sign_algorithms = set() #type: Set[str]
108        # macro name -> list of argument names
109        self.argspecs = {} #type: Dict[str, List[str]]
110        # argument name -> list of values
111        self.arguments_for = {
112            'mac_length': [],
113            'min_mac_length': [],
114            'tag_length': [],
115            'min_tag_length': [],
116        } #type: Dict[str, List[str]]
117        # Whether to include intermediate macros in enumerations. Intermediate
118        # macros serve as category headers and are not valid values of their
119        # type. See `is_internal_name`.
120        # Always false in this class, may be set to true in derived classes.
121        self.include_intermediate = False
122
123    def is_internal_name(self, name: str) -> bool:
124        """Whether this is an internal macro. Internal macros will be skipped."""
125        if not self.include_intermediate:
126            if name.endswith('_BASE') or name.endswith('_NONE'):
127                return True
128            if '_CATEGORY_' in name:
129                return True
130        return name.endswith('_FLAG') or name.endswith('_MASK')
131
132    def gather_arguments(self) -> None:
133        """Populate the list of values for macro arguments.
134
135        Call this after parsing all the inputs.
136        """
137        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
138        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
139        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
140        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
141        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
142        self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
143        self.arguments_for['curve'] = sorted(self.ecc_curves)
144        self.arguments_for['group'] = sorted(self.dh_groups)
145        self.arguments_for['persistence'] = sorted(self.persistence_levels)
146        self.arguments_for['location'] = sorted(self.locations)
147        self.arguments_for['lifetime'] = sorted(self.lifetimes)
148
149    @staticmethod
150    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
151        """Format a macro call with arguments.
152
153        The resulting format is consistent with
154        `InputsForTest.normalize_argument`.
155        """
156        return name + '(' + ', '.join(arguments) + ')'
157
158    _argument_split_re = re.compile(r' *, *')
159    @classmethod
160    def _argument_split(cls, arguments: str) -> List[str]:
161        return re.split(cls._argument_split_re, arguments)
162
163    def distribute_arguments(self, name: str) -> Iterator[str]:
164        """Generate macro calls with each tested argument set.
165
166        If name is a macro without arguments, just yield "name".
167        If name is a macro with arguments, yield a series of
168        "name(arg1,...,argN)" where each argument takes each possible
169        value at least once.
170        """
171        try:
172            if name not in self.argspecs:
173                yield name
174                return
175            argspec = self.argspecs[name]
176            if argspec == []:
177                yield name + '()'
178                return
179            argument_lists = [self.arguments_for[arg] for arg in argspec]
180            arguments = [values[0] for values in argument_lists]
181            yield self._format_arguments(name, arguments)
182            # Dear Pylint, enumerate won't work here since we're modifying
183            # the array.
184            # pylint: disable=consider-using-enumerate
185            for i in range(len(arguments)):
186                for value in argument_lists[i][1:]:
187                    arguments[i] = value
188                    yield self._format_arguments(name, arguments)
189                arguments[i] = argument_lists[i][0]
190        except BaseException as e:
191            raise Exception('distribute_arguments({})'.format(name)) from e
192
193    def distribute_arguments_without_duplicates(
194            self, seen: Set[str], name: str
195    ) -> Iterator[str]:
196        """Same as `distribute_arguments`, but don't repeat seen results."""
197        for result in self.distribute_arguments(name):
198            if result not in seen:
199                seen.add(result)
200                yield result
201
202    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
203        """Generate expressions covering values constructed from the given names.
204
205        `names` can be any iterable collection of macro names.
206
207        For example:
208        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
209          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
210          every known hash algorithm ``h``.
211        * ``macros.generate_expressions(macros.key_types)`` generates all
212          key types.
213        """
214        seen = set() #type: Set[str]
215        return itertools.chain(*(
216            self.distribute_arguments_without_duplicates(seen, name)
217            for name in names
218        ))
219
220
221class PSAMacroCollector(PSAMacroEnumerator):
222    """Collect PSA crypto macro definitions from C header files.
223    """
224
225    def __init__(self, include_intermediate: bool = False) -> None:
226        """Set up an object to collect PSA macro definitions.
227
228        Call the read_file method of the constructed object on each header file.
229
230        * include_intermediate: if true, include intermediate macros such as
231          PSA_XXX_BASE that do not designate semantic values.
232        """
233        super().__init__()
234        self.include_intermediate = include_intermediate
235        self.key_types_from_curve = {} #type: Dict[str, str]
236        self.key_types_from_group = {} #type: Dict[str, str]
237        self.algorithms_from_hash = {} #type: Dict[str, str]
238
239    @staticmethod
240    def algorithm_tester(name: str) -> str:
241        """The predicate for whether an algorithm is built from the given constructor.
242
243        The given name must be the name of an algorithm constructor of the
244        form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build
245        an algorithm value. Return the corresponding predicate macro which
246        is used as ``predicate(alg)`` to test whether ``alg`` can be built
247        as ``PSA_ALG_xxx(yyy)``. The predicate is usually called
248        ``PSA_ALG_IS_xxx``.
249        """
250        prefix = 'PSA_ALG_'
251        assert name.startswith(prefix)
252        midfix = 'IS_'
253        suffix = name[len(prefix):]
254        if suffix in ['DSA', 'ECDSA']:
255            midfix += 'RANDOMIZED_'
256        elif suffix == 'RSA_PSS':
257            suffix += '_STANDARD_SALT'
258        return prefix + midfix + suffix
259
260    def record_algorithm_subtype(self, name: str, expansion: str) -> None:
261        """Record the subtype of an algorithm constructor.
262
263        Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
264        is of a subtype that is tracked in its own set, add it to the relevant
265        set.
266        """
267        # This code is very ad hoc and fragile. It should be replaced by
268        # something more robust.
269        if re.match(r'MAC(?:_|\Z)', name):
270            self.mac_algorithms.add(name)
271        elif re.match(r'KDF(?:_|\Z)', name):
272            self.kdf_algorithms.add(name)
273        elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
274            self.hash_algorithms.add(name)
275        elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
276            self.mac_algorithms.add(name)
277        elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
278            self.aead_algorithms.add(name)
279        elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
280            self.ka_algorithms.add(name)
281        elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
282            self.kdf_algorithms.add(name)
283
284    # "#define" followed by a macro name with either no parameters
285    # or a single parameter and a non-empty expansion.
286    # Grab the macro name in group 1, the parameter name if any in group 2
287    # and the expansion in group 3.
288    _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
289                                      r'(?:\s+|\((\w+)\)\s*)' +
290                                      r'(.+)')
291    _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
292
293    def read_line(self, line):
294        """Parse a C header line and record the PSA identifier it defines if any.
295        This function analyzes lines that start with "#define PSA_"
296        (up to non-significant whitespace) and skips all non-matching lines.
297        """
298        # pylint: disable=too-many-branches
299        m = re.match(self._define_directive_re, line)
300        if not m:
301            return
302        name, parameter, expansion = m.groups()
303        expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
304        if parameter:
305            self.argspecs[name] = [parameter]
306        if re.match(self._deprecated_definition_re, expansion):
307            # Skip deprecated values, which are assumed to be
308            # backward compatibility aliases that share
309            # numerical values with non-deprecated values.
310            return
311        if self.is_internal_name(name):
312            # Macro only to build actual values
313            return
314        elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
315           and not parameter:
316            self.statuses.add(name)
317        elif name.startswith('PSA_KEY_TYPE_') and not parameter:
318            self.key_types.add(name)
319        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
320            self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
321        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
322            self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
323        elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
324            self.ecc_curves.add(name)
325        elif name.startswith('PSA_DH_FAMILY_') and not parameter:
326            self.dh_groups.add(name)
327        elif name.startswith('PSA_ALG_') and not parameter:
328            if name in ['PSA_ALG_ECDSA_BASE',
329                        'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
330                # Ad hoc skipping of duplicate names for some numerical values
331                return
332            self.algorithms.add(name)
333            self.record_algorithm_subtype(name, expansion)
334        elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
335            self.algorithms_from_hash[name] = self.algorithm_tester(name)
336        elif name.startswith('PSA_KEY_USAGE_') and not parameter:
337            self.key_usage_flags.add(name)
338        else:
339            # Other macro without parameter
340            return
341
342    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
343    _continued_line_re = re.compile(rb'\\\r?\n\Z')
344    def read_file(self, header_file):
345        for line in header_file:
346            m = re.search(self._continued_line_re, line)
347            while m:
348                cont = next(header_file)
349                line = line[:m.start(0)] + cont
350                m = re.search(self._continued_line_re, line)
351            line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
352            self.read_line(line)
353
354
355class InputsForTest(PSAMacroEnumerator):
356    # pylint: disable=too-many-instance-attributes
357    """Accumulate information about macros to test.
358enumerate
359    This includes macro names as well as information about their arguments
360    when applicable.
361    """
362
363    def __init__(self) -> None:
364        super().__init__()
365        self.all_declared = set() #type: Set[str]
366        # Identifier prefixes
367        self.table_by_prefix = {
368            'ERROR': self.statuses,
369            'ALG': self.algorithms,
370            'ECC_CURVE': self.ecc_curves,
371            'DH_GROUP': self.dh_groups,
372            'KEY_LIFETIME': self.lifetimes,
373            'KEY_LOCATION': self.locations,
374            'KEY_PERSISTENCE': self.persistence_levels,
375            'KEY_TYPE': self.key_types,
376            'KEY_USAGE': self.key_usage_flags,
377        } #type: Dict[str, Set[str]]
378        # Test functions
379        self.table_by_test_function = {
380            # Any function ending in _algorithm also gets added to
381            # self.algorithms.
382            'key_type': [self.key_types],
383            'block_cipher_key_type': [self.key_types],
384            'stream_cipher_key_type': [self.key_types],
385            'ecc_key_family': [self.ecc_curves],
386            'ecc_key_types': [self.ecc_curves],
387            'dh_key_family': [self.dh_groups],
388            'dh_key_types': [self.dh_groups],
389            'hash_algorithm': [self.hash_algorithms],
390            'mac_algorithm': [self.mac_algorithms],
391            'cipher_algorithm': [],
392            'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
393            'aead_algorithm': [self.aead_algorithms],
394            'key_derivation_algorithm': [self.kdf_algorithms],
395            'key_agreement_algorithm': [self.ka_algorithms],
396            'asymmetric_signature_algorithm': [self.sign_algorithms],
397            'asymmetric_signature_wildcard': [self.algorithms],
398            'asymmetric_encryption_algorithm': [],
399            'pake_algorithm': [self.pake_algorithms],
400            'other_algorithm': [],
401            'lifetime': [self.lifetimes],
402        } #type: Dict[str, List[Set[str]]]
403        mac_lengths = [str(n) for n in [
404            1,  # minimum expressible
405            4,  # minimum allowed by policy
406            13, # an odd size in a plausible range
407            14, # an even non-power-of-two size in a plausible range
408            16, # same as full size for at least one algorithm
409            63, # maximum expressible
410        ]]
411        self.arguments_for['mac_length'] += mac_lengths
412        self.arguments_for['min_mac_length'] += mac_lengths
413        aead_lengths = [str(n) for n in [
414            1,  # minimum expressible
415            4,  # minimum allowed by policy
416            13, # an odd size in a plausible range
417            14, # an even non-power-of-two size in a plausible range
418            16, # same as full size for at least one algorithm
419            63, # maximum expressible
420        ]]
421        self.arguments_for['tag_length'] += aead_lengths
422        self.arguments_for['min_tag_length'] += aead_lengths
423
424    def add_numerical_values(self) -> None:
425        """Add numerical values that are not supported to the known identifiers."""
426        # Sets of names per type
427        self.algorithms.add('0xffffffff')
428        self.ecc_curves.add('0xff')
429        self.dh_groups.add('0xff')
430        self.key_types.add('0xffff')
431        self.key_usage_flags.add('0x80000000')
432
433        # Hard-coded values for unknown algorithms
434        #
435        # These have to have values that are correct for their respective
436        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
437        # not likely to be assigned in the near future.
438        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
439        self.mac_algorithms.add('0x03007fff')
440        self.ka_algorithms.add('0x09fc0000')
441        self.kdf_algorithms.add('0x080000ff')
442        self.pake_algorithms.add('0x0a0000ff')
443        # For AEAD algorithms, the only variability is over the tag length,
444        # and this only applies to known algorithms, so don't test an
445        # unknown algorithm.
446
447    def get_names(self, type_word: str) -> Set[str]:
448        """Return the set of known names of values of the given type."""
449        return {
450            'status': self.statuses,
451            'algorithm': self.algorithms,
452            'ecc_curve': self.ecc_curves,
453            'dh_group': self.dh_groups,
454            'key_type': self.key_types,
455            'key_usage': self.key_usage_flags,
456        }[type_word]
457
458    # Regex for interesting header lines.
459    # Groups: 1=macro name, 2=type, 3=argument list (optional).
460    _header_line_re = \
461        re.compile(r'#define +' +
462                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
463                   r'(?:\(([^\n()]*)\))?')
464    # Regex of macro names to exclude.
465    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
466    # Additional excluded macros.
467    _excluded_names = set([
468        # Macros that provide an alternative way to build the same
469        # algorithm as another macro.
470        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
471        'PSA_ALG_FULL_LENGTH_MAC',
472        # Auxiliary macro whose name doesn't fit the usual patterns for
473        # auxiliary macros.
474        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
475    ])
476    def parse_header_line(self, line: str) -> None:
477        """Parse a C header line, looking for "#define PSA_xxx"."""
478        m = re.match(self._header_line_re, line)
479        if not m:
480            return
481        name = m.group(1)
482        self.all_declared.add(name)
483        if re.search(self._excluded_name_re, name) or \
484           name in self._excluded_names or \
485           self.is_internal_name(name):
486            return
487        dest = self.table_by_prefix.get(m.group(2))
488        if dest is None:
489            return
490        dest.add(name)
491        if m.group(3):
492            self.argspecs[name] = self._argument_split(m.group(3))
493
494    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
495    def parse_header(self, filename: str) -> None:
496        """Parse a C header file, looking for "#define PSA_xxx"."""
497        with read_file_lines(filename, binary=True) as lines:
498            for line in lines:
499                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
500                self.parse_header_line(line)
501
502    _macro_identifier_re = re.compile(r'[A-Z]\w+')
503    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
504        for name in re.findall(self._macro_identifier_re, expr):
505            if name not in self.all_declared:
506                yield name
507
508    def accept_test_case_line(self, function: str, argument: str) -> bool:
509        #pylint: disable=unused-argument
510        undeclared = list(self.generate_undeclared_names(argument))
511        if undeclared:
512            raise Exception('Undeclared names in test case', undeclared)
513        return True
514
515    @staticmethod
516    def normalize_argument(argument: str) -> str:
517        """Normalize whitespace in the given C expression.
518
519        The result uses the same whitespace as
520        ` PSAMacroEnumerator.distribute_arguments`.
521        """
522        return re.sub(r',', r', ', re.sub(r' +', r'', argument))
523
524    def add_test_case_line(self, function: str, argument: str) -> None:
525        """Parse a test case data line, looking for algorithm metadata tests."""
526        sets = []
527        if function.endswith('_algorithm'):
528            sets.append(self.algorithms)
529            if function == 'key_agreement_algorithm' and \
530               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
531                # We only want *raw* key agreement algorithms as such, so
532                # exclude ones that are already chained with a KDF.
533                # Keep the expression as one to test as an algorithm.
534                function = 'other_algorithm'
535        sets += self.table_by_test_function[function]
536        if self.accept_test_case_line(function, argument):
537            for s in sets:
538                s.add(self.normalize_argument(argument))
539
540    # Regex matching a *.data line containing a test function call and
541    # its arguments. The actual definition is partly positional, but this
542    # regex is good enough in practice.
543    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
544    def parse_test_cases(self, filename: str) -> None:
545        """Parse a test case file (*.data), looking for algorithm metadata tests."""
546        with read_file_lines(filename) as lines:
547            for line in lines:
548                m = re.match(self._test_case_line_re, line)
549                if m:
550                    self.add_test_case_line(m.group(1), m.group(2))
551