1"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
19import itertools
20import re
21from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union
22
23
24class ReadFileLineException(Exception):
25    def __init__(self, filename: str, line_number: Union[int, str]) -> None:
26        message = 'in {} at {}'.format(filename, line_number)
27        super(ReadFileLineException, self).__init__(message)
28        self.filename = filename
29        self.line_number = line_number
30
31
32class read_file_lines:
33    # Dear Pylint, conventionally, a context manager class name is lowercase.
34    # pylint: disable=invalid-name,too-few-public-methods
35    """Context manager to read a text file line by line.
36
37    ```
38    with read_file_lines(filename) as lines:
39        for line in lines:
40            process(line)
41    ```
42    is equivalent to
43    ```
44    with open(filename, 'r') as input_file:
45        for line in input_file:
46            process(line)
47    ```
48    except that if process(line) raises an exception, then the read_file_lines
49    snippet annotates the exception with the file name and line number.
50    """
51    def __init__(self, filename: str, binary: bool = False) -> None:
52        self.filename = filename
53        self.line_number = 'entry' #type: Union[int, str]
54        self.generator = None #type: Optional[Iterable[Tuple[int, str]]]
55        self.binary = binary
56    def __enter__(self) -> 'read_file_lines':
57        self.generator = enumerate(open(self.filename,
58                                        'rb' if self.binary else 'r'))
59        return self
60    def __iter__(self) -> Iterator[str]:
61        assert self.generator is not None
62        for line_number, content in self.generator:
63            self.line_number = line_number
64            yield content
65        self.line_number = 'exit'
66    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
67        if exc_type is not None:
68            raise ReadFileLineException(self.filename, self.line_number) \
69                from exc_value
70
71
72class PSAMacroEnumerator:
73    """Information about constructors of various PSA Crypto types.
74
75    This includes macro names as well as information about their arguments
76    when applicable.
77
78    This class only provides ways to enumerate expressions that evaluate to
79    values of the covered types. Derived classes are expected to populate
80    the set of known constructors of each kind, as well as populate
81    `self.arguments_for` for arguments that are not of a kind that is
82    enumerated here.
83    """
84    #pylint: disable=too-many-instance-attributes
85
86    def __init__(self) -> None:
87        """Set up an empty set of known constructor macros.
88        """
89        self.statuses = set() #type: Set[str]
90        self.lifetimes = set() #type: Set[str]
91        self.locations = set() #type: Set[str]
92        self.persistence_levels = set() #type: Set[str]
93        self.algorithms = set() #type: Set[str]
94        self.ecc_curves = set() #type: Set[str]
95        self.dh_groups = set() #type: Set[str]
96        self.key_types = set() #type: Set[str]
97        self.key_usage_flags = set() #type: Set[str]
98        self.hash_algorithms = set() #type: Set[str]
99        self.mac_algorithms = set() #type: Set[str]
100        self.ka_algorithms = set() #type: Set[str]
101        self.kdf_algorithms = set() #type: Set[str]
102        self.aead_algorithms = set() #type: Set[str]
103        self.sign_algorithms = set() #type: Set[str]
104        # macro name -> list of argument names
105        self.argspecs = {} #type: Dict[str, List[str]]
106        # argument name -> list of values
107        self.arguments_for = {
108            'mac_length': [],
109            'min_mac_length': [],
110            'tag_length': [],
111            'min_tag_length': [],
112        } #type: Dict[str, List[str]]
113        # Whether to include intermediate macros in enumerations. Intermediate
114        # macros serve as category headers and are not valid values of their
115        # type. See `is_internal_name`.
116        # Always false in this class, may be set to true in derived classes.
117        self.include_intermediate = False
118
119    def is_internal_name(self, name: str) -> bool:
120        """Whether this is an internal macro. Internal macros will be skipped."""
121        if not self.include_intermediate:
122            if name.endswith('_BASE') or name.endswith('_NONE'):
123                return True
124            if '_CATEGORY_' in name:
125                return True
126        return name.endswith('_FLAG') or name.endswith('_MASK')
127
128    def gather_arguments(self) -> None:
129        """Populate the list of values for macro arguments.
130
131        Call this after parsing all the inputs.
132        """
133        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
134        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
135        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
136        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
137        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
138        self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
139        self.arguments_for['curve'] = sorted(self.ecc_curves)
140        self.arguments_for['group'] = sorted(self.dh_groups)
141        self.arguments_for['persistence'] = sorted(self.persistence_levels)
142        self.arguments_for['location'] = sorted(self.locations)
143        self.arguments_for['lifetime'] = sorted(self.lifetimes)
144
145    @staticmethod
146    def _format_arguments(name: str, arguments: Iterable[str]) -> str:
147        """Format a macro call with arguments.
148
149        The resulting format is consistent with
150        `InputsForTest.normalize_argument`.
151        """
152        return name + '(' + ', '.join(arguments) + ')'
153
154    _argument_split_re = re.compile(r' *, *')
155    @classmethod
156    def _argument_split(cls, arguments: str) -> List[str]:
157        return re.split(cls._argument_split_re, arguments)
158
159    def distribute_arguments(self, name: str) -> Iterator[str]:
160        """Generate macro calls with each tested argument set.
161
162        If name is a macro without arguments, just yield "name".
163        If name is a macro with arguments, yield a series of
164        "name(arg1,...,argN)" where each argument takes each possible
165        value at least once.
166        """
167        try:
168            if name not in self.argspecs:
169                yield name
170                return
171            argspec = self.argspecs[name]
172            if argspec == []:
173                yield name + '()'
174                return
175            argument_lists = [self.arguments_for[arg] for arg in argspec]
176            arguments = [values[0] for values in argument_lists]
177            yield self._format_arguments(name, arguments)
178            # Dear Pylint, enumerate won't work here since we're modifying
179            # the array.
180            # pylint: disable=consider-using-enumerate
181            for i in range(len(arguments)):
182                for value in argument_lists[i][1:]:
183                    arguments[i] = value
184                    yield self._format_arguments(name, arguments)
185                arguments[i] = argument_lists[0][0]
186        except BaseException as e:
187            raise Exception('distribute_arguments({})'.format(name)) from e
188
189    def distribute_arguments_without_duplicates(
190            self, seen: Set[str], name: str
191    ) -> Iterator[str]:
192        """Same as `distribute_arguments`, but don't repeat seen results."""
193        for result in self.distribute_arguments(name):
194            if result not in seen:
195                seen.add(result)
196                yield result
197
198    def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
199        """Generate expressions covering values constructed from the given names.
200
201        `names` can be any iterable collection of macro names.
202
203        For example:
204        * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
205          generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
206          every known hash algorithm ``h``.
207        * ``macros.generate_expressions(macros.key_types)`` generates all
208          key types.
209        """
210        seen = set() #type: Set[str]
211        return itertools.chain(*(
212            self.distribute_arguments_without_duplicates(seen, name)
213            for name in names
214        ))
215
216
217class PSAMacroCollector(PSAMacroEnumerator):
218    """Collect PSA crypto macro definitions from C header files.
219    """
220
221    def __init__(self, include_intermediate: bool = False) -> None:
222        """Set up an object to collect PSA macro definitions.
223
224        Call the read_file method of the constructed object on each header file.
225
226        * include_intermediate: if true, include intermediate macros such as
227          PSA_XXX_BASE that do not designate semantic values.
228        """
229        super().__init__()
230        self.include_intermediate = include_intermediate
231        self.key_types_from_curve = {} #type: Dict[str, str]
232        self.key_types_from_group = {} #type: Dict[str, str]
233        self.algorithms_from_hash = {} #type: Dict[str, str]
234
235    @staticmethod
236    def algorithm_tester(name: str) -> str:
237        """The predicate for whether an algorithm is built from the given constructor.
238
239        The given name must be the name of an algorithm constructor of the
240        form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build
241        an algorithm value. Return the corresponding predicate macro which
242        is used as ``predicate(alg)`` to test whether ``alg`` can be built
243        as ``PSA_ALG_xxx(yyy)``. The predicate is usually called
244        ``PSA_ALG_IS_xxx``.
245        """
246        prefix = 'PSA_ALG_'
247        assert name.startswith(prefix)
248        midfix = 'IS_'
249        suffix = name[len(prefix):]
250        if suffix in ['DSA', 'ECDSA']:
251            midfix += 'RANDOMIZED_'
252        elif suffix == 'RSA_PSS':
253            suffix += '_STANDARD_SALT'
254        return prefix + midfix + suffix
255
256    def record_algorithm_subtype(self, name: str, expansion: str) -> None:
257        """Record the subtype of an algorithm constructor.
258
259        Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
260        is of a subtype that is tracked in its own set, add it to the relevant
261        set.
262        """
263        # This code is very ad hoc and fragile. It should be replaced by
264        # something more robust.
265        if re.match(r'MAC(?:_|\Z)', name):
266            self.mac_algorithms.add(name)
267        elif re.match(r'KDF(?:_|\Z)', name):
268            self.kdf_algorithms.add(name)
269        elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
270            self.hash_algorithms.add(name)
271        elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
272            self.mac_algorithms.add(name)
273        elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
274            self.aead_algorithms.add(name)
275        elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
276            self.ka_algorithms.add(name)
277        elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
278            self.kdf_algorithms.add(name)
279
280    # "#define" followed by a macro name with either no parameters
281    # or a single parameter and a non-empty expansion.
282    # Grab the macro name in group 1, the parameter name if any in group 2
283    # and the expansion in group 3.
284    _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
285                                      r'(?:\s+|\((\w+)\)\s*)' +
286                                      r'(.+)')
287    _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
288
289    def read_line(self, line):
290        """Parse a C header line and record the PSA identifier it defines if any.
291        This function analyzes lines that start with "#define PSA_"
292        (up to non-significant whitespace) and skips all non-matching lines.
293        """
294        # pylint: disable=too-many-branches
295        m = re.match(self._define_directive_re, line)
296        if not m:
297            return
298        name, parameter, expansion = m.groups()
299        expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
300        if parameter:
301            self.argspecs[name] = [parameter]
302        if re.match(self._deprecated_definition_re, expansion):
303            # Skip deprecated values, which are assumed to be
304            # backward compatibility aliases that share
305            # numerical values with non-deprecated values.
306            return
307        if self.is_internal_name(name):
308            # Macro only to build actual values
309            return
310        elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
311           and not parameter:
312            self.statuses.add(name)
313        elif name.startswith('PSA_KEY_TYPE_') and not parameter:
314            self.key_types.add(name)
315        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
316            self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
317        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
318            self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
319        elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
320            self.ecc_curves.add(name)
321        elif name.startswith('PSA_DH_FAMILY_') and not parameter:
322            self.dh_groups.add(name)
323        elif name.startswith('PSA_ALG_') and not parameter:
324            if name in ['PSA_ALG_ECDSA_BASE',
325                        'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
326                # Ad hoc skipping of duplicate names for some numerical values
327                return
328            self.algorithms.add(name)
329            self.record_algorithm_subtype(name, expansion)
330        elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
331            self.algorithms_from_hash[name] = self.algorithm_tester(name)
332        elif name.startswith('PSA_KEY_USAGE_') and not parameter:
333            self.key_usage_flags.add(name)
334        else:
335            # Other macro without parameter
336            return
337
338    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
339    _continued_line_re = re.compile(rb'\\\r?\n\Z')
340    def read_file(self, header_file):
341        for line in header_file:
342            m = re.search(self._continued_line_re, line)
343            while m:
344                cont = next(header_file)
345                line = line[:m.start(0)] + cont
346                m = re.search(self._continued_line_re, line)
347            line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
348            self.read_line(line)
349
350
351class InputsForTest(PSAMacroEnumerator):
352    # pylint: disable=too-many-instance-attributes
353    """Accumulate information about macros to test.
354enumerate
355    This includes macro names as well as information about their arguments
356    when applicable.
357    """
358
359    def __init__(self) -> None:
360        super().__init__()
361        self.all_declared = set() #type: Set[str]
362        # Identifier prefixes
363        self.table_by_prefix = {
364            'ERROR': self.statuses,
365            'ALG': self.algorithms,
366            'ECC_CURVE': self.ecc_curves,
367            'DH_GROUP': self.dh_groups,
368            'KEY_LIFETIME': self.lifetimes,
369            'KEY_LOCATION': self.locations,
370            'KEY_PERSISTENCE': self.persistence_levels,
371            'KEY_TYPE': self.key_types,
372            'KEY_USAGE': self.key_usage_flags,
373        } #type: Dict[str, Set[str]]
374        # Test functions
375        self.table_by_test_function = {
376            # Any function ending in _algorithm also gets added to
377            # self.algorithms.
378            'key_type': [self.key_types],
379            'block_cipher_key_type': [self.key_types],
380            'stream_cipher_key_type': [self.key_types],
381            'ecc_key_family': [self.ecc_curves],
382            'ecc_key_types': [self.ecc_curves],
383            'dh_key_family': [self.dh_groups],
384            'dh_key_types': [self.dh_groups],
385            'hash_algorithm': [self.hash_algorithms],
386            'mac_algorithm': [self.mac_algorithms],
387            'cipher_algorithm': [],
388            'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
389            'aead_algorithm': [self.aead_algorithms],
390            'key_derivation_algorithm': [self.kdf_algorithms],
391            'key_agreement_algorithm': [self.ka_algorithms],
392            'asymmetric_signature_algorithm': [self.sign_algorithms],
393            'asymmetric_signature_wildcard': [self.algorithms],
394            'asymmetric_encryption_algorithm': [],
395            'other_algorithm': [],
396            'lifetime': [self.lifetimes],
397        } #type: Dict[str, List[Set[str]]]
398        self.arguments_for['mac_length'] += ['1', '63']
399        self.arguments_for['min_mac_length'] += ['1', '63']
400        self.arguments_for['tag_length'] += ['1', '63']
401        self.arguments_for['min_tag_length'] += ['1', '63']
402
403    def add_numerical_values(self) -> None:
404        """Add numerical values that are not supported to the known identifiers."""
405        # Sets of names per type
406        self.algorithms.add('0xffffffff')
407        self.ecc_curves.add('0xff')
408        self.dh_groups.add('0xff')
409        self.key_types.add('0xffff')
410        self.key_usage_flags.add('0x80000000')
411
412        # Hard-coded values for unknown algorithms
413        #
414        # These have to have values that are correct for their respective
415        # PSA_ALG_IS_xxx macros, but are also not currently assigned and are
416        # not likely to be assigned in the near future.
417        self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH
418        self.mac_algorithms.add('0x03007fff')
419        self.ka_algorithms.add('0x09fc0000')
420        self.kdf_algorithms.add('0x080000ff')
421        # For AEAD algorithms, the only variability is over the tag length,
422        # and this only applies to known algorithms, so don't test an
423        # unknown algorithm.
424
425    def get_names(self, type_word: str) -> Set[str]:
426        """Return the set of known names of values of the given type."""
427        return {
428            'status': self.statuses,
429            'algorithm': self.algorithms,
430            'ecc_curve': self.ecc_curves,
431            'dh_group': self.dh_groups,
432            'key_type': self.key_types,
433            'key_usage': self.key_usage_flags,
434        }[type_word]
435
436    # Regex for interesting header lines.
437    # Groups: 1=macro name, 2=type, 3=argument list (optional).
438    _header_line_re = \
439        re.compile(r'#define +' +
440                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
441                   r'(?:\(([^\n()]*)\))?')
442    # Regex of macro names to exclude.
443    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
444    # Additional excluded macros.
445    _excluded_names = set([
446        # Macros that provide an alternative way to build the same
447        # algorithm as another macro.
448        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG',
449        'PSA_ALG_FULL_LENGTH_MAC',
450        # Auxiliary macro whose name doesn't fit the usual patterns for
451        # auxiliary macros.
452        'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE',
453    ])
454    def parse_header_line(self, line: str) -> None:
455        """Parse a C header line, looking for "#define PSA_xxx"."""
456        m = re.match(self._header_line_re, line)
457        if not m:
458            return
459        name = m.group(1)
460        self.all_declared.add(name)
461        if re.search(self._excluded_name_re, name) or \
462           name in self._excluded_names or \
463           self.is_internal_name(name):
464            return
465        dest = self.table_by_prefix.get(m.group(2))
466        if dest is None:
467            return
468        dest.add(name)
469        if m.group(3):
470            self.argspecs[name] = self._argument_split(m.group(3))
471
472    _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern
473    def parse_header(self, filename: str) -> None:
474        """Parse a C header file, looking for "#define PSA_xxx"."""
475        with read_file_lines(filename, binary=True) as lines:
476            for line in lines:
477                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
478                self.parse_header_line(line)
479
480    _macro_identifier_re = re.compile(r'[A-Z]\w+')
481    def generate_undeclared_names(self, expr: str) -> Iterable[str]:
482        for name in re.findall(self._macro_identifier_re, expr):
483            if name not in self.all_declared:
484                yield name
485
486    def accept_test_case_line(self, function: str, argument: str) -> bool:
487        #pylint: disable=unused-argument
488        undeclared = list(self.generate_undeclared_names(argument))
489        if undeclared:
490            raise Exception('Undeclared names in test case', undeclared)
491        return True
492
493    @staticmethod
494    def normalize_argument(argument: str) -> str:
495        """Normalize whitespace in the given C expression.
496
497        The result uses the same whitespace as
498        ` PSAMacroEnumerator.distribute_arguments`.
499        """
500        return re.sub(r',', r', ', re.sub(r' +', r'', argument))
501
502    def add_test_case_line(self, function: str, argument: str) -> None:
503        """Parse a test case data line, looking for algorithm metadata tests."""
504        sets = []
505        if function.endswith('_algorithm'):
506            sets.append(self.algorithms)
507            if function == 'key_agreement_algorithm' and \
508               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
509                # We only want *raw* key agreement algorithms as such, so
510                # exclude ones that are already chained with a KDF.
511                # Keep the expression as one to test as an algorithm.
512                function = 'other_algorithm'
513        sets += self.table_by_test_function[function]
514        if self.accept_test_case_line(function, argument):
515            for s in sets:
516                s.add(self.normalize_argument(argument))
517
518    # Regex matching a *.data line containing a test function call and
519    # its arguments. The actual definition is partly positional, but this
520    # regex is good enough in practice.
521    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
522    def parse_test_cases(self, filename: str) -> None:
523        """Parse a test case file (*.data), looking for algorithm metadata tests."""
524        with read_file_lines(filename) as lines:
525            for line in lines:
526                m = re.match(self._test_case_line_re, line)
527                if m:
528                    self.add_test_case_line(m.group(1), m.group(2))
529