1#!/usr/bin/env python3 2"""Test the program psa_constant_names. 3Gather constant names from header files and test cases. Compile a C program 4to print out their numerical values, feed these numerical values to 5psa_constant_names, and check that the output is the original name. 6Return 0 if all test cases pass, 1 if the output was not always as expected, 7or 1 (with a Python backtrace) if there was an operational error. 8""" 9 10# Copyright The Mbed TLS Contributors 11# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 12 13import argparse 14from collections import namedtuple 15import os 16import re 17import subprocess 18import sys 19from typing import Iterable, List, Optional, Tuple 20 21import scripts_path # pylint: disable=unused-import 22from mbedtls_dev import c_build_helper 23from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator 24from mbedtls_dev import typing_util 25 26def gather_inputs(headers: Iterable[str], 27 test_suites: Iterable[str], 28 inputs_class=InputsForTest) -> PSAMacroEnumerator: 29 """Read the list of inputs to test psa_constant_names with.""" 30 inputs = inputs_class() 31 for header in headers: 32 inputs.parse_header(header) 33 for test_cases in test_suites: 34 inputs.parse_test_cases(test_cases) 35 inputs.add_numerical_values() 36 inputs.gather_arguments() 37 return inputs 38 39def run_c(type_word: str, 40 expressions: Iterable[str], 41 include_path: Optional[str] = None, 42 keep_c: bool = False) -> List[str]: 43 """Generate and run a program to print out numerical values of C expressions.""" 44 if type_word == 'status': 45 cast_to = 'long' 46 printf_format = '%ld' 47 else: 48 cast_to = 'unsigned long' 49 printf_format = '0x%08lx' 50 return c_build_helper.get_c_expression_values( 51 cast_to, printf_format, 52 expressions, 53 caller='test_psa_constant_names.py for {} values'.format(type_word), 54 file_label=type_word, 55 header='#include <psa/crypto.h>', 56 include_path=include_path, 57 keep_c=keep_c 58 ) 59 60NORMALIZE_STRIP_RE = re.compile(r'\s+') 61def normalize(expr: str) -> str: 62 """Normalize the C expression so as not to care about trivial differences. 63 64 Currently "trivial differences" means whitespace. 65 """ 66 return re.sub(NORMALIZE_STRIP_RE, '', expr) 67 68ALG_TRUNCATED_TO_SELF_RE = \ 69 re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\(' 70 r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)' 71 r', *16\)\Z') 72 73def is_simplifiable(expr: str) -> bool: 74 """Determine whether an expression is simplifiable. 75 76 Simplifiable expressions can't be output in their input form, since 77 the output will be the simple form. Therefore they must be excluded 78 from testing. 79 """ 80 if ALG_TRUNCATED_TO_SELF_RE.match(expr): 81 return True 82 return False 83 84def collect_values(inputs: InputsForTest, 85 type_word: str, 86 include_path: Optional[str] = None, 87 keep_c: bool = False) -> Tuple[List[str], List[str]]: 88 """Generate expressions using known macro names and calculate their values. 89 90 Return a list of pairs of (expr, value) where expr is an expression and 91 value is a string representation of its integer value. 92 """ 93 names = inputs.get_names(type_word) 94 expressions = sorted(expr 95 for expr in inputs.generate_expressions(names) 96 if not is_simplifiable(expr)) 97 values = run_c(type_word, expressions, 98 include_path=include_path, keep_c=keep_c) 99 return expressions, values 100 101class Tests: 102 """An object representing tests and their results.""" 103 104 Error = namedtuple('Error', 105 ['type', 'expression', 'value', 'output']) 106 107 def __init__(self, options) -> None: 108 self.options = options 109 self.count = 0 110 self.errors = [] #type: List[Tests.Error] 111 112 def run_one(self, inputs: InputsForTest, type_word: str) -> None: 113 """Test psa_constant_names for the specified type. 114 115 Run the program on the names for this type. 116 Use the inputs to figure out what arguments to pass to macros that 117 take arguments. 118 """ 119 expressions, values = collect_values(inputs, type_word, 120 include_path=self.options.include, 121 keep_c=self.options.keep_c) 122 output_bytes = subprocess.check_output([self.options.program, 123 type_word] + values) 124 output = output_bytes.decode('ascii') 125 outputs = output.strip().split('\n') 126 self.count += len(expressions) 127 for expr, value, output in zip(expressions, values, outputs): 128 if self.options.show: 129 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output)) 130 if normalize(expr) != normalize(output): 131 self.errors.append(self.Error(type=type_word, 132 expression=expr, 133 value=value, 134 output=output)) 135 136 def run_all(self, inputs: InputsForTest) -> None: 137 """Run psa_constant_names on all the gathered inputs.""" 138 for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 139 'key_type', 'key_usage']: 140 self.run_one(inputs, type_word) 141 142 def report(self, out: typing_util.Writable) -> None: 143 """Describe each case where the output is not as expected. 144 145 Write the errors to ``out``. 146 Also write a total. 147 """ 148 for error in self.errors: 149 out.write('For {} "{}", got "{}" (value: {})\n' 150 .format(error.type, error.expression, 151 error.output, error.value)) 152 out.write('{} test cases'.format(self.count)) 153 if self.errors: 154 out.write(', {} FAIL\n'.format(len(self.errors))) 155 else: 156 out.write(' PASS\n') 157 158HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] 159TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] 160 161def main(): 162 parser = argparse.ArgumentParser(description=globals()['__doc__']) 163 parser.add_argument('--include', '-I', 164 action='append', default=['include'], 165 help='Directory for header files') 166 parser.add_argument('--keep-c', 167 action='store_true', dest='keep_c', default=False, 168 help='Keep the intermediate C file') 169 parser.add_argument('--no-keep-c', 170 action='store_false', dest='keep_c', 171 help='Don\'t keep the intermediate C file (default)') 172 parser.add_argument('--program', 173 default='programs/psa/psa_constant_names', 174 help='Program to test') 175 parser.add_argument('--show', 176 action='store_true', 177 help='Show tested values on stdout') 178 parser.add_argument('--no-show', 179 action='store_false', dest='show', 180 help='Don\'t show tested values (default)') 181 options = parser.parse_args() 182 headers = [os.path.join(options.include[0], h) for h in HEADERS] 183 inputs = gather_inputs(headers, TEST_SUITES) 184 tests = Tests(options) 185 tests.run_all(inputs) 186 tests.report(sys.stdout) 187 if tests.errors: 188 sys.exit(1) 189 190if __name__ == '__main__': 191 main() 192