1"""Generate C wrapper functions.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6
7### WARNING: the code in this file has not been extensively reviewed yet.
8### We do not think it is harmful, but it may be below our normal standards
9### for robustness and maintainability.
10
11import os
12import re
13import sys
14import typing
15from typing import Dict, List, Optional, Tuple
16
17from .c_parsing_helper import ArgumentInfo, FunctionInfo
18from . import typing_util
19
20
21def c_declare(prefix: str, name: str, suffix: str) -> str:
22    """Format a declaration of name with the given type prefix and suffix."""
23    if not prefix.endswith('*'):
24        prefix += ' '
25    return prefix + name + suffix
26
27
28WrapperInfo = typing.NamedTuple('WrapperInfo', [
29    ('argument_names', List[str]),
30    ('guard', Optional[str]),
31    ('wrapper_name', str),
32])
33
34
35class Base:
36    """Generate a C source file containing wrapper functions."""
37
38    # This class is designed to have many methods potentially overloaded.
39    # Tell pylint not to complain about methods that have unused arguments:
40    # child classes are likely to override those methods and need the
41    # arguments in question.
42    #pylint: disable=no-self-use,unused-argument
43
44    # Prefix prepended to the function's name to form the wrapper name.
45    _WRAPPER_NAME_PREFIX = ''
46    # Suffix appended to the function's name to form the wrapper name.
47    _WRAPPER_NAME_SUFFIX = '_wrap'
48
49    # Functions with one of these qualifiers are skipped.
50    _SKIP_FUNCTION_WITH_QUALIFIERS = frozenset(['inline', 'static'])
51
52    def __init__(self):
53        """Construct a wrapper generator object.
54        """
55        self.program_name = os.path.basename(sys.argv[0])
56        # To be populated in a derived class
57        self.functions = {} #type: Dict[str, FunctionInfo]
58        # Preprocessor symbol used as a guard against multiple inclusion in the
59        # header. Must be set before writing output to a header.
60        # Not used when writing .c output.
61        self.header_guard = None #type: Optional[str]
62
63    def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
64        """Write the prologue of a C file.
65
66        This includes a description comment and some include directives.
67        """
68        out.write("""/* Automatically generated by {}, do not edit! */
69
70/* Copyright The Mbed TLS Contributors
71 * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
72 */
73"""
74                  .format(self.program_name))
75        if header:
76            out.write("""
77#ifndef {guard}
78#define {guard}
79
80#ifdef __cplusplus
81extern "C" {{
82#endif
83"""
84                      .format(guard=self.header_guard))
85        out.write("""
86#include <mbedtls/build_info.h>
87""")
88
89    def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
90        """Write the epilogue of a C file.
91        """
92        if header:
93            out.write("""
94#ifdef __cplusplus
95}}
96#endif
97
98#endif /* {guard} */
99"""
100                      .format(guard=self.header_guard))
101        out.write("""
102/* End of automatically generated file. */
103""")
104
105    def _wrapper_function_name(self, original_name: str) -> str:
106        """The name of the wrapper function.
107
108        By default, this adds a suffix.
109        """
110        return (self._WRAPPER_NAME_PREFIX +
111                original_name +
112                self._WRAPPER_NAME_SUFFIX)
113
114    def _wrapper_declaration_start(self,
115                                   function: FunctionInfo,
116                                   wrapper_name: str) -> str:
117        """The beginning of the wrapper function declaration.
118
119        This ends just before the opening parenthesis of the argument list.
120
121        This is a string containing at least the return type and the
122        function name. It may start with additional qualifiers or attributes
123        such as `static`, `__attribute__((...))`, etc.
124        """
125        return c_declare(function.return_type, wrapper_name, '')
126
127    def _argument_name(self,
128                       function_name: str,
129                       num: int,
130                       arg: ArgumentInfo) -> str:
131        """Name to use for the given argument in the wrapper function.
132
133        Argument numbers count from 0.
134        """
135        name = 'arg' + str(num)
136        if arg.name:
137            name += '_' + arg.name
138        return name
139
140    def _wrapper_declaration_argument(self,
141                                      function_name: str,
142                                      num: int, name: str,
143                                      arg: ArgumentInfo) -> str:
144        """One argument definition in the wrapper function declaration.
145
146        Argument numbers count from 0.
147        """
148        return c_declare(arg.type, name, arg.suffix)
149
150    def _underlying_function_name(self, function: FunctionInfo) -> str:
151        """The name of the underlying function.
152
153        By default, this is the name of the wrapped function.
154        """
155        return function.name
156
157    def _return_variable_name(self, function: FunctionInfo) -> str:
158        """The name of the variable that will contain the return value."""
159        return 'retval'
160
161    def _write_function_call(self, out: typing_util.Writable,
162                             function: FunctionInfo,
163                             argument_names: List[str]) -> None:
164        """Write the call to the underlying function.
165        """
166        # Note that the function name is in parentheses, to avoid calling
167        # a function-like macro with the same name, since in typical usage
168        # there is a function-like macro with the same name which is the
169        # wrapper.
170        call = '({})({})'.format(self._underlying_function_name(function),
171                                 ', '.join(argument_names))
172        if function.returns_void():
173            out.write('    {};\n'.format(call))
174        else:
175            ret_name = self._return_variable_name(function)
176            ret_decl = c_declare(function.return_type, ret_name, '')
177            out.write('    {} = {};\n'.format(ret_decl, call))
178
179    def _write_function_return(self, out: typing_util.Writable,
180                               function: FunctionInfo,
181                               if_void: bool = False) -> None:
182        """Write a return statement.
183
184        If the function returns void, only write a statement if if_void is true.
185        """
186        if function.returns_void():
187            if if_void:
188                out.write('    return;\n')
189        else:
190            ret_name = self._return_variable_name(function)
191            out.write('    return {};\n'.format(ret_name))
192
193    def _write_function_body(self, out: typing_util.Writable,
194                             function: FunctionInfo,
195                             argument_names: List[str]) -> None:
196        """Write the body of the wrapper code for the specified function.
197        """
198        self._write_function_call(out, function, argument_names)
199        self._write_function_return(out, function)
200
201    def _skip_function(self, function: FunctionInfo) -> bool:
202        """Whether to skip this function.
203
204        By default, static or inline functions are skipped.
205        """
206        if not self._SKIP_FUNCTION_WITH_QUALIFIERS.isdisjoint(function.qualifiers):
207            return True
208        return False
209
210    _FUNCTION_GUARDS = {
211    } #type: Dict[str, str]
212
213    def _function_guard(self, function: FunctionInfo) -> Optional[str]:
214        """A preprocessor condition for this function.
215
216        The wrapper will be guarded with `#if` on this condition, if not None.
217        """
218        return self._FUNCTION_GUARDS.get(function.name)
219
220    def _wrapper_info(self, function: FunctionInfo) -> Optional[WrapperInfo]:
221        """Information about the wrapper for one function.
222
223        Return None if the function should be skipped.
224        """
225        if self._skip_function(function):
226            return None
227        argument_names = [self._argument_name(function.name, num, arg)
228                          for num, arg in enumerate(function.arguments)]
229        return WrapperInfo(
230            argument_names=argument_names,
231            guard=self._function_guard(function),
232            wrapper_name=self._wrapper_function_name(function.name),
233        )
234
235    def _write_function_prototype(self, out: typing_util.Writable,
236                                  function: FunctionInfo,
237                                  wrapper: WrapperInfo,
238                                  header: bool) -> None:
239        """Write the prototype of a wrapper function.
240
241        If header is true, write a function declaration, with a semicolon at
242        the end. Otherwise just write the prototype, intended to be followed
243        by the function's body.
244        """
245        declaration_start = self._wrapper_declaration_start(function,
246                                                            wrapper.wrapper_name)
247        arg_indent = '    '
248        terminator = ';\n' if header else '\n'
249        if function.arguments:
250            out.write(declaration_start + '(\n')
251            for num in range(len(function.arguments)):
252                arg_def = self._wrapper_declaration_argument(
253                    function.name,
254                    num, wrapper.argument_names[num], function.arguments[num])
255                arg_terminator = \
256                    (')' + terminator if num == len(function.arguments) - 1 else
257                     ',\n')
258                out.write(arg_indent + arg_def + arg_terminator)
259        else:
260            out.write(declaration_start + '(void)' + terminator)
261
262    def _write_c_function(self, out: typing_util.Writable,
263                          function: FunctionInfo) -> None:
264        """Write wrapper code for one function.
265
266        Do nothing if the function is skipped.
267        """
268        wrapper = self._wrapper_info(function)
269        if wrapper is None:
270            return
271        out.write("""
272/* Wrapper for {} */
273"""
274                  .format(function.name))
275        if wrapper.guard is not None:
276            out.write('#if {}\n'.format(wrapper.guard))
277        self._write_function_prototype(out, function, wrapper, False)
278        out.write('{\n')
279        self._write_function_body(out, function, wrapper.argument_names)
280        out.write('}\n')
281        if wrapper.guard is not None:
282            out.write('#endif /* {} */\n'.format(wrapper.guard))
283
284    def _write_h_function_declaration(self, out: typing_util.Writable,
285                                      function: FunctionInfo,
286                                      wrapper: WrapperInfo) -> None:
287        """Write the declaration of one wrapper function.
288        """
289        self._write_function_prototype(out, function, wrapper, True)
290
291    def _write_h_macro_definition(self, out: typing_util.Writable,
292                                  function: FunctionInfo,
293                                  wrapper: WrapperInfo) -> None:
294        """Write the macro definition for one wrapper.
295        """
296        arg_list = ', '.join(wrapper.argument_names)
297        out.write('#define {function_name}({args}) \\\n    {wrapper_name}({args})\n'
298                  .format(function_name=function.name,
299                          wrapper_name=wrapper.wrapper_name,
300                          args=arg_list))
301
302    def _write_h_function(self, out: typing_util.Writable,
303                          function: FunctionInfo) -> None:
304        """Write the complete header content for one wrapper.
305
306        This is the declaration of the wrapper function, and the
307        definition of a function-like macro that calls the wrapper function.
308
309        Do nothing if the function is skipped.
310        """
311        wrapper = self._wrapper_info(function)
312        if wrapper is None:
313            return
314        out.write('\n')
315        if wrapper.guard is not None:
316            out.write('#if {}\n'.format(wrapper.guard))
317        self._write_h_function_declaration(out, function, wrapper)
318        self._write_h_macro_definition(out, function, wrapper)
319        if wrapper.guard is not None:
320            out.write('#endif /* {} */\n'.format(wrapper.guard))
321
322    def write_c_file(self, filename: str) -> None:
323        """Output a whole C file containing function wrapper definitions."""
324        with open(filename, 'w', encoding='utf-8') as out:
325            self._write_prologue(out, False)
326            for name in sorted(self.functions):
327                self._write_c_function(out, self.functions[name])
328            self._write_epilogue(out, False)
329
330    def _header_guard_from_file_name(self, filename: str) -> str:
331        """Preprocessor symbol used as a guard against multiple inclusion."""
332        # Heuristic to strip irrelevant leading directories
333        filename = re.sub(r'.*include[\\/]', r'', filename)
334        return re.sub(r'[^0-9A-Za-z]', r'_', filename, re.A).upper()
335
336    def write_h_file(self, filename: str) -> None:
337        """Output a header file with function wrapper declarations and macro definitions."""
338        self.header_guard = self._header_guard_from_file_name(filename)
339        with open(filename, 'w', encoding='utf-8') as out:
340            self._write_prologue(out, True)
341            for name in sorted(self.functions):
342                self._write_h_function(out, self.functions[name])
343            self._write_epilogue(out, True)
344
345
346class UnknownTypeForPrintf(Exception):
347    """Exception raised when attempting to generate code that logs a value of an unknown type."""
348
349    def __init__(self, typ: str) -> None:
350        super().__init__("Unknown type for printf format generation: " + typ)
351
352
353class Logging(Base):
354    """Generate wrapper functions that log the inputs and outputs."""
355
356    def __init__(self) -> None:
357        """Construct a wrapper generator including logging of inputs and outputs.
358
359        Log to stdout by default. Call `set_stream` to change this.
360        """
361        super().__init__()
362        self.stream = 'stdout'
363
364    def set_stream(self, stream: str) -> None:
365        """Set the stdio stream to log to.
366
367        Call this method before calling `write_c_output` or `write_h_output`.
368        """
369        self.stream = stream
370
371    def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
372        super()._write_prologue(out, header)
373        if not header:
374            out.write("""
375#if defined(MBEDTLS_FS_IO) && defined(MBEDTLS_TEST_HOOKS)
376#include <stdio.h>
377#include <inttypes.h>
378#include <mbedtls/debug.h> // for MBEDTLS_PRINTF_SIZET
379#include <mbedtls/platform.h> // for mbedtls_fprintf
380#endif /* defined(MBEDTLS_FS_IO) && defined(MBEDTLS_TEST_HOOKS) */
381""")
382
383    _PRINTF_SIMPLE_FORMAT = {
384        'int': '%d',
385        'long': '%ld',
386        'long long': '%lld',
387        'size_t': '%"MBEDTLS_PRINTF_SIZET"',
388        'unsigned': '0x%08x',
389        'unsigned int': '0x%08x',
390        'unsigned long': '0x%08lx',
391        'unsigned long long': '0x%016llx',
392    }
393
394    def _printf_simple_format(self, typ: str) -> Optional[str]:
395        """Use this printf format for a value of typ.
396
397        Return None if values of typ need more complex handling.
398        """
399        return self._PRINTF_SIMPLE_FORMAT.get(typ)
400
401    _PRINTF_TYPE_CAST = {
402        'int32_t': 'int',
403        'uint32_t': 'unsigned',
404        'uint64_t': 'unsigned long long',
405    } #type: Dict[str, str]
406
407    def _printf_type_cast(self, typ: str) -> Optional[str]:
408        """Cast values of typ to this type before passing them to printf.
409
410        Return None if values of the given type do not need a cast.
411        """
412        return self._PRINTF_TYPE_CAST.get(typ)
413
414    _POINTER_TYPE_RE = re.compile(r'\s*\*\Z')
415
416    def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
417        """The printf format and arguments for a value of type typ stored in var.
418        """
419        expr = var
420        base_type = typ
421        # For outputs via a pointer, get the value that has been written.
422        # Note: we don't support pointers to pointers here.
423        pointer_match = self._POINTER_TYPE_RE.search(base_type)
424        if pointer_match:
425            base_type = base_type[:pointer_match.start(0)]
426            expr = '*({})'.format(expr)
427        # Maybe cast the value to a standard type.
428        cast_to = self._printf_type_cast(base_type)
429        if cast_to is not None:
430            expr = '({}) {}'.format(cast_to, expr)
431            base_type = cast_to
432        # Try standard types.
433        fmt = self._printf_simple_format(base_type)
434        if fmt is not None:
435            return '{}={}'.format(var, fmt), [expr]
436        raise UnknownTypeForPrintf(typ)
437
438    def _write_function_logging(self, out: typing_util.Writable,
439                                function: FunctionInfo,
440                                argument_names: List[str]) -> None:
441        """Write code to log the function's inputs and outputs."""
442        formats, values = '%s', ['"' + function.name + '"']
443        for arg_info, arg_name in zip(function.arguments, argument_names):
444            fmt, vals = self._printf_parameters(arg_info.type, arg_name)
445            if fmt:
446                formats += ' ' + fmt
447                values += vals
448        if not function.returns_void():
449            ret_name = self._return_variable_name(function)
450            fmt, vals = self._printf_parameters(function.return_type, ret_name)
451            if fmt:
452                formats += ' ' + fmt
453                values += vals
454        out.write("""\
455#if defined(MBEDTLS_FS_IO) && defined(MBEDTLS_TEST_HOOKS)
456    if ({stream}) {{
457        mbedtls_fprintf({stream}, "{formats}\\n",
458                        {values});
459    }}
460#endif /* defined(MBEDTLS_FS_IO) && defined(MBEDTLS_TEST_HOOKS) */
461"""
462                  .format(stream=self.stream,
463                          formats=formats,
464                          values=', '.join(values)))
465
466    def _write_function_body(self, out: typing_util.Writable,
467                             function: FunctionInfo,
468                             argument_names: List[str]) -> None:
469        """Write the body of the wrapper code for the specified function.
470        """
471        self._write_function_call(out, function, argument_names)
472        self._write_function_logging(out, function, argument_names)
473        self._write_function_return(out, function)
474