1#!/usr/bin/env python3 2"""Generate wrapper functions for PSA function calls. 3""" 4 5# Copyright The Mbed TLS Contributors 6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 7 8### WARNING: the code in this file has not been extensively reviewed yet. 9### We do not think it is harmful, but it may be below our normal standards 10### for robustness and maintainability. 11 12import argparse 13import itertools 14import os 15from typing import Iterator, List, Optional, Tuple 16 17import scripts_path #pylint: disable=unused-import 18from mbedtls_dev import build_tree 19from mbedtls_dev import c_parsing_helper 20from mbedtls_dev import c_wrapper_generator 21from mbedtls_dev import typing_util 22 23 24class BufferParameter: 25 """Description of an input or output buffer parameter sequence to a PSA function.""" 26 #pylint: disable=too-few-public-methods 27 28 def __init__(self, i: int, is_output: bool, 29 buffer_name: str, size_name: str) -> None: 30 """Initialize the parameter information. 31 32 i is the index of the function argument that is the pointer to the buffer. 33 The size is argument i+1. For a variable-size output, the actual length 34 goes in argument i+2. 35 36 buffer_name and size_names are the names of arguments i and i+1. 37 This class does not yet help with the output length. 38 """ 39 self.index = i 40 self.buffer_name = buffer_name 41 self.size_name = size_name 42 self.is_output = is_output 43 44 45class PSAWrapperGenerator(c_wrapper_generator.Base): 46 """Generate a C source file containing wrapper functions for PSA Crypto API calls.""" 47 48 _CPP_GUARDS = ('defined(MBEDTLS_PSA_CRYPTO_C) && ' + 49 'defined(MBEDTLS_TEST_HOOKS) && \\\n ' + 50 '!defined(RECORD_PSA_STATUS_COVERAGE_LOG)') 51 _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_' 52 _WRAPPER_NAME_SUFFIX = '' 53 54 def gather_data(self) -> None: 55 root_dir = build_tree.guess_mbedtls_root() 56 for header_name in ['crypto.h', 'crypto_extra.h']: 57 header_path = os.path.join(root_dir, 'include', 'psa', header_name) 58 c_parsing_helper.read_function_declarations(self.functions, header_path) 59 60 _SKIP_FUNCTIONS = frozenset([ 61 'mbedtls_psa_external_get_random', # not a library function 62 'psa_get_key_domain_parameters', # client-side function 63 'psa_get_key_slot_number', # client-side function 64 'psa_key_derivation_verify_bytes', # not implemented yet 65 'psa_key_derivation_verify_key', # not implemented yet 66 'psa_set_key_domain_parameters', # client-side function 67 ]) 68 69 def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool: 70 if function.return_type != 'psa_status_t': 71 return True 72 if function.name in self._SKIP_FUNCTIONS: 73 return True 74 return False 75 76 # PAKE stuff: not implemented yet 77 _PAKE_STUFF = frozenset([ 78 'psa_crypto_driver_pake_inputs_t *', 79 'psa_pake_cipher_suite_t *', 80 ]) 81 82 def _return_variable_name(self, 83 function: c_wrapper_generator.FunctionInfo) -> str: 84 """The name of the variable that will contain the return value.""" 85 if function.return_type == 'psa_status_t': 86 return 'status' 87 return super()._return_variable_name(function) 88 89 _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \ 90 #pylint: disable=protected-access 91 _FUNCTION_GUARDS.update({ 92 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)', 93 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)', 94 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)', 95 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)', 96 }) 97 98 @staticmethod 99 def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo], 100 argument_names: List[str]) -> Iterator[BufferParameter]: 101 """Detect function arguments that are buffers (pointer, size [,length]).""" 102 types = ['' if arg.suffix else arg.type for arg in arguments] 103 # pairs = list of (type_of_arg_N, type_of_arg_N+1) 104 # where each type_of_arg_X is the empty string if the type is an array 105 # or there is no argument X. 106 pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue='')) 107 for i, t01 in pairs: 108 if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \ 109 t01[1] == 'size_t': 110 yield BufferParameter(i, not t01[0].startswith('const '), 111 argument_names[i], argument_names[i+1]) 112 113 @staticmethod 114 def _write_poison_buffer_parameter(out: typing_util.Writable, 115 param: BufferParameter, 116 poison: bool) -> None: 117 """Write poisoning or unpoisoning code for a buffer parameter. 118 119 Write poisoning code if poison is true, unpoisoning code otherwise. 120 """ 121 out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format( 122 'POISON' if poison else 'UNPOISON', 123 param.buffer_name, param.size_name 124 )) 125 126 def _write_poison_buffer_parameters(self, out: typing_util.Writable, 127 buffer_parameters: List[BufferParameter], 128 poison: bool) -> None: 129 """Write poisoning or unpoisoning code for the buffer parameters. 130 131 Write poisoning code if poison is true, unpoisoning code otherwise. 132 """ 133 if not buffer_parameters: 134 return 135 out.write('#if !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS)\n') 136 for param in buffer_parameters: 137 self._write_poison_buffer_parameter(out, param, poison) 138 out.write('#endif /* !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS) */\n') 139 140 @staticmethod 141 def _parameter_should_be_copied(function_name: str, 142 _buffer_name: Optional[str]) -> bool: 143 """Whether the specified buffer argument to a PSA function should be copied. 144 """ 145 # False-positives that do not need buffer copying 146 if function_name in ('mbedtls_psa_inject_entropy', 147 'psa_crypto_driver_pake_get_password', 148 'psa_crypto_driver_pake_get_user', 149 'psa_crypto_driver_pake_get_peer'): 150 return False 151 152 return True 153 154 def _write_function_call(self, out: typing_util.Writable, 155 function: c_wrapper_generator.FunctionInfo, 156 argument_names: List[str]) -> None: 157 buffer_parameters = list( 158 param 159 for param in self._detect_buffer_parameters(function.arguments, 160 argument_names) 161 if self._parameter_should_be_copied(function.name, 162 function.arguments[param.index].name)) 163 self._write_poison_buffer_parameters(out, buffer_parameters, True) 164 super()._write_function_call(out, function, argument_names) 165 self._write_poison_buffer_parameters(out, buffer_parameters, False) 166 167 def _write_prologue(self, out: typing_util.Writable, header: bool) -> None: 168 super()._write_prologue(out, header) 169 out.write(""" 170#if {} 171 172#include <psa/crypto.h> 173 174#include <test/memory.h> 175#include <test/psa_crypto_helpers.h> 176#include <test/psa_test_wrappers.h> 177""" 178 .format(self._CPP_GUARDS)) 179 180 def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None: 181 out.write(""" 182#endif /* {} */ 183""" 184 .format(self._CPP_GUARDS)) 185 super()._write_epilogue(out, header) 186 187 188class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging): 189 """Generate a C source file containing wrapper functions that log PSA Crypto API calls.""" 190 191 def __init__(self, stream: str) -> None: 192 super().__init__() 193 self.set_stream(stream) 194 195 _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy() 196 _PRINTF_TYPE_CAST.update({ 197 'mbedtls_svc_key_id_t': 'unsigned', 198 'psa_algorithm_t': 'unsigned', 199 'psa_drv_slot_number_t': 'unsigned long long', 200 'psa_key_derivation_step_t': 'int', 201 'psa_key_id_t': 'unsigned', 202 'psa_key_slot_number_t': 'unsigned long long', 203 'psa_key_lifetime_t': 'unsigned', 204 'psa_key_type_t': 'unsigned', 205 'psa_key_usage_flags_t': 'unsigned', 206 'psa_pake_role_t': 'int', 207 'psa_pake_step_t': 'int', 208 'psa_status_t': 'int', 209 }) 210 211 def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]: 212 if typ.startswith('const '): 213 typ = typ[6:] 214 if typ == 'uint8_t *': 215 # Skip buffers 216 return '', [] 217 if typ.endswith('operation_t *'): 218 return '', [] 219 if typ in self._PAKE_STUFF: 220 return '', [] 221 if typ == 'psa_key_attributes_t *': 222 return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}', 223 ['(unsigned) psa_get_key_{}({})'.format(field, var) 224 for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']]) 225 return super()._printf_parameters(typ, var) 226 227 228DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c' 229DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h' 230 231def main() -> None: 232 parser = argparse.ArgumentParser(description=globals()['__doc__']) 233 parser.add_argument('--log', 234 help='Stream to log to (default: no logging code)') 235 parser.add_argument('--output-c', 236 metavar='FILENAME', 237 default=DEFAULT_C_OUTPUT_FILE_NAME, 238 help=('Output .c file path (default: {}; skip .c output if empty)' 239 .format(DEFAULT_C_OUTPUT_FILE_NAME))) 240 parser.add_argument('--output-h', 241 metavar='FILENAME', 242 default=DEFAULT_H_OUTPUT_FILE_NAME, 243 help=('Output .h file path (default: {}; skip .h output if empty)' 244 .format(DEFAULT_H_OUTPUT_FILE_NAME))) 245 options = parser.parse_args() 246 if options.log: 247 generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator 248 else: 249 generator = PSAWrapperGenerator() 250 generator.gather_data() 251 if options.output_h: 252 generator.write_h_file(options.output_h) 253 if options.output_c: 254 generator.write_c_file(options.output_c) 255 256if __name__ == '__main__': 257 main() 258