1#!/usr/bin/env python3
2"""Generate wrapper functions for PSA function calls.
3"""
4
5# Copyright The Mbed TLS Contributors
6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7
8### WARNING: the code in this file has not been extensively reviewed yet.
9### We do not think it is harmful, but it may be below our normal standards
10### for robustness and maintainability.
11
12import argparse
13import itertools
14import os
15from typing import Iterator, List, Optional, Tuple
16
17import scripts_path #pylint: disable=unused-import
18from mbedtls_dev import build_tree
19from mbedtls_dev import c_parsing_helper
20from mbedtls_dev import c_wrapper_generator
21from mbedtls_dev import typing_util
22
23
24class BufferParameter:
25    """Description of an input or output buffer parameter sequence to a PSA function."""
26    #pylint: disable=too-few-public-methods
27
28    def __init__(self, i: int, is_output: bool,
29                 buffer_name: str, size_name: str) -> None:
30        """Initialize the parameter information.
31
32        i is the index of the function argument that is the pointer to the buffer.
33        The size is argument i+1. For a variable-size output, the actual length
34        goes in argument i+2.
35
36        buffer_name and size_names are the names of arguments i and i+1.
37        This class does not yet help with the output length.
38        """
39        self.index = i
40        self.buffer_name = buffer_name
41        self.size_name = size_name
42        self.is_output = is_output
43
44
45class PSAWrapperGenerator(c_wrapper_generator.Base):
46    """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
47
48    _CPP_GUARDS = ('defined(MBEDTLS_PSA_CRYPTO_C) && ' +
49                   'defined(MBEDTLS_TEST_HOOKS) && \\\n    ' +
50                   '!defined(RECORD_PSA_STATUS_COVERAGE_LOG)')
51    _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
52    _WRAPPER_NAME_SUFFIX = ''
53
54    def gather_data(self) -> None:
55        root_dir = build_tree.guess_mbedtls_root()
56        for header_name in ['crypto.h', 'crypto_extra.h']:
57            header_path = os.path.join(root_dir, 'include', 'psa', header_name)
58            c_parsing_helper.read_function_declarations(self.functions, header_path)
59
60    _SKIP_FUNCTIONS = frozenset([
61        'mbedtls_psa_external_get_random', # not a library function
62        'psa_get_key_domain_parameters', # client-side function
63        'psa_get_key_slot_number', # client-side function
64        'psa_key_derivation_verify_bytes', # not implemented yet
65        'psa_key_derivation_verify_key', # not implemented yet
66        'psa_set_key_domain_parameters', # client-side function
67    ])
68
69    def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
70        if function.return_type != 'psa_status_t':
71            return True
72        if function.name in self._SKIP_FUNCTIONS:
73            return True
74        return False
75
76    # PAKE stuff: not implemented yet
77    _PAKE_STUFF = frozenset([
78        'psa_crypto_driver_pake_inputs_t *',
79        'psa_pake_cipher_suite_t *',
80    ])
81
82    def _return_variable_name(self,
83                              function: c_wrapper_generator.FunctionInfo) -> str:
84        """The name of the variable that will contain the return value."""
85        if function.return_type == 'psa_status_t':
86            return 'status'
87        return super()._return_variable_name(function)
88
89    _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
90        #pylint: disable=protected-access
91    _FUNCTION_GUARDS.update({
92        'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
93        'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
94        'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
95        'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
96    })
97
98    @staticmethod
99    def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
100                                  argument_names: List[str]) -> Iterator[BufferParameter]:
101        """Detect function arguments that are buffers (pointer, size [,length])."""
102        types = ['' if arg.suffix else arg.type for arg in arguments]
103        # pairs = list of (type_of_arg_N, type_of_arg_N+1)
104        # where each type_of_arg_X is the empty string if the type is an array
105        # or there is no argument X.
106        pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
107        for i, t01 in pairs:
108            if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
109               t01[1] == 'size_t':
110                yield BufferParameter(i, not t01[0].startswith('const '),
111                                      argument_names[i], argument_names[i+1])
112
113    @staticmethod
114    def _write_poison_buffer_parameter(out: typing_util.Writable,
115                                       param: BufferParameter,
116                                       poison: bool) -> None:
117        """Write poisoning or unpoisoning code for a buffer parameter.
118
119        Write poisoning code if poison is true, unpoisoning code otherwise.
120        """
121        out.write('    MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
122            'POISON' if poison else 'UNPOISON',
123            param.buffer_name, param.size_name
124        ))
125
126    def _write_poison_buffer_parameters(self, out: typing_util.Writable,
127                                        buffer_parameters: List[BufferParameter],
128                                        poison: bool) -> None:
129        """Write poisoning or unpoisoning code for the buffer parameters.
130
131        Write poisoning code if poison is true, unpoisoning code otherwise.
132        """
133        if not buffer_parameters:
134            return
135        out.write('#if !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS)\n')
136        for param in buffer_parameters:
137            self._write_poison_buffer_parameter(out, param, poison)
138        out.write('#endif /* !defined(MBEDTLS_PSA_ASSUME_EXCLUSIVE_BUFFERS) */\n')
139
140    @staticmethod
141    def _parameter_should_be_copied(function_name: str,
142                                    _buffer_name: Optional[str]) -> bool:
143        """Whether the specified buffer argument to a PSA function should be copied.
144        """
145        # False-positives that do not need buffer copying
146        if function_name in ('mbedtls_psa_inject_entropy',
147                             'psa_crypto_driver_pake_get_password',
148                             'psa_crypto_driver_pake_get_user',
149                             'psa_crypto_driver_pake_get_peer'):
150            return False
151
152        return True
153
154    def _write_function_call(self, out: typing_util.Writable,
155                             function: c_wrapper_generator.FunctionInfo,
156                             argument_names: List[str]) -> None:
157        buffer_parameters = list(
158            param
159            for param in self._detect_buffer_parameters(function.arguments,
160                                                        argument_names)
161            if self._parameter_should_be_copied(function.name,
162                                                function.arguments[param.index].name))
163        self._write_poison_buffer_parameters(out, buffer_parameters, True)
164        super()._write_function_call(out, function, argument_names)
165        self._write_poison_buffer_parameters(out, buffer_parameters, False)
166
167    def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
168        super()._write_prologue(out, header)
169        out.write("""
170#if {}
171
172#include <psa/crypto.h>
173
174#include <test/memory.h>
175#include <test/psa_crypto_helpers.h>
176#include <test/psa_test_wrappers.h>
177"""
178                  .format(self._CPP_GUARDS))
179
180    def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
181        out.write("""
182#endif /* {} */
183"""
184                  .format(self._CPP_GUARDS))
185        super()._write_epilogue(out, header)
186
187
188class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
189    """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""
190
191    def __init__(self, stream: str) -> None:
192        super().__init__()
193        self.set_stream(stream)
194
195    _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
196    _PRINTF_TYPE_CAST.update({
197        'mbedtls_svc_key_id_t': 'unsigned',
198        'psa_algorithm_t': 'unsigned',
199        'psa_drv_slot_number_t': 'unsigned long long',
200        'psa_key_derivation_step_t': 'int',
201        'psa_key_id_t': 'unsigned',
202        'psa_key_slot_number_t': 'unsigned long long',
203        'psa_key_lifetime_t': 'unsigned',
204        'psa_key_type_t': 'unsigned',
205        'psa_key_usage_flags_t': 'unsigned',
206        'psa_pake_role_t': 'int',
207        'psa_pake_step_t': 'int',
208        'psa_status_t': 'int',
209    })
210
211    def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
212        if typ.startswith('const '):
213            typ = typ[6:]
214        if typ == 'uint8_t *':
215            # Skip buffers
216            return '', []
217        if typ.endswith('operation_t *'):
218            return '', []
219        if typ in self._PAKE_STUFF:
220            return '', []
221        if typ == 'psa_key_attributes_t *':
222            return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
223                    ['(unsigned) psa_get_key_{}({})'.format(field, var)
224                     for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
225        return super()._printf_parameters(typ, var)
226
227
228DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
229DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'
230
231def main() -> None:
232    parser = argparse.ArgumentParser(description=globals()['__doc__'])
233    parser.add_argument('--log',
234                        help='Stream to log to (default: no logging code)')
235    parser.add_argument('--output-c',
236                        metavar='FILENAME',
237                        default=DEFAULT_C_OUTPUT_FILE_NAME,
238                        help=('Output .c file path (default: {}; skip .c output if empty)'
239                              .format(DEFAULT_C_OUTPUT_FILE_NAME)))
240    parser.add_argument('--output-h',
241                        metavar='FILENAME',
242                        default=DEFAULT_H_OUTPUT_FILE_NAME,
243                        help=('Output .h file path (default: {}; skip .h output if empty)'
244                              .format(DEFAULT_H_OUTPUT_FILE_NAME)))
245    options = parser.parse_args()
246    if options.log:
247        generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
248    else:
249        generator = PSAWrapperGenerator()
250    generator.gather_data()
251    if options.output_h:
252        generator.write_h_file(options.output_h)
253    if options.output_c:
254        generator.write_c_file(options.output_c)
255
256if __name__ == '__main__':
257    main()
258