1#!/usr/bin/env python3
2"""Test the program psa_constant_names.
3Gather constant names from header files and test cases. Compile a C program
4to print out their numerical values, feed these numerical values to
5psa_constant_names, and check that the output is the original name.
6Return 0 if all test cases pass, 1 if the output was not always as expected,
7or 1 (with a Python backtrace) if there was an operational error.
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
12
13import argparse
14from collections import namedtuple
15import os
16import re
17import subprocess
18import sys
19from typing import Iterable, List, Optional, Tuple
20
21import scripts_path # pylint: disable=unused-import
22from mbedtls_dev import c_build_helper
23from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
24from mbedtls_dev import typing_util
25
26def gather_inputs(headers: Iterable[str],
27                  test_suites: Iterable[str],
28                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
29    """Read the list of inputs to test psa_constant_names with."""
30    inputs = inputs_class()
31    for header in headers:
32        inputs.parse_header(header)
33    for test_cases in test_suites:
34        inputs.parse_test_cases(test_cases)
35    inputs.add_numerical_values()
36    inputs.gather_arguments()
37    return inputs
38
39def run_c(type_word: str,
40          expressions: Iterable[str],
41          include_path: Optional[str] = None,
42          keep_c: bool = False) -> List[str]:
43    """Generate and run a program to print out numerical values of C expressions."""
44    if type_word == 'status':
45        cast_to = 'long'
46        printf_format = '%ld'
47    else:
48        cast_to = 'unsigned long'
49        printf_format = '0x%08lx'
50    return c_build_helper.get_c_expression_values(
51        cast_to, printf_format,
52        expressions,
53        caller='test_psa_constant_names.py for {} values'.format(type_word),
54        file_label=type_word,
55        header='#include <psa/crypto.h>',
56        include_path=include_path,
57        keep_c=keep_c
58    )
59
60NORMALIZE_STRIP_RE = re.compile(r'\s+')
61def normalize(expr: str) -> str:
62    """Normalize the C expression so as not to care about trivial differences.
63
64    Currently "trivial differences" means whitespace.
65    """
66    return re.sub(NORMALIZE_STRIP_RE, '', expr)
67
68ALG_TRUNCATED_TO_SELF_RE = \
69    re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\('
70               r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)'
71               r', *16\)\Z')
72
73def is_simplifiable(expr: str) -> bool:
74    """Determine whether an expression is simplifiable.
75
76    Simplifiable expressions can't be output in their input form, since
77    the output will be the simple form. Therefore they must be excluded
78    from testing.
79    """
80    if ALG_TRUNCATED_TO_SELF_RE.match(expr):
81        return True
82    return False
83
84def collect_values(inputs: InputsForTest,
85                   type_word: str,
86                   include_path: Optional[str] = None,
87                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
88    """Generate expressions using known macro names and calculate their values.
89
90    Return a list of pairs of (expr, value) where expr is an expression and
91    value is a string representation of its integer value.
92    """
93    names = inputs.get_names(type_word)
94    expressions = sorted(expr
95                         for expr in inputs.generate_expressions(names)
96                         if not is_simplifiable(expr))
97    values = run_c(type_word, expressions,
98                   include_path=include_path, keep_c=keep_c)
99    return expressions, values
100
101class Tests:
102    """An object representing tests and their results."""
103
104    Error = namedtuple('Error',
105                       ['type', 'expression', 'value', 'output'])
106
107    def __init__(self, options) -> None:
108        self.options = options
109        self.count = 0
110        self.errors = [] #type: List[Tests.Error]
111
112    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
113        """Test psa_constant_names for the specified type.
114
115        Run the program on the names for this type.
116        Use the inputs to figure out what arguments to pass to macros that
117        take arguments.
118        """
119        expressions, values = collect_values(inputs, type_word,
120                                             include_path=self.options.include,
121                                             keep_c=self.options.keep_c)
122        output_bytes = subprocess.check_output([self.options.program,
123                                                type_word] + values)
124        output = output_bytes.decode('ascii')
125        outputs = output.strip().split('\n')
126        self.count += len(expressions)
127        for expr, value, output in zip(expressions, values, outputs):
128            if self.options.show:
129                sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
130            if normalize(expr) != normalize(output):
131                self.errors.append(self.Error(type=type_word,
132                                              expression=expr,
133                                              value=value,
134                                              output=output))
135
136    def run_all(self, inputs: InputsForTest) -> None:
137        """Run psa_constant_names on all the gathered inputs."""
138        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
139                          'key_type', 'key_usage']:
140            self.run_one(inputs, type_word)
141
142    def report(self, out: typing_util.Writable) -> None:
143        """Describe each case where the output is not as expected.
144
145        Write the errors to ``out``.
146        Also write a total.
147        """
148        for error in self.errors:
149            out.write('For {} "{}", got "{}" (value: {})\n'
150                      .format(error.type, error.expression,
151                              error.output, error.value))
152        out.write('{} test cases'.format(self.count))
153        if self.errors:
154            out.write(', {} FAIL\n'.format(len(self.errors)))
155        else:
156            out.write(' PASS\n')
157
158HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
159TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
160
161def main():
162    parser = argparse.ArgumentParser(description=globals()['__doc__'])
163    parser.add_argument('--include', '-I',
164                        action='append', default=['include'],
165                        help='Directory for header files')
166    parser.add_argument('--keep-c',
167                        action='store_true', dest='keep_c', default=False,
168                        help='Keep the intermediate C file')
169    parser.add_argument('--no-keep-c',
170                        action='store_false', dest='keep_c',
171                        help='Don\'t keep the intermediate C file (default)')
172    parser.add_argument('--program',
173                        default='programs/psa/psa_constant_names',
174                        help='Program to test')
175    parser.add_argument('--show',
176                        action='store_true',
177                        help='Show tested values on stdout')
178    parser.add_argument('--no-show',
179                        action='store_false', dest='show',
180                        help='Don\'t show tested values (default)')
181    options = parser.parse_args()
182    headers = [os.path.join(options.include[0], h) for h in HEADERS]
183    inputs = gather_inputs(headers, TEST_SUITES)
184    tests = Tests(options)
185    tests.run_all(inputs)
186    tests.report(sys.stdout)
187    if tests.errors:
188        sys.exit(1)
189
190if __name__ == '__main__':
191    main()
192