1#!/usr/bin/env python3
2"""Test the program psa_constant_names.
3Gather constant names from header files and test cases. Compile a C program
4to print out their numerical values, feed these numerical values to
5psa_constant_names, and check that the output is the original name.
6Return 0 if all test cases pass, 1 if the output was not always as expected,
7or 1 (with a Python backtrace) if there was an operational error.
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0
12#
13# Licensed under the Apache License, Version 2.0 (the "License"); you may
14# not use this file except in compliance with the License.
15# You may obtain a copy of the License at
16#
17# http://www.apache.org/licenses/LICENSE-2.0
18#
19# Unless required by applicable law or agreed to in writing, software
20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22# See the License for the specific language governing permissions and
23# limitations under the License.
24
25import argparse
26from collections import namedtuple
27import itertools
28import os
29import platform
30import re
31import subprocess
32import sys
33import tempfile
34
35class ReadFileLineException(Exception):
36    def __init__(self, filename, line_number):
37        message = 'in {} at {}'.format(filename, line_number)
38        super(ReadFileLineException, self).__init__(message)
39        self.filename = filename
40        self.line_number = line_number
41
42class read_file_lines:
43    # Dear Pylint, conventionally, a context manager class name is lowercase.
44    # pylint: disable=invalid-name,too-few-public-methods
45    """Context manager to read a text file line by line.
46
47    ```
48    with read_file_lines(filename) as lines:
49        for line in lines:
50            process(line)
51    ```
52    is equivalent to
53    ```
54    with open(filename, 'r') as input_file:
55        for line in input_file:
56            process(line)
57    ```
58    except that if process(line) raises an exception, then the read_file_lines
59    snippet annotates the exception with the file name and line number.
60    """
61    def __init__(self, filename, binary=False):
62        self.filename = filename
63        self.line_number = 'entry'
64        self.generator = None
65        self.binary = binary
66    def __enter__(self):
67        self.generator = enumerate(open(self.filename,
68                                        'rb' if self.binary else 'r'))
69        return self
70    def __iter__(self):
71        for line_number, content in self.generator:
72            self.line_number = line_number
73            yield content
74        self.line_number = 'exit'
75    def __exit__(self, exc_type, exc_value, exc_traceback):
76        if exc_type is not None:
77            raise ReadFileLineException(self.filename, self.line_number) \
78                from exc_value
79
80class Inputs:
81    # pylint: disable=too-many-instance-attributes
82    """Accumulate information about macros to test.
83
84    This includes macro names as well as information about their arguments
85    when applicable.
86    """
87
88    def __init__(self):
89        self.all_declared = set()
90        # Sets of names per type
91        self.statuses = set(['PSA_SUCCESS'])
92        self.algorithms = set(['0xffffffff'])
93        self.ecc_curves = set(['0xff'])
94        self.dh_groups = set(['0xff'])
95        self.key_types = set(['0xffff'])
96        self.key_usage_flags = set(['0x80000000'])
97        # Hard-coded value for unknown algorithms
98        self.hash_algorithms = set(['0x020000fe'])
99        self.mac_algorithms = set(['0x0300ffff'])
100        self.ka_algorithms = set(['0x09fc0000'])
101        self.kdf_algorithms = set(['0x080000ff'])
102        # For AEAD algorithms, the only variability is over the tag length,
103        # and this only applies to known algorithms, so don't test an
104        # unknown algorithm.
105        self.aead_algorithms = set()
106        # Identifier prefixes
107        self.table_by_prefix = {
108            'ERROR': self.statuses,
109            'ALG': self.algorithms,
110            'ECC_CURVE': self.ecc_curves,
111            'DH_GROUP': self.dh_groups,
112            'KEY_TYPE': self.key_types,
113            'KEY_USAGE': self.key_usage_flags,
114        }
115        # Test functions
116        self.table_by_test_function = {
117            # Any function ending in _algorithm also gets added to
118            # self.algorithms.
119            'key_type': [self.key_types],
120            'block_cipher_key_type': [self.key_types],
121            'stream_cipher_key_type': [self.key_types],
122            'ecc_key_family': [self.ecc_curves],
123            'ecc_key_types': [self.ecc_curves],
124            'dh_key_family': [self.dh_groups],
125            'dh_key_types': [self.dh_groups],
126            'hash_algorithm': [self.hash_algorithms],
127            'mac_algorithm': [self.mac_algorithms],
128            'cipher_algorithm': [],
129            'hmac_algorithm': [self.mac_algorithms],
130            'aead_algorithm': [self.aead_algorithms],
131            'key_derivation_algorithm': [self.kdf_algorithms],
132            'key_agreement_algorithm': [self.ka_algorithms],
133            'asymmetric_signature_algorithm': [],
134            'asymmetric_signature_wildcard': [self.algorithms],
135            'asymmetric_encryption_algorithm': [],
136            'other_algorithm': [],
137        }
138        # macro name -> list of argument names
139        self.argspecs = {}
140        # argument name -> list of values
141        self.arguments_for = {
142            'mac_length': ['1', '63'],
143            'tag_length': ['1', '63'],
144        }
145
146    def get_names(self, type_word):
147        """Return the set of known names of values of the given type."""
148        return {
149            'status': self.statuses,
150            'algorithm': self.algorithms,
151            'ecc_curve': self.ecc_curves,
152            'dh_group': self.dh_groups,
153            'key_type': self.key_types,
154            'key_usage': self.key_usage_flags,
155        }[type_word]
156
157    def gather_arguments(self):
158        """Populate the list of values for macro arguments.
159
160        Call this after parsing all the inputs.
161        """
162        self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
163        self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
164        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
165        self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
166        self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
167        self.arguments_for['curve'] = sorted(self.ecc_curves)
168        self.arguments_for['group'] = sorted(self.dh_groups)
169
170    @staticmethod
171    def _format_arguments(name, arguments):
172        """Format a macro call with arguments.."""
173        return name + '(' + ', '.join(arguments) + ')'
174
175    def distribute_arguments(self, name):
176        """Generate macro calls with each tested argument set.
177
178        If name is a macro without arguments, just yield "name".
179        If name is a macro with arguments, yield a series of
180        "name(arg1,...,argN)" where each argument takes each possible
181        value at least once.
182        """
183        try:
184            if name not in self.argspecs:
185                yield name
186                return
187            argspec = self.argspecs[name]
188            if argspec == []:
189                yield name + '()'
190                return
191            argument_lists = [self.arguments_for[arg] for arg in argspec]
192            arguments = [values[0] for values in argument_lists]
193            yield self._format_arguments(name, arguments)
194            # Dear Pylint, enumerate won't work here since we're modifying
195            # the array.
196            # pylint: disable=consider-using-enumerate
197            for i in range(len(arguments)):
198                for value in argument_lists[i][1:]:
199                    arguments[i] = value
200                    yield self._format_arguments(name, arguments)
201                arguments[i] = argument_lists[0][0]
202        except BaseException as e:
203            raise Exception('distribute_arguments({})'.format(name)) from e
204
205    def generate_expressions(self, names):
206        return itertools.chain(*map(self.distribute_arguments, names))
207
208    _argument_split_re = re.compile(r' *, *')
209    @classmethod
210    def _argument_split(cls, arguments):
211        return re.split(cls._argument_split_re, arguments)
212
213    # Regex for interesting header lines.
214    # Groups: 1=macro name, 2=type, 3=argument list (optional).
215    _header_line_re = \
216        re.compile(r'#define +' +
217                   r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' +
218                   r'(?:\(([^\n()]*)\))?')
219    # Regex of macro names to exclude.
220    _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z')
221    # Additional excluded macros.
222    _excluded_names = set([
223        # Macros that provide an alternative way to build the same
224        # algorithm as another macro.
225        'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH',
226        'PSA_ALG_FULL_LENGTH_MAC',
227        # Auxiliary macro whose name doesn't fit the usual patterns for
228        # auxiliary macros.
229        'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE',
230    ])
231    def parse_header_line(self, line):
232        """Parse a C header line, looking for "#define PSA_xxx"."""
233        m = re.match(self._header_line_re, line)
234        if not m:
235            return
236        name = m.group(1)
237        self.all_declared.add(name)
238        if re.search(self._excluded_name_re, name) or \
239           name in self._excluded_names:
240            return
241        dest = self.table_by_prefix.get(m.group(2))
242        if dest is None:
243            return
244        dest.add(name)
245        if m.group(3):
246            self.argspecs[name] = self._argument_split(m.group(3))
247
248    _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
249    def parse_header(self, filename):
250        """Parse a C header file, looking for "#define PSA_xxx"."""
251        with read_file_lines(filename, binary=True) as lines:
252            for line in lines:
253                line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
254                self.parse_header_line(line)
255
256    _macro_identifier_re = re.compile(r'[A-Z]\w+')
257    def generate_undeclared_names(self, expr):
258        for name in re.findall(self._macro_identifier_re, expr):
259            if name not in self.all_declared:
260                yield name
261
262    def accept_test_case_line(self, function, argument):
263        #pylint: disable=unused-argument
264        undeclared = list(self.generate_undeclared_names(argument))
265        if undeclared:
266            raise Exception('Undeclared names in test case', undeclared)
267        return True
268
269    def add_test_case_line(self, function, argument):
270        """Parse a test case data line, looking for algorithm metadata tests."""
271        sets = []
272        if function.endswith('_algorithm'):
273            sets.append(self.algorithms)
274            if function == 'key_agreement_algorithm' and \
275               argument.startswith('PSA_ALG_KEY_AGREEMENT('):
276                # We only want *raw* key agreement algorithms as such, so
277                # exclude ones that are already chained with a KDF.
278                # Keep the expression as one to test as an algorithm.
279                function = 'other_algorithm'
280        sets += self.table_by_test_function[function]
281        if self.accept_test_case_line(function, argument):
282            for s in sets:
283                s.add(argument)
284
285    # Regex matching a *.data line containing a test function call and
286    # its arguments. The actual definition is partly positional, but this
287    # regex is good enough in practice.
288    _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)')
289    def parse_test_cases(self, filename):
290        """Parse a test case file (*.data), looking for algorithm metadata tests."""
291        with read_file_lines(filename) as lines:
292            for line in lines:
293                m = re.match(self._test_case_line_re, line)
294                if m:
295                    self.add_test_case_line(m.group(1), m.group(2))
296
297def gather_inputs(headers, test_suites, inputs_class=Inputs):
298    """Read the list of inputs to test psa_constant_names with."""
299    inputs = inputs_class()
300    for header in headers:
301        inputs.parse_header(header)
302    for test_cases in test_suites:
303        inputs.parse_test_cases(test_cases)
304    inputs.gather_arguments()
305    return inputs
306
307def remove_file_if_exists(filename):
308    """Remove the specified file, ignoring errors."""
309    if not filename:
310        return
311    try:
312        os.remove(filename)
313    except OSError:
314        pass
315
316def run_c(type_word, expressions, include_path=None, keep_c=False):
317    """Generate and run a program to print out numerical values for expressions."""
318    if include_path is None:
319        include_path = []
320    if type_word == 'status':
321        cast_to = 'long'
322        printf_format = '%ld'
323    else:
324        cast_to = 'unsigned long'
325        printf_format = '0x%08lx'
326    c_name = None
327    exe_name = None
328    try:
329        c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type_word),
330                                        suffix='.c',
331                                        dir='programs/psa')
332        exe_suffix = '.exe' if platform.system() == 'Windows' else ''
333        exe_name = c_name[:-2] + exe_suffix
334        remove_file_if_exists(exe_name)
335        c_file = os.fdopen(c_fd, 'w', encoding='ascii')
336        c_file.write('/* Generated by test_psa_constant_names.py for {} values */'
337                     .format(type_word))
338        c_file.write('''
339#include <stdio.h>
340#include <psa/crypto.h>
341int main(void)
342{
343''')
344        for expr in expressions:
345            c_file.write('    printf("{}\\n", ({}) {});\n'
346                         .format(printf_format, cast_to, expr))
347        c_file.write('''    return 0;
348}
349''')
350        c_file.close()
351        cc = os.getenv('CC', 'cc')
352        subprocess.check_call([cc] +
353                              ['-I' + dir for dir in include_path] +
354                              ['-o', exe_name, c_name])
355        if keep_c:
356            sys.stderr.write('List of {} tests kept at {}\n'
357                             .format(type_word, c_name))
358        else:
359            os.remove(c_name)
360        output = subprocess.check_output([exe_name])
361        return output.decode('ascii').strip().split('\n')
362    finally:
363        remove_file_if_exists(exe_name)
364
365NORMALIZE_STRIP_RE = re.compile(r'\s+')
366def normalize(expr):
367    """Normalize the C expression so as not to care about trivial differences.
368
369    Currently "trivial differences" means whitespace.
370    """
371    return re.sub(NORMALIZE_STRIP_RE, '', expr)
372
373def collect_values(inputs, type_word, include_path=None, keep_c=False):
374    """Generate expressions using known macro names and calculate their values.
375
376    Return a list of pairs of (expr, value) where expr is an expression and
377    value is a string representation of its integer value.
378    """
379    names = inputs.get_names(type_word)
380    expressions = sorted(inputs.generate_expressions(names))
381    values = run_c(type_word, expressions,
382                   include_path=include_path, keep_c=keep_c)
383    return expressions, values
384
385class Tests:
386    """An object representing tests and their results."""
387
388    Error = namedtuple('Error',
389                       ['type', 'expression', 'value', 'output'])
390
391    def __init__(self, options):
392        self.options = options
393        self.count = 0
394        self.errors = []
395
396    def run_one(self, inputs, type_word):
397        """Test psa_constant_names for the specified type.
398
399        Run the program on the names for this type.
400        Use the inputs to figure out what arguments to pass to macros that
401        take arguments.
402        """
403        expressions, values = collect_values(inputs, type_word,
404                                             include_path=self.options.include,
405                                             keep_c=self.options.keep_c)
406        output = subprocess.check_output([self.options.program, type_word] +
407                                         values)
408        outputs = output.decode('ascii').strip().split('\n')
409        self.count += len(expressions)
410        for expr, value, output in zip(expressions, values, outputs):
411            if self.options.show:
412                sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
413            if normalize(expr) != normalize(output):
414                self.errors.append(self.Error(type=type_word,
415                                              expression=expr,
416                                              value=value,
417                                              output=output))
418
419    def run_all(self, inputs):
420        """Run psa_constant_names on all the gathered inputs."""
421        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
422                          'key_type', 'key_usage']:
423            self.run_one(inputs, type_word)
424
425    def report(self, out):
426        """Describe each case where the output is not as expected.
427
428        Write the errors to ``out``.
429        Also write a total.
430        """
431        for error in self.errors:
432            out.write('For {} "{}", got "{}" (value: {})\n'
433                      .format(error.type, error.expression,
434                              error.output, error.value))
435        out.write('{} test cases'.format(self.count))
436        if self.errors:
437            out.write(', {} FAIL\n'.format(len(self.errors)))
438        else:
439            out.write(' PASS\n')
440
441HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
442TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
443
444def main():
445    parser = argparse.ArgumentParser(description=globals()['__doc__'])
446    parser.add_argument('--include', '-I',
447                        action='append', default=['include'],
448                        help='Directory for header files')
449    parser.add_argument('--keep-c',
450                        action='store_true', dest='keep_c', default=False,
451                        help='Keep the intermediate C file')
452    parser.add_argument('--no-keep-c',
453                        action='store_false', dest='keep_c',
454                        help='Don\'t keep the intermediate C file (default)')
455    parser.add_argument('--program',
456                        default='programs/psa/psa_constant_names',
457                        help='Program to test')
458    parser.add_argument('--show',
459                        action='store_true',
460                        help='Keep the intermediate C file')
461    parser.add_argument('--no-show',
462                        action='store_false', dest='show',
463                        help='Don\'t show tested values (default)')
464    options = parser.parse_args()
465    headers = [os.path.join(options.include[0], h) for h in HEADERS]
466    inputs = gather_inputs(headers, TEST_SUITES)
467    tests = Tests(options)
468    tests.run_all(inputs)
469    tests.report(sys.stdout)
470    if tests.errors:
471        sys.exit(1)
472
473if __name__ == '__main__':
474    main()
475