1#!/usr/bin/env python3 2"""Test the program psa_constant_names. 3Gather constant names from header files and test cases. Compile a C program 4to print out their numerical values, feed these numerical values to 5psa_constant_names, and check that the output is the original name. 6Return 0 if all test cases pass, 1 if the output was not always as expected, 7or 1 (with a Python backtrace) if there was an operational error. 8""" 9 10# Copyright The Mbed TLS Contributors 11# SPDX-License-Identifier: Apache-2.0 12# 13# Licensed under the Apache License, Version 2.0 (the "License"); you may 14# not use this file except in compliance with the License. 15# You may obtain a copy of the License at 16# 17# http://www.apache.org/licenses/LICENSE-2.0 18# 19# Unless required by applicable law or agreed to in writing, software 20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22# See the License for the specific language governing permissions and 23# limitations under the License. 24 25import argparse 26from collections import namedtuple 27import itertools 28import os 29import platform 30import re 31import subprocess 32import sys 33import tempfile 34 35class ReadFileLineException(Exception): 36 def __init__(self, filename, line_number): 37 message = 'in {} at {}'.format(filename, line_number) 38 super(ReadFileLineException, self).__init__(message) 39 self.filename = filename 40 self.line_number = line_number 41 42class read_file_lines: 43 # Dear Pylint, conventionally, a context manager class name is lowercase. 44 # pylint: disable=invalid-name,too-few-public-methods 45 """Context manager to read a text file line by line. 46 47 ``` 48 with read_file_lines(filename) as lines: 49 for line in lines: 50 process(line) 51 ``` 52 is equivalent to 53 ``` 54 with open(filename, 'r') as input_file: 55 for line in input_file: 56 process(line) 57 ``` 58 except that if process(line) raises an exception, then the read_file_lines 59 snippet annotates the exception with the file name and line number. 60 """ 61 def __init__(self, filename, binary=False): 62 self.filename = filename 63 self.line_number = 'entry' 64 self.generator = None 65 self.binary = binary 66 def __enter__(self): 67 self.generator = enumerate(open(self.filename, 68 'rb' if self.binary else 'r')) 69 return self 70 def __iter__(self): 71 for line_number, content in self.generator: 72 self.line_number = line_number 73 yield content 74 self.line_number = 'exit' 75 def __exit__(self, exc_type, exc_value, exc_traceback): 76 if exc_type is not None: 77 raise ReadFileLineException(self.filename, self.line_number) \ 78 from exc_value 79 80class Inputs: 81 # pylint: disable=too-many-instance-attributes 82 """Accumulate information about macros to test. 83 84 This includes macro names as well as information about their arguments 85 when applicable. 86 """ 87 88 def __init__(self): 89 self.all_declared = set() 90 # Sets of names per type 91 self.statuses = set(['PSA_SUCCESS']) 92 self.algorithms = set(['0xffffffff']) 93 self.ecc_curves = set(['0xff']) 94 self.dh_groups = set(['0xff']) 95 self.key_types = set(['0xffff']) 96 self.key_usage_flags = set(['0x80000000']) 97 # Hard-coded value for unknown algorithms 98 self.hash_algorithms = set(['0x020000fe']) 99 self.mac_algorithms = set(['0x0300ffff']) 100 self.ka_algorithms = set(['0x09fc0000']) 101 self.kdf_algorithms = set(['0x080000ff']) 102 # For AEAD algorithms, the only variability is over the tag length, 103 # and this only applies to known algorithms, so don't test an 104 # unknown algorithm. 105 self.aead_algorithms = set() 106 # Identifier prefixes 107 self.table_by_prefix = { 108 'ERROR': self.statuses, 109 'ALG': self.algorithms, 110 'ECC_CURVE': self.ecc_curves, 111 'DH_GROUP': self.dh_groups, 112 'KEY_TYPE': self.key_types, 113 'KEY_USAGE': self.key_usage_flags, 114 } 115 # Test functions 116 self.table_by_test_function = { 117 # Any function ending in _algorithm also gets added to 118 # self.algorithms. 119 'key_type': [self.key_types], 120 'block_cipher_key_type': [self.key_types], 121 'stream_cipher_key_type': [self.key_types], 122 'ecc_key_family': [self.ecc_curves], 123 'ecc_key_types': [self.ecc_curves], 124 'dh_key_family': [self.dh_groups], 125 'dh_key_types': [self.dh_groups], 126 'hash_algorithm': [self.hash_algorithms], 127 'mac_algorithm': [self.mac_algorithms], 128 'cipher_algorithm': [], 129 'hmac_algorithm': [self.mac_algorithms], 130 'aead_algorithm': [self.aead_algorithms], 131 'key_derivation_algorithm': [self.kdf_algorithms], 132 'key_agreement_algorithm': [self.ka_algorithms], 133 'asymmetric_signature_algorithm': [], 134 'asymmetric_signature_wildcard': [self.algorithms], 135 'asymmetric_encryption_algorithm': [], 136 'other_algorithm': [], 137 } 138 # macro name -> list of argument names 139 self.argspecs = {} 140 # argument name -> list of values 141 self.arguments_for = { 142 'mac_length': ['1', '63'], 143 'tag_length': ['1', '63'], 144 } 145 146 def get_names(self, type_word): 147 """Return the set of known names of values of the given type.""" 148 return { 149 'status': self.statuses, 150 'algorithm': self.algorithms, 151 'ecc_curve': self.ecc_curves, 152 'dh_group': self.dh_groups, 153 'key_type': self.key_types, 154 'key_usage': self.key_usage_flags, 155 }[type_word] 156 157 def gather_arguments(self): 158 """Populate the list of values for macro arguments. 159 160 Call this after parsing all the inputs. 161 """ 162 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) 163 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) 164 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) 165 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) 166 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) 167 self.arguments_for['curve'] = sorted(self.ecc_curves) 168 self.arguments_for['group'] = sorted(self.dh_groups) 169 170 @staticmethod 171 def _format_arguments(name, arguments): 172 """Format a macro call with arguments..""" 173 return name + '(' + ', '.join(arguments) + ')' 174 175 def distribute_arguments(self, name): 176 """Generate macro calls with each tested argument set. 177 178 If name is a macro without arguments, just yield "name". 179 If name is a macro with arguments, yield a series of 180 "name(arg1,...,argN)" where each argument takes each possible 181 value at least once. 182 """ 183 try: 184 if name not in self.argspecs: 185 yield name 186 return 187 argspec = self.argspecs[name] 188 if argspec == []: 189 yield name + '()' 190 return 191 argument_lists = [self.arguments_for[arg] for arg in argspec] 192 arguments = [values[0] for values in argument_lists] 193 yield self._format_arguments(name, arguments) 194 # Dear Pylint, enumerate won't work here since we're modifying 195 # the array. 196 # pylint: disable=consider-using-enumerate 197 for i in range(len(arguments)): 198 for value in argument_lists[i][1:]: 199 arguments[i] = value 200 yield self._format_arguments(name, arguments) 201 arguments[i] = argument_lists[0][0] 202 except BaseException as e: 203 raise Exception('distribute_arguments({})'.format(name)) from e 204 205 def generate_expressions(self, names): 206 return itertools.chain(*map(self.distribute_arguments, names)) 207 208 _argument_split_re = re.compile(r' *, *') 209 @classmethod 210 def _argument_split(cls, arguments): 211 return re.split(cls._argument_split_re, arguments) 212 213 # Regex for interesting header lines. 214 # Groups: 1=macro name, 2=type, 3=argument list (optional). 215 _header_line_re = \ 216 re.compile(r'#define +' + 217 r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + 218 r'(?:\(([^\n()]*)\))?') 219 # Regex of macro names to exclude. 220 _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') 221 # Additional excluded macros. 222 _excluded_names = set([ 223 # Macros that provide an alternative way to build the same 224 # algorithm as another macro. 225 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH', 226 'PSA_ALG_FULL_LENGTH_MAC', 227 # Auxiliary macro whose name doesn't fit the usual patterns for 228 # auxiliary macros. 229 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE', 230 ]) 231 def parse_header_line(self, line): 232 """Parse a C header line, looking for "#define PSA_xxx".""" 233 m = re.match(self._header_line_re, line) 234 if not m: 235 return 236 name = m.group(1) 237 self.all_declared.add(name) 238 if re.search(self._excluded_name_re, name) or \ 239 name in self._excluded_names: 240 return 241 dest = self.table_by_prefix.get(m.group(2)) 242 if dest is None: 243 return 244 dest.add(name) 245 if m.group(3): 246 self.argspecs[name] = self._argument_split(m.group(3)) 247 248 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') 249 def parse_header(self, filename): 250 """Parse a C header file, looking for "#define PSA_xxx".""" 251 with read_file_lines(filename, binary=True) as lines: 252 for line in lines: 253 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 254 self.parse_header_line(line) 255 256 _macro_identifier_re = re.compile(r'[A-Z]\w+') 257 def generate_undeclared_names(self, expr): 258 for name in re.findall(self._macro_identifier_re, expr): 259 if name not in self.all_declared: 260 yield name 261 262 def accept_test_case_line(self, function, argument): 263 #pylint: disable=unused-argument 264 undeclared = list(self.generate_undeclared_names(argument)) 265 if undeclared: 266 raise Exception('Undeclared names in test case', undeclared) 267 return True 268 269 def add_test_case_line(self, function, argument): 270 """Parse a test case data line, looking for algorithm metadata tests.""" 271 sets = [] 272 if function.endswith('_algorithm'): 273 sets.append(self.algorithms) 274 if function == 'key_agreement_algorithm' and \ 275 argument.startswith('PSA_ALG_KEY_AGREEMENT('): 276 # We only want *raw* key agreement algorithms as such, so 277 # exclude ones that are already chained with a KDF. 278 # Keep the expression as one to test as an algorithm. 279 function = 'other_algorithm' 280 sets += self.table_by_test_function[function] 281 if self.accept_test_case_line(function, argument): 282 for s in sets: 283 s.add(argument) 284 285 # Regex matching a *.data line containing a test function call and 286 # its arguments. The actual definition is partly positional, but this 287 # regex is good enough in practice. 288 _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') 289 def parse_test_cases(self, filename): 290 """Parse a test case file (*.data), looking for algorithm metadata tests.""" 291 with read_file_lines(filename) as lines: 292 for line in lines: 293 m = re.match(self._test_case_line_re, line) 294 if m: 295 self.add_test_case_line(m.group(1), m.group(2)) 296 297def gather_inputs(headers, test_suites, inputs_class=Inputs): 298 """Read the list of inputs to test psa_constant_names with.""" 299 inputs = inputs_class() 300 for header in headers: 301 inputs.parse_header(header) 302 for test_cases in test_suites: 303 inputs.parse_test_cases(test_cases) 304 inputs.gather_arguments() 305 return inputs 306 307def remove_file_if_exists(filename): 308 """Remove the specified file, ignoring errors.""" 309 if not filename: 310 return 311 try: 312 os.remove(filename) 313 except OSError: 314 pass 315 316def run_c(type_word, expressions, include_path=None, keep_c=False): 317 """Generate and run a program to print out numerical values for expressions.""" 318 if include_path is None: 319 include_path = [] 320 if type_word == 'status': 321 cast_to = 'long' 322 printf_format = '%ld' 323 else: 324 cast_to = 'unsigned long' 325 printf_format = '0x%08lx' 326 c_name = None 327 exe_name = None 328 try: 329 c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type_word), 330 suffix='.c', 331 dir='programs/psa') 332 exe_suffix = '.exe' if platform.system() == 'Windows' else '' 333 exe_name = c_name[:-2] + exe_suffix 334 remove_file_if_exists(exe_name) 335 c_file = os.fdopen(c_fd, 'w', encoding='ascii') 336 c_file.write('/* Generated by test_psa_constant_names.py for {} values */' 337 .format(type_word)) 338 c_file.write(''' 339#include <stdio.h> 340#include <psa/crypto.h> 341int main(void) 342{ 343''') 344 for expr in expressions: 345 c_file.write(' printf("{}\\n", ({}) {});\n' 346 .format(printf_format, cast_to, expr)) 347 c_file.write(''' return 0; 348} 349''') 350 c_file.close() 351 cc = os.getenv('CC', 'cc') 352 subprocess.check_call([cc] + 353 ['-I' + dir for dir in include_path] + 354 ['-o', exe_name, c_name]) 355 if keep_c: 356 sys.stderr.write('List of {} tests kept at {}\n' 357 .format(type_word, c_name)) 358 else: 359 os.remove(c_name) 360 output = subprocess.check_output([exe_name]) 361 return output.decode('ascii').strip().split('\n') 362 finally: 363 remove_file_if_exists(exe_name) 364 365NORMALIZE_STRIP_RE = re.compile(r'\s+') 366def normalize(expr): 367 """Normalize the C expression so as not to care about trivial differences. 368 369 Currently "trivial differences" means whitespace. 370 """ 371 return re.sub(NORMALIZE_STRIP_RE, '', expr) 372 373def collect_values(inputs, type_word, include_path=None, keep_c=False): 374 """Generate expressions using known macro names and calculate their values. 375 376 Return a list of pairs of (expr, value) where expr is an expression and 377 value is a string representation of its integer value. 378 """ 379 names = inputs.get_names(type_word) 380 expressions = sorted(inputs.generate_expressions(names)) 381 values = run_c(type_word, expressions, 382 include_path=include_path, keep_c=keep_c) 383 return expressions, values 384 385class Tests: 386 """An object representing tests and their results.""" 387 388 Error = namedtuple('Error', 389 ['type', 'expression', 'value', 'output']) 390 391 def __init__(self, options): 392 self.options = options 393 self.count = 0 394 self.errors = [] 395 396 def run_one(self, inputs, type_word): 397 """Test psa_constant_names for the specified type. 398 399 Run the program on the names for this type. 400 Use the inputs to figure out what arguments to pass to macros that 401 take arguments. 402 """ 403 expressions, values = collect_values(inputs, type_word, 404 include_path=self.options.include, 405 keep_c=self.options.keep_c) 406 output = subprocess.check_output([self.options.program, type_word] + 407 values) 408 outputs = output.decode('ascii').strip().split('\n') 409 self.count += len(expressions) 410 for expr, value, output in zip(expressions, values, outputs): 411 if self.options.show: 412 sys.stdout.write('{} {}\t{}\n'.format(type_word, value, output)) 413 if normalize(expr) != normalize(output): 414 self.errors.append(self.Error(type=type_word, 415 expression=expr, 416 value=value, 417 output=output)) 418 419 def run_all(self, inputs): 420 """Run psa_constant_names on all the gathered inputs.""" 421 for type_word in ['status', 'algorithm', 'ecc_curve', 'dh_group', 422 'key_type', 'key_usage']: 423 self.run_one(inputs, type_word) 424 425 def report(self, out): 426 """Describe each case where the output is not as expected. 427 428 Write the errors to ``out``. 429 Also write a total. 430 """ 431 for error in self.errors: 432 out.write('For {} "{}", got "{}" (value: {})\n' 433 .format(error.type, error.expression, 434 error.output, error.value)) 435 out.write('{} test cases'.format(self.count)) 436 if self.errors: 437 out.write(', {} FAIL\n'.format(len(self.errors))) 438 else: 439 out.write(' PASS\n') 440 441HEADERS = ['psa/crypto.h', 'psa/crypto_extra.h', 'psa/crypto_values.h'] 442TEST_SUITES = ['tests/suites/test_suite_psa_crypto_metadata.data'] 443 444def main(): 445 parser = argparse.ArgumentParser(description=globals()['__doc__']) 446 parser.add_argument('--include', '-I', 447 action='append', default=['include'], 448 help='Directory for header files') 449 parser.add_argument('--keep-c', 450 action='store_true', dest='keep_c', default=False, 451 help='Keep the intermediate C file') 452 parser.add_argument('--no-keep-c', 453 action='store_false', dest='keep_c', 454 help='Don\'t keep the intermediate C file (default)') 455 parser.add_argument('--program', 456 default='programs/psa/psa_constant_names', 457 help='Program to test') 458 parser.add_argument('--show', 459 action='store_true', 460 help='Keep the intermediate C file') 461 parser.add_argument('--no-show', 462 action='store_false', dest='show', 463 help='Don\'t show tested values (default)') 464 options = parser.parse_args() 465 headers = [os.path.join(options.include[0], h) for h in HEADERS] 466 inputs = gather_inputs(headers, TEST_SUITES) 467 tests = Tests(options) 468 tests.run_all(inputs) 469 tests.report(sys.stdout) 470 if tests.errors: 471 sys.exit(1) 472 473if __name__ == '__main__': 474 main() 475