1#!/usr/bin/env python3
2
3"""Generate library/ssl_debug_helpers_generated.c
4
5The code generated by this module includes debug helper functions that can not be
6implemented by fixed codes.
7
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
12import sys
13import re
14import os
15import textwrap
16import argparse
17
18import framework_scripts_path # pylint: disable=unused-import
19from mbedtls_framework import build_tree
20
21
22def remove_c_comments(string):
23    """
24        Remove C style comments from input string
25    """
26    string_pattern = r"(?P<string>\".*?\"|\'.*?\')"
27    comment_pattern = r"(?P<comment>/\*.*?\*/|//[^\r\n]*$)"
28    pattern = re.compile(string_pattern + r'|' + comment_pattern,
29                         re.MULTILINE | re.DOTALL)
30
31    def replacer(match):
32        if match.lastgroup == 'comment':
33            return ""
34        return match.group()
35    return pattern.sub(replacer, string)
36
37
38class CondDirectiveNotMatch(Exception):
39    pass
40
41
42def preprocess_c_source_code(source, *classes):
43    """
44        Simple preprocessor for C source code.
45
46        Only processes condition directives without expanding them.
47        Yield object according to the classes input. Most match firstly
48
49        If the directive pair does not match , raise CondDirectiveNotMatch.
50
51        Assume source code does not include comments and compile pass.
52
53    """
54
55    pattern = re.compile(r"^[ \t]*#[ \t]*" +
56                         r"(?P<directive>(if[ \t]|ifndef[ \t]|ifdef[ \t]|else|endif))" +
57                         r"[ \t]*(?P<param>(.*\\\n)*.*$)",
58                         re.MULTILINE)
59    stack = []
60
61    def _yield_objects(s, d, p, st, end):
62        """
63            Output matched source piece
64        """
65        nonlocal stack
66        start_line, end_line = '', ''
67        if stack:
68            start_line = '#{} {}'.format(d, p)
69            if d == 'if':
70                end_line = '#endif /* {} */'.format(p)
71            elif d == 'ifdef':
72                end_line = '#endif /* defined({}) */'.format(p)
73            else:
74                end_line = '#endif /* !defined({}) */'.format(p)
75        has_instance = False
76        for cls in classes:
77            for instance in cls.extract(s, st, end):
78                if has_instance is False:
79                    has_instance = True
80                    yield pair_start, start_line
81                yield instance.span()[0], instance
82        if has_instance:
83            yield start, end_line
84
85    for match in pattern.finditer(source):
86
87        directive = match.groupdict()['directive'].strip()
88        param = match.groupdict()['param']
89        start, end = match.span()
90
91        if directive in ('if', 'ifndef', 'ifdef'):
92            stack.append((directive, param, start, end))
93            continue
94
95        if not stack:
96            raise CondDirectiveNotMatch()
97
98        pair_directive, pair_param, pair_start, pair_end = stack.pop()
99        yield from _yield_objects(source,
100                                  pair_directive,
101                                  pair_param,
102                                  pair_end,
103                                  start)
104
105        if directive == 'endif':
106            continue
107
108        if pair_directive == 'if':
109            directive = 'if'
110            param = "!( {} )".format(pair_param)
111        elif pair_directive == 'ifdef':
112            directive = 'ifndef'
113            param = pair_param
114        else:
115            directive = 'ifdef'
116            param = pair_param
117
118        stack.append((directive, param, start, end))
119    assert not stack, len(stack)
120
121
122class EnumDefinition:
123    """
124        Generate helper functions around enumeration.
125
126        Currently, it generate translation function from enum value to string.
127        Enum definition looks like:
128        [typedef] enum [prefix name] { [body] } [suffix name];
129
130        Known limitation:
131        - the '}' and ';' SHOULD NOT exist in different macro blocks. Like
132        ```
133        enum test {
134            ....
135        #if defined(A)
136            ....
137        };
138        #else
139            ....
140        };
141        #endif
142        ```
143    """
144
145    @classmethod
146    def extract(cls, source_code, start=0, end=-1):
147        enum_pattern = re.compile(r'enum\s*(?P<prefix_name>\w*)\s*' +
148                                  r'{\s*(?P<body>[^}]*)}' +
149                                  r'\s*(?P<suffix_name>\w*)\s*;',
150                                  re.MULTILINE | re.DOTALL)
151
152        for match in enum_pattern.finditer(source_code, start, end):
153            yield EnumDefinition(source_code,
154                                 span=match.span(),
155                                 group=match.groupdict())
156
157    def __init__(self, source_code, span=None, group=None):
158        assert isinstance(group, dict)
159        prefix_name = group.get('prefix_name', None)
160        suffix_name = group.get('suffix_name', None)
161        body = group.get('body', None)
162        assert prefix_name or suffix_name
163        assert body
164        assert span
165        # If suffix_name exists, it is a typedef
166        self._prototype = suffix_name if suffix_name else 'enum ' + prefix_name
167        self._name = suffix_name if suffix_name else prefix_name
168        self._body = body
169        self._source = source_code
170        self._span = span
171
172    def __repr__(self):
173        return 'Enum({},{})'.format(self._name, self._span)
174
175    def __str__(self):
176        return repr(self)
177
178    def span(self):
179        return self._span
180
181    def generate_translation_function(self):
182        """
183            Generate function for translating value to string
184        """
185        translation_table = []
186
187        for line in self._body.splitlines():
188
189            if line.strip().startswith('#'):
190                # Preprocess directive, keep it in table
191                translation_table.append(line.strip())
192                continue
193
194            if not line.strip():
195                continue
196
197            for field in line.strip().split(','):
198                if not field.strip():
199                    continue
200                member = field.strip().split()[0]
201                translation_table.append(
202                    '{space}case {member}:\n{space}    return "{member}";'
203                    .format(member=member, space=' '*8)
204                )
205
206        body = textwrap.dedent('''\
207            const char *{name}_str( {prototype} in )
208            {{
209                switch (in) {{
210            {translation_table}
211                    default:
212                        return "UNKNOWN_VALUE";
213                }}
214            }}
215                    ''')
216        body = body.format(translation_table='\n'.join(translation_table),
217                           name=self._name,
218                           prototype=self._prototype)
219        return body
220
221
222class SignatureAlgorithmDefinition:
223    """
224        Generate helper functions for signature algorithms.
225
226        It generates translation function from signature algorithm define to string.
227        Signature algorithm definition looks like:
228        #define MBEDTLS_TLS1_3_SIG_[ upper case signature algorithm ] [ value(hex) ]
229
230        Known limitation:
231        - the definitions SHOULD  exist in same macro blocks.
232    """
233
234    @classmethod
235    def extract(cls, source_code, start=0, end=-1):
236        sig_alg_pattern = re.compile(r'#define\s+(?P<name>MBEDTLS_TLS1_3_SIG_\w+)\s+' +
237                                     r'(?P<value>0[xX][0-9a-fA-F]+)$',
238                                     re.MULTILINE | re.DOTALL)
239        matches = list(sig_alg_pattern.finditer(source_code, start, end))
240        if matches:
241            yield SignatureAlgorithmDefinition(source_code, definitions=matches)
242
243    def __init__(self, source_code, definitions=None):
244        if definitions is None:
245            definitions = []
246        assert isinstance(definitions, list) and definitions
247        self._definitions = definitions
248        self._source = source_code
249
250    def __repr__(self):
251        return 'SigAlgs({})'.format(self._definitions[0].span())
252
253    def span(self):
254        return self._definitions[0].span()
255
256    def __str__(self):
257        """
258            Generate function for translating value to string
259        """
260        translation_table = []
261        for m in self._definitions:
262            name = m.groupdict()['name']
263            return_val = name[len('MBEDTLS_TLS1_3_SIG_'):].lower()
264            translation_table.append(
265                '    case {}:\n        return "{}";'.format(name, return_val))
266
267        body = textwrap.dedent('''\
268            const char *mbedtls_ssl_sig_alg_to_str( uint16_t in )
269            {{
270                switch( in )
271                {{
272            {translation_table}
273                }};
274
275                return "UNKNOWN";
276            }}''')
277        body = body.format(translation_table='\n'.join(translation_table))
278        return body
279
280
281class NamedGroupDefinition:
282    """
283        Generate helper functions for named group
284
285        It generates translation function from named group define to string.
286        Named group definition looks like:
287        #define MBEDTLS_SSL_IANA_TLS_GROUP_[ upper case named group ] [ value(hex) ]
288
289        Known limitation:
290        - the definitions SHOULD exist in same macro blocks.
291    """
292
293    @classmethod
294    def extract(cls, source_code, start=0, end=-1):
295        named_group_pattern = re.compile(r'#define\s+(?P<name>MBEDTLS_SSL_IANA_TLS_GROUP_\w+)\s+' +
296                                         r'(?P<value>0[xX][0-9a-fA-F]+)$',
297                                         re.MULTILINE | re.DOTALL)
298        matches = list(named_group_pattern.finditer(source_code, start, end))
299        if matches:
300            yield NamedGroupDefinition(source_code, definitions=matches)
301
302    def __init__(self, source_code, definitions=None):
303        if definitions is None:
304            definitions = []
305        assert isinstance(definitions, list) and definitions
306        self._definitions = definitions
307        self._source = source_code
308
309    def __repr__(self):
310        return 'NamedGroup({})'.format(self._definitions[0].span())
311
312    def span(self):
313        return self._definitions[0].span()
314
315    def __str__(self):
316        """
317            Generate function for translating value to string
318        """
319        translation_table = []
320        for m in self._definitions:
321            name = m.groupdict()['name']
322            iana_name = name[len('MBEDTLS_SSL_IANA_TLS_GROUP_'):].lower()
323            translation_table.append('    case {}:\n        return "{}";'.format(name, iana_name))
324
325        body = textwrap.dedent('''\
326            const char *mbedtls_ssl_named_group_to_str( uint16_t in )
327            {{
328                switch( in )
329                {{
330            {translation_table}
331                }};
332
333                return "UNKNOWN";
334            }}''')
335        body = body.format(translation_table='\n'.join(translation_table))
336        return body
337
338
339OUTPUT_C_TEMPLATE = '''\
340/* Automatically generated by generate_ssl_debug_helpers.py. DO NOT EDIT. */
341
342/**
343 * \\file ssl_debug_helpers_generated.c
344 *
345 * \\brief Automatically generated helper functions for debugging
346 */
347/*
348 *  Copyright The Mbed TLS Contributors
349 *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
350 *
351 */
352
353#include "common.h"
354
355#if defined(MBEDTLS_DEBUG_C)
356
357#include "ssl_debug_helpers.h"
358
359{functions}
360
361#endif /* MBEDTLS_DEBUG_C */
362/* End of automatically generated file. */
363
364'''
365
366
367def generate_ssl_debug_helpers(output_directory, mbedtls_root):
368    """
369        Generate functions of debug helps
370    """
371    mbedtls_root = os.path.abspath(
372        mbedtls_root or build_tree.guess_mbedtls_root())
373    with open(os.path.join(mbedtls_root, 'include/mbedtls/ssl.h')) as f:
374        source_code = remove_c_comments(f.read())
375
376    definitions = dict()
377    for start, instance in preprocess_c_source_code(source_code,
378                                                    EnumDefinition,
379                                                    SignatureAlgorithmDefinition,
380                                                    NamedGroupDefinition):
381        if start in definitions:
382            continue
383        if isinstance(instance, EnumDefinition):
384            definition = instance.generate_translation_function()
385        else:
386            definition = instance
387        definitions[start] = definition
388
389    function_definitions = [str(v) for _, v in sorted(definitions.items())]
390    if output_directory == sys.stdout:
391        sys.stdout.write(OUTPUT_C_TEMPLATE.format(
392            functions='\n'.join(function_definitions)))
393    else:
394        with open(os.path.join(output_directory, 'ssl_debug_helpers_generated.c'), 'w') as f:
395            f.write(OUTPUT_C_TEMPLATE.format(
396                functions='\n'.join(function_definitions)))
397
398
399def main():
400    """
401    Command line entry
402    """
403    parser = argparse.ArgumentParser()
404    parser.add_argument('--mbedtls-root', nargs='?', default=None,
405                        help='root directory of mbedtls source code')
406    parser.add_argument('output_directory', nargs='?',
407                        default='library', help='source/header files location')
408
409    args = parser.parse_args()
410
411    generate_ssl_debug_helpers(args.output_directory, args.mbedtls_root)
412    return 0
413
414
415if __name__ == '__main__':
416    sys.exit(main())
417