1#!/usr/bin/env python3 2 3"""Generate psa_constant_names_generated.c 4which is included by programs/psa/psa_constant_names.c. 5The code generated by this module is only meant to be used in the context 6of that program. 7 8An argument passed to this script will modify the output directory where the 9file is written: 10* by default (no arguments passed): writes to programs/psa/ 11* OUTPUT_FILE_DIR passed: writes to OUTPUT_FILE_DIR/ 12""" 13 14# Copyright The Mbed TLS Contributors 15# SPDX-License-Identifier: Apache-2.0 16# 17# Licensed under the Apache License, Version 2.0 (the "License"); you may 18# not use this file except in compliance with the License. 19# You may obtain a copy of the License at 20# 21# http://www.apache.org/licenses/LICENSE-2.0 22# 23# Unless required by applicable law or agreed to in writing, software 24# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 25# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26# See the License for the specific language governing permissions and 27# limitations under the License. 28 29import os 30import re 31import sys 32 33OUTPUT_TEMPLATE = '''\ 34/* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */ 35 36static const char *psa_strerror(psa_status_t status) 37{ 38 switch (status) { 39 %(status_cases)s 40 default: return NULL; 41 } 42} 43 44static const char *psa_ecc_family_name(psa_ecc_family_t curve) 45{ 46 switch (curve) { 47 %(ecc_curve_cases)s 48 default: return NULL; 49 } 50} 51 52static const char *psa_dh_family_name(psa_dh_family_t group) 53{ 54 switch (group) { 55 %(dh_group_cases)s 56 default: return NULL; 57 } 58} 59 60static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg) 61{ 62 switch (hash_alg) { 63 %(hash_algorithm_cases)s 64 default: return NULL; 65 } 66} 67 68static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg) 69{ 70 switch (ka_alg) { 71 %(ka_algorithm_cases)s 72 default: return NULL; 73 } 74} 75 76static int psa_snprint_key_type(char *buffer, size_t buffer_size, 77 psa_key_type_t type) 78{ 79 size_t required_size = 0; 80 switch (type) { 81 %(key_type_cases)s 82 default: 83 %(key_type_code)s{ 84 return snprintf(buffer, buffer_size, 85 "0x%%04x", (unsigned) type); 86 } 87 break; 88 } 89 buffer[0] = 0; 90 return (int) required_size; 91} 92 93#define NO_LENGTH_MODIFIER 0xfffffffflu 94static int psa_snprint_algorithm(char *buffer, size_t buffer_size, 95 psa_algorithm_t alg) 96{ 97 size_t required_size = 0; 98 psa_algorithm_t core_alg = alg; 99 unsigned long length_modifier = NO_LENGTH_MODIFIER; 100 if (PSA_ALG_IS_MAC(alg)) { 101 core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0); 102 if (core_alg != alg) { 103 append(&buffer, buffer_size, &required_size, 104 "PSA_ALG_TRUNCATED_MAC(", 22); 105 length_modifier = PSA_MAC_TRUNCATED_LENGTH(alg); 106 } 107 } else if (PSA_ALG_IS_AEAD(alg)) { 108 core_alg = PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH(alg); 109 if (core_alg == 0) { 110 /* For unknown AEAD algorithms, there is no "default tag length". */ 111 core_alg = alg; 112 } else if (core_alg != alg) { 113 append(&buffer, buffer_size, &required_size, 114 "PSA_ALG_AEAD_WITH_TAG_LENGTH(", 29); 115 length_modifier = PSA_AEAD_TAG_LENGTH(alg); 116 } 117 } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) && 118 !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) { 119 core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg); 120 append(&buffer, buffer_size, &required_size, 121 "PSA_ALG_KEY_AGREEMENT(", 22); 122 append_with_alg(&buffer, buffer_size, &required_size, 123 psa_ka_algorithm_name, 124 PSA_ALG_KEY_AGREEMENT_GET_BASE(alg)); 125 append(&buffer, buffer_size, &required_size, ", ", 2); 126 } 127 switch (core_alg) { 128 %(algorithm_cases)s 129 default: 130 %(algorithm_code)s{ 131 append_integer(&buffer, buffer_size, &required_size, 132 "0x%%08lx", (unsigned long) core_alg); 133 } 134 break; 135 } 136 if (core_alg != alg) { 137 if (length_modifier != NO_LENGTH_MODIFIER) { 138 append(&buffer, buffer_size, &required_size, ", ", 2); 139 append_integer(&buffer, buffer_size, &required_size, 140 "%%lu", length_modifier); 141 } 142 append(&buffer, buffer_size, &required_size, ")", 1); 143 } 144 buffer[0] = 0; 145 return (int) required_size; 146} 147 148static int psa_snprint_key_usage(char *buffer, size_t buffer_size, 149 psa_key_usage_t usage) 150{ 151 size_t required_size = 0; 152 if (usage == 0) { 153 if (buffer_size > 1) { 154 buffer[0] = '0'; 155 buffer[1] = 0; 156 } else if (buffer_size == 1) { 157 buffer[0] = 0; 158 } 159 return 1; 160 } 161%(key_usage_code)s 162 if (usage != 0) { 163 if (required_size != 0) { 164 append(&buffer, buffer_size, &required_size, " | ", 3); 165 } 166 append_integer(&buffer, buffer_size, &required_size, 167 "0x%%08lx", (unsigned long) usage); 168 } else { 169 buffer[0] = 0; 170 } 171 return (int) required_size; 172} 173 174/* End of automatically generated file. */ 175''' 176 177KEY_TYPE_FROM_CURVE_TEMPLATE = '''if (%(tester)s(type)) { 178 append_with_curve(&buffer, buffer_size, &required_size, 179 "%(builder)s", %(builder_length)s, 180 PSA_KEY_TYPE_ECC_GET_FAMILY(type)); 181 } else ''' 182 183KEY_TYPE_FROM_GROUP_TEMPLATE = '''if (%(tester)s(type)) { 184 append_with_group(&buffer, buffer_size, &required_size, 185 "%(builder)s", %(builder_length)s, 186 PSA_KEY_TYPE_DH_GET_FAMILY(type)); 187 } else ''' 188 189ALGORITHM_FROM_HASH_TEMPLATE = '''if (%(tester)s(core_alg)) { 190 append(&buffer, buffer_size, &required_size, 191 "%(builder)s(", %(builder_length)s + 1); 192 append_with_alg(&buffer, buffer_size, &required_size, 193 psa_hash_algorithm_name, 194 PSA_ALG_GET_HASH(core_alg)); 195 append(&buffer, buffer_size, &required_size, ")", 1); 196 } else ''' 197 198BIT_TEST_TEMPLATE = '''\ 199 if (%(var)s & %(flag)s) { 200 if (required_size != 0) { 201 append(&buffer, buffer_size, &required_size, " | ", 3); 202 } 203 append(&buffer, buffer_size, &required_size, "%(flag)s", %(length)d); 204 %(var)s ^= %(flag)s; 205 }\ 206''' 207 208class MacroCollector: 209 """Collect PSA crypto macro definitions from C header files. 210 211 1. Call `read_file` on the input header file(s). 212 2. Call `write_file` to write ``psa_constant_names_generated.c``. 213 """ 214 215 def __init__(self): 216 self.statuses = set() 217 self.key_types = set() 218 self.key_types_from_curve = {} 219 self.key_types_from_group = {} 220 self.ecc_curves = set() 221 self.dh_groups = set() 222 self.algorithms = set() 223 self.hash_algorithms = set() 224 self.ka_algorithms = set() 225 self.algorithms_from_hash = {} 226 self.key_usages = set() 227 228 # "#define" followed by a macro name with either no parameters 229 # or a single parameter and a non-empty expansion. 230 # Grab the macro name in group 1, the parameter name if any in group 2 231 # and the expansion in group 3. 232 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' + 233 r'(?:\s+|\((\w+)\)\s*)' + 234 r'(.+)') 235 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED') 236 237 def read_line(self, line): 238 """Parse a C header line and record the PSA identifier it defines if any. 239 This function analyzes lines that start with "#define PSA_" 240 (up to non-significant whitespace) and skips all non-matching lines. 241 """ 242 # pylint: disable=too-many-branches 243 m = re.match(self._define_directive_re, line) 244 if not m: 245 return 246 name, parameter, expansion = m.groups() 247 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion) 248 if re.match(self._deprecated_definition_re, expansion): 249 # Skip deprecated values, which are assumed to be 250 # backward compatibility aliases that share 251 # numerical values with non-deprecated values. 252 return 253 if name.endswith('_FLAG') or name.endswith('MASK'): 254 # Macro only to build actual values 255 return 256 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \ 257 and not parameter: 258 self.statuses.add(name) 259 elif name.startswith('PSA_KEY_TYPE_') and not parameter: 260 self.key_types.add(name) 261 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve': 262 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:] 263 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group': 264 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:] 265 elif name.startswith('PSA_ECC_FAMILY_') and not parameter: 266 self.ecc_curves.add(name) 267 elif name.startswith('PSA_DH_FAMILY_') and not parameter: 268 self.dh_groups.add(name) 269 elif name.startswith('PSA_ALG_') and not parameter: 270 if name in ['PSA_ALG_ECDSA_BASE', 271 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']: 272 # Ad hoc skipping of duplicate names for some numerical values 273 return 274 self.algorithms.add(name) 275 # Ad hoc detection of hash algorithms 276 if re.search(r'0x020000[0-9A-Fa-f]{2}', expansion): 277 self.hash_algorithms.add(name) 278 # Ad hoc detection of key agreement algorithms 279 if re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion): 280 self.ka_algorithms.add(name) 281 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg': 282 if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']: 283 # A naming irregularity 284 tester = name[:8] + 'IS_RANDOMIZED_' + name[8:] 285 else: 286 tester = name[:8] + 'IS_' + name[8:] 287 self.algorithms_from_hash[name] = tester 288 elif name.startswith('PSA_KEY_USAGE_') and not parameter: 289 self.key_usages.add(name) 290 else: 291 # Other macro without parameter 292 return 293 294 _nonascii_re = re.compile(rb'[^\x00-\x7f]+') 295 _continued_line_re = re.compile(rb'\\\r?\n\Z') 296 def read_file(self, header_file): 297 for line in header_file: 298 m = re.search(self._continued_line_re, line) 299 while m: 300 cont = next(header_file) 301 line = line[:m.start(0)] + cont 302 m = re.search(self._continued_line_re, line) 303 line = re.sub(self._nonascii_re, rb'', line).decode('ascii') 304 self.read_line(line) 305 306 @staticmethod 307 def _make_return_case(name): 308 return 'case %(name)s: return "%(name)s";' % {'name': name} 309 310 @staticmethod 311 def _make_append_case(name): 312 template = ('case %(name)s: ' 313 'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); ' 314 'break;') 315 return template % {'name': name, 'length': len(name)} 316 317 @staticmethod 318 def _make_bit_test(var, flag): 319 return BIT_TEST_TEMPLATE % {'var': var, 320 'flag': flag, 321 'length': len(flag)} 322 323 def _make_status_cases(self): 324 return '\n '.join(map(self._make_return_case, 325 sorted(self.statuses))) 326 327 def _make_ecc_curve_cases(self): 328 return '\n '.join(map(self._make_return_case, 329 sorted(self.ecc_curves))) 330 331 def _make_dh_group_cases(self): 332 return '\n '.join(map(self._make_return_case, 333 sorted(self.dh_groups))) 334 335 def _make_key_type_cases(self): 336 return '\n '.join(map(self._make_append_case, 337 sorted(self.key_types))) 338 339 @staticmethod 340 def _make_key_type_from_curve_code(builder, tester): 341 return KEY_TYPE_FROM_CURVE_TEMPLATE % {'builder': builder, 342 'builder_length': len(builder), 343 'tester': tester} 344 345 @staticmethod 346 def _make_key_type_from_group_code(builder, tester): 347 return KEY_TYPE_FROM_GROUP_TEMPLATE % {'builder': builder, 348 'builder_length': len(builder), 349 'tester': tester} 350 351 def _make_ecc_key_type_code(self): 352 d = self.key_types_from_curve 353 make = self._make_key_type_from_curve_code 354 return ''.join([make(k, d[k]) for k in sorted(d.keys())]) 355 356 def _make_dh_key_type_code(self): 357 d = self.key_types_from_group 358 make = self._make_key_type_from_group_code 359 return ''.join([make(k, d[k]) for k in sorted(d.keys())]) 360 361 def _make_hash_algorithm_cases(self): 362 return '\n '.join(map(self._make_return_case, 363 sorted(self.hash_algorithms))) 364 365 def _make_ka_algorithm_cases(self): 366 return '\n '.join(map(self._make_return_case, 367 sorted(self.ka_algorithms))) 368 369 def _make_algorithm_cases(self): 370 return '\n '.join(map(self._make_append_case, 371 sorted(self.algorithms))) 372 373 @staticmethod 374 def _make_algorithm_from_hash_code(builder, tester): 375 return ALGORITHM_FROM_HASH_TEMPLATE % {'builder': builder, 376 'builder_length': len(builder), 377 'tester': tester} 378 379 def _make_algorithm_code(self): 380 d = self.algorithms_from_hash 381 make = self._make_algorithm_from_hash_code 382 return ''.join([make(k, d[k]) for k in sorted(d.keys())]) 383 384 def _make_key_usage_code(self): 385 return '\n'.join([self._make_bit_test('usage', bit) 386 for bit in sorted(self.key_usages)]) 387 388 def write_file(self, output_file): 389 """Generate the pretty-printer function code from the gathered 390 constant definitions. 391 """ 392 data = {} 393 data['status_cases'] = self._make_status_cases() 394 data['ecc_curve_cases'] = self._make_ecc_curve_cases() 395 data['dh_group_cases'] = self._make_dh_group_cases() 396 data['key_type_cases'] = self._make_key_type_cases() 397 data['key_type_code'] = (self._make_ecc_key_type_code() + 398 self._make_dh_key_type_code()) 399 data['hash_algorithm_cases'] = self._make_hash_algorithm_cases() 400 data['ka_algorithm_cases'] = self._make_ka_algorithm_cases() 401 data['algorithm_cases'] = self._make_algorithm_cases() 402 data['algorithm_code'] = self._make_algorithm_code() 403 data['key_usage_code'] = self._make_key_usage_code() 404 output_file.write(OUTPUT_TEMPLATE % data) 405 406def generate_psa_constants(header_file_names, output_file_name): 407 collector = MacroCollector() 408 for header_file_name in header_file_names: 409 with open(header_file_name, 'rb') as header_file: 410 collector.read_file(header_file) 411 temp_file_name = output_file_name + '.tmp' 412 with open(temp_file_name, 'w') as output_file: 413 collector.write_file(output_file) 414 os.replace(temp_file_name, output_file_name) 415 416if __name__ == '__main__': 417 if not os.path.isdir('programs') and os.path.isdir('../programs'): 418 os.chdir('..') 419 # Allow to change the directory where psa_constant_names_generated.c is written to. 420 OUTPUT_FILE_DIR = sys.argv[1] if len(sys.argv) == 2 else "programs/psa" 421 generate_psa_constants(['include/psa/crypto_values.h', 422 'include/psa/crypto_extra.h'], 423 OUTPUT_FILE_DIR + '/psa_constant_names_generated.c') 424