1#!/usr/bin/env python3
2
3"""Generate library/ssl_debug_helpers_generated.c
4
5The code generated by this module includes debug helper functions that can not be
6implemented by fixed codes.
7
8"""
9
10# Copyright The Mbed TLS Contributors
11# SPDX-License-Identifier: Apache-2.0
12#
13# Licensed under the Apache License, Version 2.0 (the "License"); you may
14# not use this file except in compliance with the License.
15# You may obtain a copy of the License at
16#
17# http://www.apache.org/licenses/LICENSE-2.0
18#
19# Unless required by applicable law or agreed to in writing, software
20# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
21# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22# See the License for the specific language governing permissions and
23# limitations under the License.
24import sys
25import re
26import os
27import textwrap
28import argparse
29from mbedtls_dev import build_tree
30
31
32def remove_c_comments(string):
33    """
34        Remove C style comments from input string
35    """
36    string_pattern = r"(?P<string>\".*?\"|\'.*?\')"
37    comment_pattern = r"(?P<comment>/\*.*?\*/|//[^\r\n]*$)"
38    pattern = re.compile(string_pattern + r'|' + comment_pattern,
39                         re.MULTILINE | re.DOTALL)
40
41    def replacer(match):
42        if match.lastgroup == 'comment':
43            return ""
44        return match.group()
45    return pattern.sub(replacer, string)
46
47
48class CondDirectiveNotMatch(Exception):
49    pass
50
51
52def preprocess_c_source_code(source, *classes):
53    """
54        Simple preprocessor for C source code.
55
56        Only processes condition directives without expanding them.
57        Yield object according to the classes input. Most match firstly
58
59        If the directive pair does not match , raise CondDirectiveNotMatch.
60
61        Assume source code does not include comments and compile pass.
62
63    """
64
65    pattern = re.compile(r"^[ \t]*#[ \t]*" +
66                         r"(?P<directive>(if[ \t]|ifndef[ \t]|ifdef[ \t]|else|endif))" +
67                         r"[ \t]*(?P<param>(.*\\\n)*.*$)",
68                         re.MULTILINE)
69    stack = []
70
71    def _yield_objects(s, d, p, st, end):
72        """
73            Output matched source piece
74        """
75        nonlocal stack
76        start_line, end_line = '', ''
77        if stack:
78            start_line = '#{} {}'.format(d, p)
79            if d == 'if':
80                end_line = '#endif /* {} */'.format(p)
81            elif d == 'ifdef':
82                end_line = '#endif /* defined({}) */'.format(p)
83            else:
84                end_line = '#endif /* !defined({}) */'.format(p)
85        has_instance = False
86        for cls in classes:
87            for instance in cls.extract(s, st, end):
88                if has_instance is False:
89                    has_instance = True
90                    yield pair_start, start_line
91                yield instance.span()[0], instance
92        if has_instance:
93            yield start, end_line
94
95    for match in pattern.finditer(source):
96
97        directive = match.groupdict()['directive'].strip()
98        param = match.groupdict()['param']
99        start, end = match.span()
100
101        if directive in ('if', 'ifndef', 'ifdef'):
102            stack.append((directive, param, start, end))
103            continue
104
105        if not stack:
106            raise CondDirectiveNotMatch()
107
108        pair_directive, pair_param, pair_start, pair_end = stack.pop()
109        yield from _yield_objects(source,
110                                  pair_directive,
111                                  pair_param,
112                                  pair_end,
113                                  start)
114
115        if directive == 'endif':
116            continue
117
118        if pair_directive == 'if':
119            directive = 'if'
120            param = "!( {} )".format(pair_param)
121        elif pair_directive == 'ifdef':
122            directive = 'ifndef'
123            param = pair_param
124        else:
125            directive = 'ifdef'
126            param = pair_param
127
128        stack.append((directive, param, start, end))
129    assert not stack, len(stack)
130
131
132class EnumDefinition:
133    """
134        Generate helper functions around enumeration.
135
136        Currently, it generate translation function from enum value to string.
137        Enum definition looks like:
138        [typedef] enum [prefix name] { [body] } [suffix name];
139
140        Known limitation:
141        - the '}' and ';' SHOULD NOT exist in different macro blocks. Like
142        ```
143        enum test {
144            ....
145        #if defined(A)
146            ....
147        };
148        #else
149            ....
150        };
151        #endif
152        ```
153    """
154
155    @classmethod
156    def extract(cls, source_code, start=0, end=-1):
157        enum_pattern = re.compile(r'enum\s*(?P<prefix_name>\w*)\s*' +
158                                  r'{\s*(?P<body>[^}]*)}' +
159                                  r'\s*(?P<suffix_name>\w*)\s*;',
160                                  re.MULTILINE | re.DOTALL)
161
162        for match in enum_pattern.finditer(source_code, start, end):
163            yield EnumDefinition(source_code,
164                                 span=match.span(),
165                                 group=match.groupdict())
166
167    def __init__(self, source_code, span=None, group=None):
168        assert isinstance(group, dict)
169        prefix_name = group.get('prefix_name', None)
170        suffix_name = group.get('suffix_name', None)
171        body = group.get('body', None)
172        assert prefix_name or suffix_name
173        assert body
174        assert span
175        # If suffix_name exists, it is a typedef
176        self._prototype = suffix_name if suffix_name else 'enum ' + prefix_name
177        self._name = suffix_name if suffix_name else prefix_name
178        self._body = body
179        self._source = source_code
180        self._span = span
181
182    def __repr__(self):
183        return 'Enum({},{})'.format(self._name, self._span)
184
185    def __str__(self):
186        return repr(self)
187
188    def span(self):
189        return self._span
190
191    def generate_translation_function(self):
192        """
193            Generate function for translating value to string
194        """
195        translation_table = []
196
197        for line in self._body.splitlines():
198
199            if line.strip().startswith('#'):
200                # Preprocess directive, keep it in table
201                translation_table.append(line.strip())
202                continue
203
204            if not line.strip():
205                continue
206
207            for field in line.strip().split(','):
208                if not field.strip():
209                    continue
210                member = field.strip().split()[0]
211                translation_table.append(
212                    '{space}[{member}] = "{member}",'.format(member=member,
213                                                             space=' '*8)
214                )
215
216        body = textwrap.dedent('''\
217            const char *{name}_str( {prototype} in )
218            {{
219                const char * in_to_str[]=
220                {{
221            {translation_table}
222                }};
223
224                if( in > ( sizeof( in_to_str )/sizeof( in_to_str[0]) - 1 ) ||
225                    in_to_str[ in ] == NULL )
226                {{
227                    return "UNKNOWN_VALUE";
228                }}
229                return in_to_str[ in ];
230            }}
231                    ''')
232        body = body.format(translation_table='\n'.join(translation_table),
233                           name=self._name,
234                           prototype=self._prototype)
235        return body
236
237
238class SignatureAlgorithmDefinition:
239    """
240        Generate helper functions for signature algorithms.
241
242        It generates translation function from signature algorithm define to string.
243        Signature algorithm definition looks like:
244        #define MBEDTLS_TLS1_3_SIG_[ upper case signature algorithm ] [ value(hex) ]
245
246        Known limitation:
247        - the definitions SHOULD  exist in same macro blocks.
248    """
249
250    @classmethod
251    def extract(cls, source_code, start=0, end=-1):
252        sig_alg_pattern = re.compile(r'#define\s+(?P<name>MBEDTLS_TLS1_3_SIG_\w+)\s+' +
253                                     r'(?P<value>0[xX][0-9a-fA-F]+)$',
254                                     re.MULTILINE | re.DOTALL)
255        matches = list(sig_alg_pattern.finditer(source_code, start, end))
256        if matches:
257            yield SignatureAlgorithmDefinition(source_code, definitions=matches)
258
259    def __init__(self, source_code, definitions=None):
260        if definitions is None:
261            definitions = []
262        assert isinstance(definitions, list) and definitions
263        self._definitions = definitions
264        self._source = source_code
265
266    def __repr__(self):
267        return 'SigAlgs({})'.format(self._definitions[0].span())
268
269    def span(self):
270        return self._definitions[0].span()
271
272    def __str__(self):
273        """
274            Generate function for translating value to string
275        """
276        translation_table = []
277        for m in self._definitions:
278            name = m.groupdict()['name']
279            return_val = name[len('MBEDTLS_TLS1_3_SIG_'):].lower()
280            translation_table.append(
281                '    case {}:\n        return "{}";'.format(name, return_val))
282
283        body = textwrap.dedent('''\
284            const char *mbedtls_ssl_sig_alg_to_str( uint16_t in )
285            {{
286                switch( in )
287                {{
288            {translation_table}
289                }};
290
291                return "UNKNOWN";
292            }}''')
293        body = body.format(translation_table='\n'.join(translation_table))
294        return body
295
296
297class NamedGroupDefinition:
298    """
299        Generate helper functions for named group
300
301        It generates translation function from named group define to string.
302        Named group definition looks like:
303        #define MBEDTLS_SSL_IANA_TLS_GROUP_[ upper case named group ] [ value(hex) ]
304
305        Known limitation:
306        - the definitions SHOULD exist in same macro blocks.
307    """
308
309    @classmethod
310    def extract(cls, source_code, start=0, end=-1):
311        named_group_pattern = re.compile(r'#define\s+(?P<name>MBEDTLS_SSL_IANA_TLS_GROUP_\w+)\s+' +
312                                         r'(?P<value>0[xX][0-9a-fA-F]+)$',
313                                         re.MULTILINE | re.DOTALL)
314        matches = list(named_group_pattern.finditer(source_code, start, end))
315        if matches:
316            yield NamedGroupDefinition(source_code, definitions=matches)
317
318    def __init__(self, source_code, definitions=None):
319        if definitions is None:
320            definitions = []
321        assert isinstance(definitions, list) and definitions
322        self._definitions = definitions
323        self._source = source_code
324
325    def __repr__(self):
326        return 'NamedGroup({})'.format(self._definitions[0].span())
327
328    def span(self):
329        return self._definitions[0].span()
330
331    def __str__(self):
332        """
333            Generate function for translating value to string
334        """
335        translation_table = []
336        for m in self._definitions:
337            name = m.groupdict()['name']
338            iana_name = name[len('MBEDTLS_SSL_IANA_TLS_GROUP_'):].lower()
339            translation_table.append('    case {}:\n        return "{}";'.format(name, iana_name))
340
341        body = textwrap.dedent('''\
342            const char *mbedtls_ssl_named_group_to_str( uint16_t in )
343            {{
344                switch( in )
345                {{
346            {translation_table}
347                }};
348
349                return "UNKOWN";
350            }}''')
351        body = body.format(translation_table='\n'.join(translation_table))
352        return body
353
354
355OUTPUT_C_TEMPLATE = '''\
356/* Automatically generated by generate_ssl_debug_helpers.py. DO NOT EDIT. */
357
358/**
359 * \\file ssl_debug_helpers_generated.c
360 *
361 * \\brief Automatically generated helper functions for debugging
362 */
363/*
364 *  Copyright The Mbed TLS Contributors
365 *  SPDX-License-Identifier: Apache-2.0
366 *
367 *  Licensed under the Apache License, Version 2.0 (the "License"); you may
368 *  not use this file except in compliance with the License.
369 *  You may obtain a copy of the License at
370 *
371 *  http://www.apache.org/licenses/LICENSE-2.0
372 *
373 *  Unless required by applicable law or agreed to in writing, software
374 *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
375 *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
376 *  See the License for the specific language governing permissions and
377 *  limitations under the License.
378 */
379
380#include "common.h"
381
382#if defined(MBEDTLS_DEBUG_C)
383
384#include "ssl_debug_helpers.h"
385
386{functions}
387
388#endif /* MBEDTLS_DEBUG_C */
389/* End of automatically generated file. */
390
391'''
392
393
394def generate_ssl_debug_helpers(output_directory, mbedtls_root):
395    """
396        Generate functions of debug helps
397    """
398    mbedtls_root = os.path.abspath(
399        mbedtls_root or build_tree.guess_mbedtls_root())
400    with open(os.path.join(mbedtls_root, 'include/mbedtls/ssl.h')) as f:
401        source_code = remove_c_comments(f.read())
402
403    definitions = dict()
404    for start, instance in preprocess_c_source_code(source_code,
405                                                    EnumDefinition,
406                                                    SignatureAlgorithmDefinition,
407                                                    NamedGroupDefinition):
408        if start in definitions:
409            continue
410        if isinstance(instance, EnumDefinition):
411            definition = instance.generate_translation_function()
412        else:
413            definition = instance
414        definitions[start] = definition
415
416    function_definitions = [str(v) for _, v in sorted(definitions.items())]
417    if output_directory == sys.stdout:
418        sys.stdout.write(OUTPUT_C_TEMPLATE.format(
419            functions='\n'.join(function_definitions)))
420    else:
421        with open(os.path.join(output_directory, 'ssl_debug_helpers_generated.c'), 'w') as f:
422            f.write(OUTPUT_C_TEMPLATE.format(
423                functions='\n'.join(function_definitions)))
424
425
426def main():
427    """
428    Command line entry
429    """
430    parser = argparse.ArgumentParser()
431    parser.add_argument('--mbedtls-root', nargs='?', default=None,
432                        help='root directory of mbedtls source code')
433    parser.add_argument('output_directory', nargs='?',
434                        default='library', help='source/header files location')
435
436    args = parser.parse_args()
437
438    generate_ssl_debug_helpers(args.output_directory, args.mbedtls_root)
439    return 0
440
441
442if __name__ == '__main__':
443    sys.exit(main())
444