1#!/usr/bin/env python3 2"""Test the program psa_constant_names. 3Gather constant names from header files and test cases. Compile a C program 4to print out their numerical values, feed these numerical values to 5psa_constant_names, and check that the output is the original name. 6Return 0 if all test cases pass, 1 if the output was not always as expected, 7or 1 (with a Python backtrace) if there was an operational error. 8""" 9 10# Copyright The Mbed TLS Contributors 11# SPDX-License-Identifier: Apache-2.0 12# 13# Licensed under the Apache License, Version 2.0 (the "License"); you may 14# not use this file except in compliance with the License. 15# You may obtain a copy of the License at 16# 17# http://www.apache.org/licenses/LICENSE-2.0 18# 19# Unless required by applicable law or agreed to in writing, software 20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22# See the License for the specific language governing permissions and 23# limitations under the License. 24 25import argparse 26from collections import namedtuple 27import os 28import re 29import subprocess 30import sys 31from typing import Iterable, List, Optional, Tuple 32 33import scripts_path # pylint: disable=unused-import 34from mbedtls_dev import c_build_helper 35from mbedtls_dev.macro_collector import InputsForTest, PSAMacroEnumerator 36from mbedtls_dev import typing_util 37 38def gather_inputs(headers: Iterable[str], 39 test_suites: Iterable[str], 40 inputs_class=InputsForTest) -> PSAMacroEnumerator: 41 """Read the list of inputs to test psa_constant_names with.""" 42 inputs = inputs_class() 43 for header in headers: 44 inputs.parse_header(header) 45 for test_cases in test_suites: 46 inputs.parse_test_cases(test_cases) 47 inputs.add_numerical_values() 48 inputs.gather_arguments() 49 return inputs 50 51def run_c(type_word: str, 52 expressions: Iterable[str], 53 include_path: Optional[str] = None, 54 keep_c: bool = False) -> List[str]: 55 """Generate and run a program to print out numerical values of C expressions.""" 56 if type_word == 'status': 57 cast_to = 'long' 58 printf_format = '%ld' 59 else: 60 cast_to = 'unsigned long' 61 printf_format = '0x%08lx' 62 return c_build_helper.get_c_expression_values( 63 cast_to, printf_format, 64 expressions, 65 caller='test_psa_constant_names.py for {} values'.format(type_word), 66 file_label=type_word, 67 header='#include <psa/crypto.h>', 68 include_path=include_path, 69 keep_c=keep_c 70 ) 71 72NORMALIZE_STRIP_RE = re.compile(r'\s+') 73def normalize(expr: str) -> str: 74 """Normalize the C expression so as not to care about trivial differences. 75 76 Currently "trivial differences" means whitespace. 77 """ 78 return re.sub(NORMALIZE_STRIP_RE, '', expr) 79 80ALG_TRUNCATED_TO_SELF_RE = \ 81 re.compile(r'PSA_ALG_AEAD_WITH_SHORTENED_TAG\(' 82 r'PSA_ALG_(?:CCM|CHACHA20_POLY1305|GCM)' 83 r', *16\)\Z') 84 85def is_simplifiable(expr: str) -> bool: 86 """Determine whether an expression is simplifiable. 87 88 Simplifiable expressions can't be output in their input form, since 89 the output will be the simple form. Therefore they must be excluded 90 from testing. 91 """ 92 if ALG_TRUNCATED_TO_SELF_RE.match(expr): 93 return True 94 return False 95 96def collect_values(inputs: InputsForTest, 97 type_word: str, 98 include_path: Optional[str] = None, 99 keep_c: bool = False) -> Tuple[List[str], List[str]]: 100 """Generate expressions using known macro names and calculate their values. 101 102 Return a list of pairs of (expr, value) where expr is an expression and 103 value is a string representation of its integer value. 104 """ 105 names = inputs.get_names(type_word) 106 expressions = sorted(expr 107 for expr in inputs.generate_expressions(names) 108 if not is_simplifiable(expr)) 109 values = run_c(type_word, expressions, 110 include_path=include_path, keep_c=keep_c) 111 return expressions, values 112 113class Tests: 114 """An object representing tests and their results.""" 115 116 Error = namedtuple('Error', 117 ['type', 'expression', 'value', 'output']) 118 119 def __init__(self, options) -> None: 120 self.options = options 121 self.count = 0 122 self.errors = [] #type: List[Tests.Error] 123 124 def run_one(self, inputs: InputsForTest, type_word: str) -> None: 125 """Test psa_constant_names for the specified type. 126 127 Run the program on the names for this type. 128 Use the inputs to figure out what arguments to pass to macros that 129 take arguments. 130 """ 131 expressions, values = collect_values(inputs, type_word, 132 include_path=self.options.include, 133 keep_c=self.options.keep_c) 134 output_bytes = subprocess.check_output([self.options.program, 135 type_word] + values) 136 output = output_bytes.decode('ascii') 137 outputs = output.strip().split('\n') 138 self.count += len(expressions) 139 for expr, value, output in zip(expressions, values, outputs): 140 if self.options.show: 141 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output)) 142 if normalize(expr) != normalize(output): 143 self.errors.append(self.Error(type=type_word, 144 expression=expr, 145 value=value, 146 output=output)) 147 148 def run_all(self, inputs: InputsForTest) -> None: 149 """Run psa_constant_names on all the gathered inputs.""" 150 for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 151 'key_type', 'key_usage']: 152 self.run_one(inputs, type_word) 153 154 def report(self, out: typing_util.Writable) -> None: 155 """Describe each case where the output is not as expected. 156 157 Write the errors to ``out``. 158 Also write a total. 159 """ 160 for error in self.errors: 161 out.write('For {} "{}", got "{}" (value: {})\n' 162 .format(error.type, error.expression, 163 error.output, error.value)) 164 out.write('{} test cases'.format(self.count)) 165 if self.errors: 166 out.write(', {} FAIL\n'.format(len(self.errors))) 167 else: 168 out.write(' PASS\n') 169 170HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] 171TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] 172 173def main(): 174 parser = argparse.ArgumentParser(description=globals()['__doc__']) 175 parser.add_argument('--include', '-I', 176 action='append', default=['include'], 177 help='Directory for header files') 178 parser.add_argument('--keep-c', 179 action='store_true', dest='keep_c', default=False, 180 help='Keep the intermediate C file') 181 parser.add_argument('--no-keep-c', 182 action='store_false', dest='keep_c', 183 help='Don\'t keep the intermediate C file (default)') 184 parser.add_argument('--program', 185 default='programs/psa/psa_constant_names', 186 help='Program to test') 187 parser.add_argument('--show', 188 action='store_true', 189 help='Show tested values on stdout') 190 parser.add_argument('--no-show', 191 action='store_false', dest='show', 192 help='Don\'t show tested values (default)') 193 options = parser.parse_args() 194 headers = [os.path.join(options.include[0], h) for h in HEADERS] 195 inputs = gather_inputs(headers, TEST_SUITES) 196 tests = Tests(options) 197 tests.run_all(inputs) 198 tests.report(sys.stdout) 199 if tests.errors: 200 sys.exit(1) 201 202if __name__ == '__main__': 203 main() 204