1"""Collect macro definitions from header files. 2""" 3 4# Copyright The Mbed TLS Contributors 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the "License"); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18 19import itertools 20import re 21from typing import Dict, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union 22 23 24class ReadFileLineException(Exception): 25 def __init__(self, filename: str, line_number: Union[int, str]) -> None: 26 message = 'in {} at {}'.format(filename, line_number) 27 super(ReadFileLineException, self).__init__(message) 28 self.filename = filename 29 self.line_number = line_number 30 31 32class read_file_lines: 33 # Dear Pylint, conventionally, a context manager class name is lowercase. 34 # pylint: disable=invalid-name,too-few-public-methods 35 """Context manager to read a text file line by line. 36 37 ``` 38 with read_file_lines(filename) as lines: 39 for line in lines: 40 process(line) 41 ``` 42 is equivalent to 43 ``` 44 with open(filename, 'r') as input_file: 45 for line in input_file: 46 process(line) 47 ``` 48 except that if process(line) raises an exception, then the read_file_lines 49 snippet annotates the exception with the file name and line number. 50 """ 51 def __init__(self, filename: str, binary: bool = False) -> None: 52 self.filename = filename 53 self.line_number = 'entry' #type: Union[int, str] 54 self.generator = None #type: Optional[Iterable[Tuple[int, str]]] 55 self.binary = binary 56 def __enter__(self) -> 'read_file_lines': 57 self.generator = enumerate(open(self.filename, 58 'rb' if self.binary else 'r')) 59 return self 60 def __iter__(self) -> Iterator[str]: 61 assert self.generator is not None 62 for line_number, content in self.generator: 63 self.line_number = line_number 64 yield content 65 self.line_number = 'exit' 66 def __exit__(self, exc_type, exc_value, exc_traceback) -> None: 67 if exc_type is not None: 68 raise ReadFileLineException(self.filename, self.line_number) \ 69 from exc_value 70 71 72class PSAMacroEnumerator: 73 """Information about constructors of various PSA Crypto types. 74 75 This includes macro names as well as information about their arguments 76 when applicable. 77 78 This class only provides ways to enumerate expressions that evaluate to 79 values of the covered types. Derived classes are expected to populate 80 the set of known constructors of each kind, as well as populate 81 `self.arguments_for` for arguments that are not of a kind that is 82 enumerated here. 83 """ 84 #pylint: disable=too-many-instance-attributes 85 86 def __init__(self) -> None: 87 """Set up an empty set of known constructor macros. 88 """ 89 self.statuses = set() #type: Set[str] 90 self.lifetimes = set() #type: Set[str] 91 self.locations = set() #type: Set[str] 92 self.persistence_levels = set() #type: Set[str] 93 self.algorithms = set() #type: Set[str] 94 self.ecc_curves = set() #type: Set[str] 95 self.dh_groups = set() #type: Set[str] 96 self.key_types = set() #type: Set[str] 97 self.key_usage_flags = set() #type: Set[str] 98 self.hash_algorithms = set() #type: Set[str] 99 self.mac_algorithms = set() #type: Set[str] 100 self.ka_algorithms = set() #type: Set[str] 101 self.kdf_algorithms = set() #type: Set[str] 102 self.aead_algorithms = set() #type: Set[str] 103 self.sign_algorithms = set() #type: Set[str] 104 # macro name -> list of argument names 105 self.argspecs = {} #type: Dict[str, List[str]] 106 # argument name -> list of values 107 self.arguments_for = { 108 'mac_length': [], 109 'min_mac_length': [], 110 'tag_length': [], 111 'min_tag_length': [], 112 } #type: Dict[str, List[str]] 113 # Whether to include intermediate macros in enumerations. Intermediate 114 # macros serve as category headers and are not valid values of their 115 # type. See `is_internal_name`. 116 # Always false in this class, may be set to true in derived classes. 117 self.include_intermediate = False 118 119 def is_internal_name(self, name: str) -> bool: 120 """Whether this is an internal macro. Internal macros will be skipped.""" 121 if not self.include_intermediate: 122 if name.endswith('_BASE') or name.endswith('_NONE'): 123 return True 124 if '_CATEGORY_' in name: 125 return True 126 return name.endswith('_FLAG') or name.endswith('_MASK') 127 128 def gather_arguments(self) -> None: 129 """Populate the list of values for macro arguments. 130 131 Call this after parsing all the inputs. 132 """ 133 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) 134 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) 135 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) 136 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) 137 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) 138 self.arguments_for['sign_alg'] = sorted(self.sign_algorithms) 139 self.arguments_for['curve'] = sorted(self.ecc_curves) 140 self.arguments_for['group'] = sorted(self.dh_groups) 141 self.arguments_for['persistence'] = sorted(self.persistence_levels) 142 self.arguments_for['location'] = sorted(self.locations) 143 self.arguments_for['lifetime'] = sorted(self.lifetimes) 144 145 @staticmethod 146 def _format_arguments(name: str, arguments: Iterable[str]) -> str: 147 """Format a macro call with arguments. 148 149 The resulting format is consistent with 150 `InputsForTest.normalize_argument`. 151 """ 152 return name + '(' + ', '.join(arguments) + ')' 153 154 _argument_split_re = re.compile(r' *, *') 155 @classmethod 156 def _argument_split(cls, arguments: str) -> List[str]: 157 return re.split(cls._argument_split_re, arguments) 158 159 def distribute_arguments(self, name: str) -> Iterator[str]: 160 """Generate macro calls with each tested argument set. 161 162 If name is a macro without arguments, just yield "name". 163 If name is a macro with arguments, yield a series of 164 "name(arg1,...,argN)" where each argument takes each possible 165 value at least once. 166 """ 167 try: 168 if name not in self.argspecs: 169 yield name 170 return 171 argspec = self.argspecs[name] 172 if argspec == []: 173 yield name + '()' 174 return 175 argument_lists = [self.arguments_for[arg] for arg in argspec] 176 arguments = [values[0] for values in argument_lists] 177 yield self._format_arguments(name, arguments) 178 # Dear Pylint, enumerate won't work here since we're modifying 179 # the array. 180 # pylint: disable=consider-using-enumerate 181 for i in range(len(arguments)): 182 for value in argument_lists[i][1:]: 183 arguments[i] = value 184 yield self._format_arguments(name, arguments) 185 arguments[i] = argument_lists[0][0] 186 except BaseException as e: 187 raise Exception('distribute_arguments({})'.format(name)) from e 188 189 def distribute_arguments_without_duplicates( 190 self, seen: Set[str], name: str 191 ) -> Iterator[str]: 192 """Same as `distribute_arguments`, but don't repeat seen results.""" 193 for result in self.distribute_arguments(name): 194 if result not in seen: 195 seen.add(result) 196 yield result 197 198 def generate_expressions(self, names: Iterable[str]) -> Iterator[str]: 199 """Generate expressions covering values constructed from the given names. 200 201 `names` can be any iterable collection of macro names. 202 203 For example: 204 * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])`` 205 generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for 206 every known hash algorithm ``h``. 207 * ``macros.generate_expressions(macros.key_types)`` generates all 208 key types. 209 """ 210 seen = set() #type: Set[str] 211 return itertools.chain(*( 212 self.distribute_arguments_without_duplicates(seen, name) 213 for name in names 214 )) 215 216 217class PSAMacroCollector(PSAMacroEnumerator): 218 """Collect PSA crypto macro definitions from C header files. 219 """ 220 221 def __init__(self, include_intermediate: bool = False) -> None: 222 """Set up an object to collect PSA macro definitions. 223 224 Call the read_file method of the constructed object on each header file. 225 226 * include_intermediate: if true, include intermediate macros such as 227 PSA_XXX_BASE that do not designate semantic values. 228 """ 229 super().__init__() 230 self.include_intermediate = include_intermediate 231 self.key_types_from_curve = {} #type: Dict[str, str] 232 self.key_types_from_group = {} #type: Dict[str, str] 233 self.algorithms_from_hash = {} #type: Dict[str, str] 234 235 @staticmethod 236 def algorithm_tester(name: str) -> str: 237 """The predicate for whether an algorithm is built from the given constructor. 238 239 The given name must be the name of an algorithm constructor of the 240 form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build 241 an algorithm value. Return the corresponding predicate macro which 242 is used as ``predicate(alg)`` to test whether ``alg`` can be built 243 as ``PSA_ALG_xxx(yyy)``. The predicate is usually called 244 ``PSA_ALG_IS_xxx``. 245 """ 246 prefix = 'PSA_ALG_' 247 assert name.startswith(prefix) 248 midfix = 'IS_' 249 suffix = name[len(prefix):] 250 if suffix in ['DSA', 'ECDSA']: 251 midfix += 'RANDOMIZED_' 252 elif suffix == 'RSA_PSS': 253 suffix += '_STANDARD_SALT' 254 return prefix + midfix + suffix 255 256 def record_algorithm_subtype(self, name: str, expansion: str) -> None: 257 """Record the subtype of an algorithm constructor. 258 259 Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm 260 is of a subtype that is tracked in its own set, add it to the relevant 261 set. 262 """ 263 # This code is very ad hoc and fragile. It should be replaced by 264 # something more robust. 265 if re.match(r'MAC(?:_|\Z)', name): 266 self.mac_algorithms.add(name) 267 elif re.match(r'KDF(?:_|\Z)', name): 268 self.kdf_algorithms.add(name) 269 elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): 270 self.hash_algorithms.add(name) 271 elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion): 272 self.mac_algorithms.add(name) 273 elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion): 274 self.aead_algorithms.add(name) 275 elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): 276 self.ka_algorithms.add(name) 277 elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion): 278 self.kdf_algorithms.add(name) 279 280 # "#define" followed by a macro name with either no parameters 281 # or a single parameter and a non-empty expansion. 282 # Grab the macro name in group 1, the parameter name if any in group 2 283 # and the expansion in group 3. 284 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' + 285 r'(?:\s+|\((\w+)\)\s*)' + 286 r'(.+)') 287 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED') 288 289 def read_line(self, line): 290 """Parse a C header line and record the PSA identifier it defines if any. 291 This function analyzes lines that start with "#define PSA_" 292 (up to non-significant whitespace) and skips all non-matching lines. 293 """ 294 # pylint: disable=too-many-branches 295 m = re.match(self._define_directive_re, line) 296 if not m: 297 return 298 name, parameter, expansion = m.groups() 299 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion) 300 if parameter: 301 self.argspecs[name] = [parameter] 302 if re.match(self._deprecated_definition_re, expansion): 303 # Skip deprecated values, which are assumed to be 304 # backward compatibility aliases that share 305 # numerical values with non-deprecated values. 306 return 307 if self.is_internal_name(name): 308 # Macro only to build actual values 309 return 310 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \ 311 and not parameter: 312 self.statuses.add(name) 313 elif name.startswith('PSA_KEY_TYPE_') and not parameter: 314 self.key_types.add(name) 315 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve': 316 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:] 317 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group': 318 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:] 319 elif name.startswith('PSA_ECC_FAMILY_') and not parameter: 320 self.ecc_curves.add(name) 321 elif name.startswith('PSA_DH_FAMILY_') and not parameter: 322 self.dh_groups.add(name) 323 elif name.startswith('PSA_ALG_') and not parameter: 324 if name in ['PSA_ALG_ECDSA_BASE', 325 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']: 326 # Ad hoc skipping of duplicate names for some numerical values 327 return 328 self.algorithms.add(name) 329 self.record_algorithm_subtype(name, expansion) 330 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': 331 self.algorithms_from_hash[name] = self.algorithm_tester(name) 332 elif name.startswith('PSA_KEY_USAGE_') and not parameter: 333 self.key_usage_flags.add(name) 334 else: 335 # Other macro without parameter 336 return 337 338 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') 339 _continued_line_re = re.compile(rb'\\\r?\n\Z') 340 def read_file(self, header_file): 341 for line in header_file: 342 m = re.search(self._continued_line_re, line) 343 while m: 344 cont = next(header_file) 345 line = line[:m.start(0)] + cont 346 m = re.search(self._continued_line_re, line) 347 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 348 self.read_line(line) 349 350 351class InputsForTest(PSAMacroEnumerator): 352 # pylint: disable=too-many-instance-attributes 353 """Accumulate information about macros to test. 354enumerate 355 This includes macro names as well as information about their arguments 356 when applicable. 357 """ 358 359 def __init__(self) -> None: 360 super().__init__() 361 self.all_declared = set() #type: Set[str] 362 # Identifier prefixes 363 self.table_by_prefix = { 364 'ERROR': self.statuses, 365 'ALG': self.algorithms, 366 'ECC_CURVE': self.ecc_curves, 367 'DH_GROUP': self.dh_groups, 368 'KEY_LIFETIME': self.lifetimes, 369 'KEY_LOCATION': self.locations, 370 'KEY_PERSISTENCE': self.persistence_levels, 371 'KEY_TYPE': self.key_types, 372 'KEY_USAGE': self.key_usage_flags, 373 } #type: Dict[str, Set[str]] 374 # Test functions 375 self.table_by_test_function = { 376 # Any function ending in _algorithm also gets added to 377 # self.algorithms. 378 'key_type': [self.key_types], 379 'block_cipher_key_type': [self.key_types], 380 'stream_cipher_key_type': [self.key_types], 381 'ecc_key_family': [self.ecc_curves], 382 'ecc_key_types': [self.ecc_curves], 383 'dh_key_family': [self.dh_groups], 384 'dh_key_types': [self.dh_groups], 385 'hash_algorithm': [self.hash_algorithms], 386 'mac_algorithm': [self.mac_algorithms], 387 'cipher_algorithm': [], 388 'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms], 389 'aead_algorithm': [self.aead_algorithms], 390 'key_derivation_algorithm': [self.kdf_algorithms], 391 'key_agreement_algorithm': [self.ka_algorithms], 392 'asymmetric_signature_algorithm': [self.sign_algorithms], 393 'asymmetric_signature_wildcard': [self.algorithms], 394 'asymmetric_encryption_algorithm': [], 395 'other_algorithm': [], 396 'lifetime': [self.lifetimes], 397 } #type: Dict[str, List[Set[str]]] 398 self.arguments_for['mac_length'] += ['1', '63'] 399 self.arguments_for['min_mac_length'] += ['1', '63'] 400 self.arguments_for['tag_length'] += ['1', '63'] 401 self.arguments_for['min_tag_length'] += ['1', '63'] 402 403 def add_numerical_values(self) -> None: 404 """Add numerical values that are not supported to the known identifiers.""" 405 # Sets of names per type 406 self.algorithms.add('0xffffffff') 407 self.ecc_curves.add('0xff') 408 self.dh_groups.add('0xff') 409 self.key_types.add('0xffff') 410 self.key_usage_flags.add('0x80000000') 411 412 # Hard-coded values for unknown algorithms 413 # 414 # These have to have values that are correct for their respective 415 # PSA_ALG_IS_xxx macros, but are also not currently assigned and are 416 # not likely to be assigned in the near future. 417 self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH 418 self.mac_algorithms.add('0x03007fff') 419 self.ka_algorithms.add('0x09fc0000') 420 self.kdf_algorithms.add('0x080000ff') 421 # For AEAD algorithms, the only variability is over the tag length, 422 # and this only applies to known algorithms, so don't test an 423 # unknown algorithm. 424 425 def get_names(self, type_word: str) -> Set[str]: 426 """Return the set of known names of values of the given type.""" 427 return { 428 'status': self.statuses, 429 'algorithm': self.algorithms, 430 'ecc_curve': self.ecc_curves, 431 'dh_group': self.dh_groups, 432 'key_type': self.key_types, 433 'key_usage': self.key_usage_flags, 434 }[type_word] 435 436 # Regex for interesting header lines. 437 # Groups: 1=macro name, 2=type, 3=argument list (optional). 438 _header_line_re = \ 439 re.compile(r'#define +' + 440 r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + 441 r'(?:\(([^\n()]*)\))?') 442 # Regex of macro names to exclude. 443 _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') 444 # Additional excluded macros. 445 _excluded_names = set([ 446 # Macros that provide an alternative way to build the same 447 # algorithm as another macro. 448 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG', 449 'PSA_ALG_FULL_LENGTH_MAC', 450 # Auxiliary macro whose name doesn't fit the usual patterns for 451 # auxiliary macros. 452 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', 453 ]) 454 def parse_header_line(self, line: str) -> None: 455 """Parse a C header line, looking for "#define PSA_xxx".""" 456 m = re.match(self._header_line_re, line) 457 if not m: 458 return 459 name = m.group(1) 460 self.all_declared.add(name) 461 if re.search(self._excluded_name_re, name) or \ 462 name in self._excluded_names or \ 463 self.is_internal_name(name): 464 return 465 dest = self.table_by_prefix.get(m.group(2)) 466 if dest is None: 467 return 468 dest.add(name) 469 if m.group(3): 470 self.argspecs[name] = self._argument_split(m.group(3)) 471 472 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern 473 def parse_header(self, filename: str) -> None: 474 """Parse a C header file, looking for "#define PSA_xxx".""" 475 with read_file_lines(filename, binary=True) as lines: 476 for line in lines: 477 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 478 self.parse_header_line(line) 479 480 _macro_identifier_re = re.compile(r'[A-Z]\w+') 481 def generate_undeclared_names(self, expr: str) -> Iterable[str]: 482 for name in re.findall(self._macro_identifier_re, expr): 483 if name not in self.all_declared: 484 yield name 485 486 def accept_test_case_line(self, function: str, argument: str) -> bool: 487 #pylint: disable=unused-argument 488 undeclared = list(self.generate_undeclared_names(argument)) 489 if undeclared: 490 raise Exception('Undeclared names in test case', undeclared) 491 return True 492 493 @staticmethod 494 def normalize_argument(argument: str) -> str: 495 """Normalize whitespace in the given C expression. 496 497 The result uses the same whitespace as 498 ` PSAMacroEnumerator.distribute_arguments`. 499 """ 500 return re.sub(r',', r', ', re.sub(r' +', r'', argument)) 501 502 def add_test_case_line(self, function: str, argument: str) -> None: 503 """Parse a test case data line, looking for algorithm metadata tests.""" 504 sets = [] 505 if function.endswith('_algorithm'): 506 sets.append(self.algorithms) 507 if function == 'key_agreement_algorithm' and \ 508 argument.startswith('PSA_ALG_KEY_AGREEMENT('): 509 # We only want *raw* key agreement algorithms as such, so 510 # exclude ones that are already chained with a KDF. 511 # Keep the expression as one to test as an algorithm. 512 function = 'other_algorithm' 513 sets += self.table_by_test_function[function] 514 if self.accept_test_case_line(function, argument): 515 for s in sets: 516 s.add(self.normalize_argument(argument)) 517 518 # Regex matching a *.data line containing a test function call and 519 # its arguments. The actual definition is partly positional, but this 520 # regex is good enough in practice. 521 _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') 522 def parse_test_cases(self, filename: str) -> None: 523 """Parse a test case file (*.data), looking for algorithm metadata tests.""" 524 with read_file_lines(filename) as lines: 525 for line in lines: 526 m = re.match(self._test_case_line_re, line) 527 if m: 528 self.add_test_case_line(m.group(1), m.group(2)) 529