1"""Collect macro definitions from header files. 2""" 3 4# Copyright The Mbed TLS Contributors 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the "License"); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18 19import itertools 20import re 21from typing import Dict, IO, Iterable, Iterator, List, Optional, Pattern, Set, Tuple, Union 22 23 24class ReadFileLineException(Exception): 25 def __init__(self, filename: str, line_number: Union[int, str]) -> None: 26 message = 'in {} at {}'.format(filename, line_number) 27 super(ReadFileLineException, self).__init__(message) 28 self.filename = filename 29 self.line_number = line_number 30 31 32class read_file_lines: 33 # Dear Pylint, conventionally, a context manager class name is lowercase. 34 # pylint: disable=invalid-name,too-few-public-methods 35 """Context manager to read a text file line by line. 36 37 ``` 38 with read_file_lines(filename) as lines: 39 for line in lines: 40 process(line) 41 ``` 42 is equivalent to 43 ``` 44 with open(filename, 'r') as input_file: 45 for line in input_file: 46 process(line) 47 ``` 48 except that if process(line) raises an exception, then the read_file_lines 49 snippet annotates the exception with the file name and line number. 50 """ 51 def __init__(self, filename: str, binary: bool = False) -> None: 52 self.filename = filename 53 self.file = None #type: Optional[IO[str]] 54 self.line_number = 'entry' #type: Union[int, str] 55 self.generator = None #type: Optional[Iterable[Tuple[int, str]]] 56 self.binary = binary 57 def __enter__(self) -> 'read_file_lines': 58 self.file = open(self.filename, 'rb' if self.binary else 'r') 59 self.generator = enumerate(self.file) 60 return self 61 def __iter__(self) -> Iterator[str]: 62 assert self.generator is not None 63 for line_number, content in self.generator: 64 self.line_number = line_number 65 yield content 66 self.line_number = 'exit' 67 def __exit__(self, exc_type, exc_value, exc_traceback) -> None: 68 if self.file is not None: 69 self.file.close() 70 if exc_type is not None: 71 raise ReadFileLineException(self.filename, self.line_number) \ 72 from exc_value 73 74 75class PSAMacroEnumerator: 76 """Information about constructors of various PSA Crypto types. 77 78 This includes macro names as well as information about their arguments 79 when applicable. 80 81 This class only provides ways to enumerate expressions that evaluate to 82 values of the covered types. Derived classes are expected to populate 83 the set of known constructors of each kind, as well as populate 84 `self.arguments_for` for arguments that are not of a kind that is 85 enumerated here. 86 """ 87 #pylint: disable=too-many-instance-attributes 88 89 def __init__(self) -> None: 90 """Set up an empty set of known constructor macros. 91 """ 92 self.statuses = set() #type: Set[str] 93 self.lifetimes = set() #type: Set[str] 94 self.locations = set() #type: Set[str] 95 self.persistence_levels = set() #type: Set[str] 96 self.algorithms = set() #type: Set[str] 97 self.ecc_curves = set() #type: Set[str] 98 self.dh_groups = set() #type: Set[str] 99 self.key_types = set() #type: Set[str] 100 self.key_usage_flags = set() #type: Set[str] 101 self.hash_algorithms = set() #type: Set[str] 102 self.mac_algorithms = set() #type: Set[str] 103 self.ka_algorithms = set() #type: Set[str] 104 self.kdf_algorithms = set() #type: Set[str] 105 self.pake_algorithms = set() #type: Set[str] 106 self.aead_algorithms = set() #type: Set[str] 107 self.sign_algorithms = set() #type: Set[str] 108 # macro name -> list of argument names 109 self.argspecs = {} #type: Dict[str, List[str]] 110 # argument name -> list of values 111 self.arguments_for = { 112 'mac_length': [], 113 'min_mac_length': [], 114 'tag_length': [], 115 'min_tag_length': [], 116 } #type: Dict[str, List[str]] 117 # Whether to include intermediate macros in enumerations. Intermediate 118 # macros serve as category headers and are not valid values of their 119 # type. See `is_internal_name`. 120 # Always false in this class, may be set to true in derived classes. 121 self.include_intermediate = False 122 123 def is_internal_name(self, name: str) -> bool: 124 """Whether this is an internal macro. Internal macros will be skipped.""" 125 if not self.include_intermediate: 126 if name.endswith('_BASE') or name.endswith('_NONE'): 127 return True 128 if '_CATEGORY_' in name: 129 return True 130 return name.endswith('_FLAG') or name.endswith('_MASK') 131 132 def gather_arguments(self) -> None: 133 """Populate the list of values for macro arguments. 134 135 Call this after parsing all the inputs. 136 """ 137 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) 138 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) 139 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) 140 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) 141 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) 142 self.arguments_for['sign_alg'] = sorted(self.sign_algorithms) 143 self.arguments_for['curve'] = sorted(self.ecc_curves) 144 self.arguments_for['group'] = sorted(self.dh_groups) 145 self.arguments_for['persistence'] = sorted(self.persistence_levels) 146 self.arguments_for['location'] = sorted(self.locations) 147 self.arguments_for['lifetime'] = sorted(self.lifetimes) 148 149 @staticmethod 150 def _format_arguments(name: str, arguments: Iterable[str]) -> str: 151 """Format a macro call with arguments. 152 153 The resulting format is consistent with 154 `InputsForTest.normalize_argument`. 155 """ 156 return name + '(' + ', '.join(arguments) + ')' 157 158 _argument_split_re = re.compile(r' *, *') 159 @classmethod 160 def _argument_split(cls, arguments: str) -> List[str]: 161 return re.split(cls._argument_split_re, arguments) 162 163 def distribute_arguments(self, name: str) -> Iterator[str]: 164 """Generate macro calls with each tested argument set. 165 166 If name is a macro without arguments, just yield "name". 167 If name is a macro with arguments, yield a series of 168 "name(arg1,...,argN)" where each argument takes each possible 169 value at least once. 170 """ 171 try: 172 if name not in self.argspecs: 173 yield name 174 return 175 argspec = self.argspecs[name] 176 if argspec == []: 177 yield name + '()' 178 return 179 argument_lists = [self.arguments_for[arg] for arg in argspec] 180 arguments = [values[0] for values in argument_lists] 181 yield self._format_arguments(name, arguments) 182 # Dear Pylint, enumerate won't work here since we're modifying 183 # the array. 184 # pylint: disable=consider-using-enumerate 185 for i in range(len(arguments)): 186 for value in argument_lists[i][1:]: 187 arguments[i] = value 188 yield self._format_arguments(name, arguments) 189 arguments[i] = argument_lists[i][0] 190 except BaseException as e: 191 raise Exception('distribute_arguments({})'.format(name)) from e 192 193 def distribute_arguments_without_duplicates( 194 self, seen: Set[str], name: str 195 ) -> Iterator[str]: 196 """Same as `distribute_arguments`, but don't repeat seen results.""" 197 for result in self.distribute_arguments(name): 198 if result not in seen: 199 seen.add(result) 200 yield result 201 202 def generate_expressions(self, names: Iterable[str]) -> Iterator[str]: 203 """Generate expressions covering values constructed from the given names. 204 205 `names` can be any iterable collection of macro names. 206 207 For example: 208 * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])`` 209 generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for 210 every known hash algorithm ``h``. 211 * ``macros.generate_expressions(macros.key_types)`` generates all 212 key types. 213 """ 214 seen = set() #type: Set[str] 215 return itertools.chain(*( 216 self.distribute_arguments_without_duplicates(seen, name) 217 for name in names 218 )) 219 220 221class PSAMacroCollector(PSAMacroEnumerator): 222 """Collect PSA crypto macro definitions from C header files. 223 """ 224 225 def __init__(self, include_intermediate: bool = False) -> None: 226 """Set up an object to collect PSA macro definitions. 227 228 Call the read_file method of the constructed object on each header file. 229 230 * include_intermediate: if true, include intermediate macros such as 231 PSA_XXX_BASE that do not designate semantic values. 232 """ 233 super().__init__() 234 self.include_intermediate = include_intermediate 235 self.key_types_from_curve = {} #type: Dict[str, str] 236 self.key_types_from_group = {} #type: Dict[str, str] 237 self.algorithms_from_hash = {} #type: Dict[str, str] 238 239 @staticmethod 240 def algorithm_tester(name: str) -> str: 241 """The predicate for whether an algorithm is built from the given constructor. 242 243 The given name must be the name of an algorithm constructor of the 244 form ``PSA_ALG_xxx`` which is used as ``PSA_ALG_xxx(yyy)`` to build 245 an algorithm value. Return the corresponding predicate macro which 246 is used as ``predicate(alg)`` to test whether ``alg`` can be built 247 as ``PSA_ALG_xxx(yyy)``. The predicate is usually called 248 ``PSA_ALG_IS_xxx``. 249 """ 250 prefix = 'PSA_ALG_' 251 assert name.startswith(prefix) 252 midfix = 'IS_' 253 suffix = name[len(prefix):] 254 if suffix in ['DSA', 'ECDSA']: 255 midfix += 'RANDOMIZED_' 256 elif suffix == 'RSA_PSS': 257 suffix += '_STANDARD_SALT' 258 return prefix + midfix + suffix 259 260 def record_algorithm_subtype(self, name: str, expansion: str) -> None: 261 """Record the subtype of an algorithm constructor. 262 263 Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm 264 is of a subtype that is tracked in its own set, add it to the relevant 265 set. 266 """ 267 # This code is very ad hoc and fragile. It should be replaced by 268 # something more robust. 269 if re.match(r'MAC(?:_|\Z)', name): 270 self.mac_algorithms.add(name) 271 elif re.match(r'KDF(?:_|\Z)', name): 272 self.kdf_algorithms.add(name) 273 elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): 274 self.hash_algorithms.add(name) 275 elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion): 276 self.mac_algorithms.add(name) 277 elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion): 278 self.aead_algorithms.add(name) 279 elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): 280 self.ka_algorithms.add(name) 281 elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion): 282 self.kdf_algorithms.add(name) 283 284 # "#define" followed by a macro name with either no parameters 285 # or a single parameter and a non-empty expansion. 286 # Grab the macro name in group 1, the parameter name if any in group 2 287 # and the expansion in group 3. 288 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' + 289 r'(?:\s+|\((\w+)\)\s*)' + 290 r'(.+)') 291 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED') 292 293 def read_line(self, line): 294 """Parse a C header line and record the PSA identifier it defines if any. 295 This function analyzes lines that start with "#define PSA_" 296 (up to non-significant whitespace) and skips all non-matching lines. 297 """ 298 # pylint: disable=too-many-branches 299 m = re.match(self._define_directive_re, line) 300 if not m: 301 return 302 name, parameter, expansion = m.groups() 303 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion) 304 if parameter: 305 self.argspecs[name] = [parameter] 306 if re.match(self._deprecated_definition_re, expansion): 307 # Skip deprecated values, which are assumed to be 308 # backward compatibility aliases that share 309 # numerical values with non-deprecated values. 310 return 311 if self.is_internal_name(name): 312 # Macro only to build actual values 313 return 314 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \ 315 and not parameter: 316 self.statuses.add(name) 317 elif name.startswith('PSA_KEY_TYPE_') and not parameter: 318 self.key_types.add(name) 319 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve': 320 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:] 321 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group': 322 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:] 323 elif name.startswith('PSA_ECC_FAMILY_') and not parameter: 324 self.ecc_curves.add(name) 325 elif name.startswith('PSA_DH_FAMILY_') and not parameter: 326 self.dh_groups.add(name) 327 elif name.startswith('PSA_ALG_') and not parameter: 328 if name in ['PSA_ALG_ECDSA_BASE', 329 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']: 330 # Ad hoc skipping of duplicate names for some numerical values 331 return 332 self.algorithms.add(name) 333 self.record_algorithm_subtype(name, expansion) 334 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': 335 self.algorithms_from_hash[name] = self.algorithm_tester(name) 336 elif name.startswith('PSA_KEY_USAGE_') and not parameter: 337 self.key_usage_flags.add(name) 338 else: 339 # Other macro without parameter 340 return 341 342 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') 343 _continued_line_re = re.compile(rb'\\\r?\n\Z') 344 def read_file(self, header_file): 345 for line in header_file: 346 m = re.search(self._continued_line_re, line) 347 while m: 348 cont = next(header_file) 349 line = line[:m.start(0)] + cont 350 m = re.search(self._continued_line_re, line) 351 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 352 self.read_line(line) 353 354 355class InputsForTest(PSAMacroEnumerator): 356 # pylint: disable=too-many-instance-attributes 357 """Accumulate information about macros to test. 358enumerate 359 This includes macro names as well as information about their arguments 360 when applicable. 361 """ 362 363 def __init__(self) -> None: 364 super().__init__() 365 self.all_declared = set() #type: Set[str] 366 # Identifier prefixes 367 self.table_by_prefix = { 368 'ERROR': self.statuses, 369 'ALG': self.algorithms, 370 'ECC_CURVE': self.ecc_curves, 371 'DH_GROUP': self.dh_groups, 372 'KEY_LIFETIME': self.lifetimes, 373 'KEY_LOCATION': self.locations, 374 'KEY_PERSISTENCE': self.persistence_levels, 375 'KEY_TYPE': self.key_types, 376 'KEY_USAGE': self.key_usage_flags, 377 } #type: Dict[str, Set[str]] 378 # Test functions 379 self.table_by_test_function = { 380 # Any function ending in _algorithm also gets added to 381 # self.algorithms. 382 'key_type': [self.key_types], 383 'block_cipher_key_type': [self.key_types], 384 'stream_cipher_key_type': [self.key_types], 385 'ecc_key_family': [self.ecc_curves], 386 'ecc_key_types': [self.ecc_curves], 387 'dh_key_family': [self.dh_groups], 388 'dh_key_types': [self.dh_groups], 389 'hash_algorithm': [self.hash_algorithms], 390 'mac_algorithm': [self.mac_algorithms], 391 'cipher_algorithm': [], 392 'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms], 393 'aead_algorithm': [self.aead_algorithms], 394 'key_derivation_algorithm': [self.kdf_algorithms], 395 'key_agreement_algorithm': [self.ka_algorithms], 396 'asymmetric_signature_algorithm': [self.sign_algorithms], 397 'asymmetric_signature_wildcard': [self.algorithms], 398 'asymmetric_encryption_algorithm': [], 399 'pake_algorithm': [self.pake_algorithms], 400 'other_algorithm': [], 401 'lifetime': [self.lifetimes], 402 } #type: Dict[str, List[Set[str]]] 403 mac_lengths = [str(n) for n in [ 404 1, # minimum expressible 405 4, # minimum allowed by policy 406 13, # an odd size in a plausible range 407 14, # an even non-power-of-two size in a plausible range 408 16, # same as full size for at least one algorithm 409 63, # maximum expressible 410 ]] 411 self.arguments_for['mac_length'] += mac_lengths 412 self.arguments_for['min_mac_length'] += mac_lengths 413 aead_lengths = [str(n) for n in [ 414 1, # minimum expressible 415 4, # minimum allowed by policy 416 13, # an odd size in a plausible range 417 14, # an even non-power-of-two size in a plausible range 418 16, # same as full size for at least one algorithm 419 63, # maximum expressible 420 ]] 421 self.arguments_for['tag_length'] += aead_lengths 422 self.arguments_for['min_tag_length'] += aead_lengths 423 424 def add_numerical_values(self) -> None: 425 """Add numerical values that are not supported to the known identifiers.""" 426 # Sets of names per type 427 self.algorithms.add('0xffffffff') 428 self.ecc_curves.add('0xff') 429 self.dh_groups.add('0xff') 430 self.key_types.add('0xffff') 431 self.key_usage_flags.add('0x80000000') 432 433 # Hard-coded values for unknown algorithms 434 # 435 # These have to have values that are correct for their respective 436 # PSA_ALG_IS_xxx macros, but are also not currently assigned and are 437 # not likely to be assigned in the near future. 438 self.hash_algorithms.add('0x020000fe') # 0x020000ff is PSA_ALG_ANY_HASH 439 self.mac_algorithms.add('0x03007fff') 440 self.ka_algorithms.add('0x09fc0000') 441 self.kdf_algorithms.add('0x080000ff') 442 self.pake_algorithms.add('0x0a0000ff') 443 # For AEAD algorithms, the only variability is over the tag length, 444 # and this only applies to known algorithms, so don't test an 445 # unknown algorithm. 446 447 def get_names(self, type_word: str) -> Set[str]: 448 """Return the set of known names of values of the given type.""" 449 return { 450 'status': self.statuses, 451 'algorithm': self.algorithms, 452 'ecc_curve': self.ecc_curves, 453 'dh_group': self.dh_groups, 454 'key_type': self.key_types, 455 'key_usage': self.key_usage_flags, 456 }[type_word] 457 458 # Regex for interesting header lines. 459 # Groups: 1=macro name, 2=type, 3=argument list (optional). 460 _header_line_re = \ 461 re.compile(r'#define +' + 462 r'(PSA_((?:(?:DH|ECC|KEY)_)?[A-Z]+)_\w+)' + 463 r'(?:\(([^\n()]*)\))?') 464 # Regex of macro names to exclude. 465 _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') 466 # Additional excluded macros. 467 _excluded_names = set([ 468 # Macros that provide an alternative way to build the same 469 # algorithm as another macro. 470 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG', 471 'PSA_ALG_FULL_LENGTH_MAC', 472 # Auxiliary macro whose name doesn't fit the usual patterns for 473 # auxiliary macros. 474 'PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG_CASE', 475 ]) 476 def parse_header_line(self, line: str) -> None: 477 """Parse a C header line, looking for "#define PSA_xxx".""" 478 m = re.match(self._header_line_re, line) 479 if not m: 480 return 481 name = m.group(1) 482 self.all_declared.add(name) 483 if re.search(self._excluded_name_re, name) or \ 484 name in self._excluded_names or \ 485 self.is_internal_name(name): 486 return 487 dest = self.table_by_prefix.get(m.group(2)) 488 if dest is None: 489 return 490 dest.add(name) 491 if m.group(3): 492 self.argspecs[name] = self._argument_split(m.group(3)) 493 494 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') #type: Pattern 495 def parse_header(self, filename: str) -> None: 496 """Parse a C header file, looking for "#define PSA_xxx".""" 497 with read_file_lines(filename, binary=True) as lines: 498 for line in lines: 499 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 500 self.parse_header_line(line) 501 502 _macro_identifier_re = re.compile(r'[A-Z]\w+') 503 def generate_undeclared_names(self, expr: str) -> Iterable[str]: 504 for name in re.findall(self._macro_identifier_re, expr): 505 if name not in self.all_declared: 506 yield name 507 508 def accept_test_case_line(self, function: str, argument: str) -> bool: 509 #pylint: disable=unused-argument 510 undeclared = list(self.generate_undeclared_names(argument)) 511 if undeclared: 512 raise Exception('Undeclared names in test case', undeclared) 513 return True 514 515 @staticmethod 516 def normalize_argument(argument: str) -> str: 517 """Normalize whitespace in the given C expression. 518 519 The result uses the same whitespace as 520 ` PSAMacroEnumerator.distribute_arguments`. 521 """ 522 return re.sub(r',', r', ', re.sub(r' +', r'', argument)) 523 524 def add_test_case_line(self, function: str, argument: str) -> None: 525 """Parse a test case data line, looking for algorithm metadata tests.""" 526 sets = [] 527 if function.endswith('_algorithm'): 528 sets.append(self.algorithms) 529 if function == 'key_agreement_algorithm' and \ 530 argument.startswith('PSA_ALG_KEY_AGREEMENT('): 531 # We only want *raw* key agreement algorithms as such, so 532 # exclude ones that are already chained with a KDF. 533 # Keep the expression as one to test as an algorithm. 534 function = 'other_algorithm' 535 sets += self.table_by_test_function[function] 536 if self.accept_test_case_line(function, argument): 537 for s in sets: 538 s.add(self.normalize_argument(argument)) 539 540 # Regex matching a *.data line containing a test function call and 541 # its arguments. The actual definition is partly positional, but this 542 # regex is good enough in practice. 543 _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') 544 def parse_test_cases(self, filename: str) -> None: 545 """Parse a test case file (*.data), looking for algorithm metadata tests.""" 546 with read_file_lines(filename) as lines: 547 for line in lines: 548 m = re.match(self._test_case_line_re, line) 549 if m: 550 self.add_test_case_line(m.group(1), m.group(2)) 551