1#!/usr/bin/env python3
2"""Test the program psa_constant_names.
3Gather constant names from header files and test cases. Compile a C program
4to print out their numerical values, feed these numerical values to
5psa_constant_names, and check that the output is the original name.
6Return 0 if all test cases pass, 1 if the output was not always as expected,
7or 1 (with a Python backtrace) if there was an operational error.
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0
12#
13# Licensed under the Apache License, Version 2.0 (the "License"); you may
14# not use this file except in compliance with the License.
15# You may obtain a copy of the License at
16#
17# http://www.apache.org/licenses/LICENSE-2.0
18#
19# Unless required by applicable law or agreed to in writing, software
20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22# See the License for the specific language governing permissions and
23# limitations under the License.
24
25import argparse
26from collections import namedtuple
27import os
28import re
29import subprocess
30import sys
31from typing import Iterable, List, Optional, Tuple
32
33import scripts_path # pylint: disable=unused-import
34from mbedtls_dev import c_build_helper
35from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator
36from mbedtls_dev import typing_util
37
38def gather_inputs(headers: Iterable[str],
39                  test_suites: Iterable[str],
40                  inputs_class=InputsForTest) -> PSAMacroEnumerator:
41    """Read the list of inputs to test psa_constant_names with."""
42    inputs = inputs_class()
43    for header in headers:
44        inputs.parse_header(header)
45    for test_cases in test_suites:
46        inputs.parse_test_cases(test_cases)
47    inputs.add_numerical_values()
48    inputs.gather_arguments()
49    return inputs
50
51def run_c(type_word: str,
52          expressions: Iterable[str],
53          include_path: Optional[str] = None,
54          keep_c: bool = False) -> List[str]:
55    """Generate and run a program to print out numerical values of C expressions."""
56    if type_word == 'status':
57        cast_to = 'long'
58        printf_format = '%ld'
59    else:
60        cast_to = 'unsigned long'
61        printf_format = '0x%08lx'
62    return c_build_helper.get_c_expression_values(
63        cast_to, printf_format,
64        expressions,
65        caller='test_psa_constant_names.py for {} values'.format(type_word),
66        file_label=type_word,
67        header='#include <psa/crypto.h>',
68        include_path=include_path,
69        keep_c=keep_c
70    )
71
72NORMALIZE_STRIP_RE = re.compile(r'\s+')
73def normalize(expr: str) -> str:
74    """Normalize the C expression so as not to care about trivial differences.
75
76    Currently "trivial differences" means whitespace.
77    """
78    return re.sub(NORMALIZE_STRIP_RE, '', expr)
79
80ALG_TRUNCATED_TO_SELF_RE = \
81    re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\('
82               r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)'
83               r', *16\)\Z')
84
85def is_simplifiable(expr: str) -> bool:
86    """Determine whether an expression is simplifiable.
87
88    Simplifiable expressions can't be output in their input form, since
89    the output will be the simple form. Therefore they must be excluded
90    from testing.
91    """
92    if ALG_TRUNCATED_TO_SELF_RE.match(expr):
93        return True
94    return False
95
96def collect_values(inputs: InputsForTest,
97                   type_word: str,
98                   include_path: Optional[str] = None,
99                   keep_c: bool = False) -> Tuple[List[str], List[str]]:
100    """Generate expressions using known macro names and calculate their values.
101
102    Return a list of pairs of (expr, value) where expr is an expression and
103    value is a string representation of its integer value.
104    """
105    names = inputs.get_names(type_word)
106    expressions = sorted(expr
107                         for expr in inputs.generate_expressions(names)
108                         if not is_simplifiable(expr))
109    values = run_c(type_word, expressions,
110                   include_path=include_path, keep_c=keep_c)
111    return expressions, values
112
113class Tests:
114    """An object representing tests and their results."""
115
116    Error = namedtuple('Error',
117                       ['type', 'expression', 'value', 'output'])
118
119    def __init__(self, options) -> None:
120        self.options = options
121        self.count = 0
122        self.errors = [] #type: List[Tests.Error]
123
124    def run_one(self, inputs: InputsForTest, type_word: str) -> None:
125        """Test psa_constant_names for the specified type.
126
127        Run the program on the names for this type.
128        Use the inputs to figure out what arguments to pass to macros that
129        take arguments.
130        """
131        expressions, values = collect_values(inputs, type_word,
132                                             include_path=self.options.include,
133                                             keep_c=self.options.keep_c)
134        output_bytes = subprocess.check_output([self.options.program,
135                                                type_word] + values)
136        output = output_bytes.decode('ascii')
137        outputs = output.strip().split('\n')
138        self.count += len(expressions)
139        for expr, value, output in zip(expressions, values, outputs):
140            if self.options.show:
141                sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output))
142            if normalize(expr) != normalize(output):
143                self.errors.append(self.Error(type=type_word,
144                                              expression=expr,
145                                              value=value,
146                                              output=output))
147
148    def run_all(self, inputs: InputsForTest) -> None:
149        """Run psa_constant_names on all the gathered inputs."""
150        for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group',
151                          'key_type', 'key_usage']:
152            self.run_one(inputs, type_word)
153
154    def report(self, out: typing_util.Writable) -> None:
155        """Describe each case where the output is not as expected.
156
157        Write the errors to ``out``.
158        Also write a total.
159        """
160        for error in self.errors:
161            out.write('For {} "{}", got "{}" (value: {})\n'
162                      .format(error.type, error.expression,
163                              error.output, error.value))
164        out.write('{} test cases'.format(self.count))
165        if self.errors:
166            out.write(', {} FAIL\n'.format(len(self.errors)))
167        else:
168            out.write(' PASS\n')
169
170HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h']
171TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data']
172
173def main():
174    parser = argparse.ArgumentParser(description=globals()['__doc__'])
175    parser.add_argument('--include', '-I',
176                        action='append', default=['include'],
177                        help='Directory for header files')
178    parser.add_argument('--keep-c',
179                        action='store_true', dest='keep_c', default=False,
180                        help='Keep the intermediate C file')
181    parser.add_argument('--no-keep-c',
182                        action='store_false', dest='keep_c',
183                        help='Don\'t keep the intermediate C file (default)')
184    parser.add_argument('--program',
185                        default='programs/psa/psa_constant_names',
186                        help='Program to test')
187    parser.add_argument('--show',
188                        action='store_true',
189                        help='Show tested values on stdout')
190    parser.add_argument('--no-show',
191                        action='store_false', dest='show',
192                        help='Don\'t show tested values (default)')
193    options = parser.parse_args()
194    headers = [os.path.join(options.include[0], h) for h in HEADERS]
195    inputs = gather_inputs(headers, TEST_SUITES)
196    tests = Tests(options)
197    tests.run_all(inputs)
198    tests.report(sys.stdout)
199    if tests.errors:
200        sys.exit(1)
201
202if __name__ == '__main__':
203    main()
204