1#!/usr/bin/env python3
2#
3# Preprocessor that makes asserts easier to debug.
4#
5# Example:
6# ./scripts/prettyasserts.py -p LFS_ASSERT lfs.c -o lfs.a.c
7#
8# Copyright (c) 2022, The littlefs authors.
9# Copyright (c) 2020, Arm Limited. All rights reserved.
10# SPDX-License-Identifier: BSD-3-Clause
11#
12
13import re
14import sys
15
16# NOTE the use of macros here helps keep a consistent stack depth which
17# tools may rely on.
18#
19# If compilation errors are noisy consider using -ftrack-macro-expansion=0.
20#
21
22LIMIT = 16
23
24CMP = {
25    '==': 'eq',
26    '!=': 'ne',
27    '<=': 'le',
28    '>=': 'ge',
29    '<':  'lt',
30    '>':  'gt',
31}
32
33LEXEMES = {
34    'ws':       [r'(?:\s|\n|#.*?\n|//.*?\n|/\*.*?\*/)+'],
35    'assert':   ['assert'],
36    'arrow':    ['=>'],
37    'string':   [r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'"],
38    'paren':    ['\(', '\)'],
39    'cmp':      CMP.keys(),
40    'logic':    ['\&\&', '\|\|'],
41    'sep':      [':', ';', '\{', '\}', ','],
42    'op':       ['->'], # specifically ops that conflict with cmp
43}
44
45
46def openio(path, mode='r', buffering=-1):
47    # allow '-' for stdin/stdout
48    if path == '-':
49        if mode == 'r':
50            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
51        else:
52            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
53    else:
54        return open(path, mode, buffering)
55
56def write_header(f, limit=LIMIT):
57    f.writeln("// Generated by %s:" % sys.argv[0])
58    f.writeln("//")
59    f.writeln("// %s" % ' '.join(sys.argv))
60    f.writeln("//")
61    f.writeln()
62
63    f.writeln("#include <stdbool.h>")
64    f.writeln("#include <stdint.h>")
65    f.writeln("#include <inttypes.h>")
66    f.writeln("#include <stdio.h>")
67    f.writeln("#include <string.h>")
68    f.writeln("#include <signal.h>")
69    # give source a chance to define feature macros
70    f.writeln("#undef _FEATURES_H")
71    f.writeln()
72
73    # write print macros
74    f.writeln("__attribute__((unused))")
75    f.writeln("static void __pretty_assert_print_bool(")
76    f.writeln("        const void *v, size_t size) {")
77    f.writeln("    (void)size;")
78    f.writeln("    printf(\"%s\", *(const bool*)v ? \"true\" : \"false\");")
79    f.writeln("}")
80    f.writeln()
81    f.writeln("__attribute__((unused))")
82    f.writeln("static void __pretty_assert_print_int(")
83    f.writeln("        const void *v, size_t size) {")
84    f.writeln("    (void)size;")
85    f.writeln("    printf(\"%\"PRIiMAX, *(const intmax_t*)v);")
86    f.writeln("}")
87    f.writeln()
88    f.writeln("__attribute__((unused))")
89    f.writeln("static void __pretty_assert_print_ptr(")
90    f.writeln("        const void *v, size_t size) {")
91    f.writeln("    (void)size;")
92    f.writeln("    printf(\"%p\", v);")
93    f.writeln("}")
94    f.writeln()
95    f.writeln("__attribute__((unused))")
96    f.writeln("static void __pretty_assert_print_mem(")
97    f.writeln("        const void *v, size_t size) {")
98    f.writeln("    const uint8_t *v_ = v;")
99    f.writeln("    printf(\"\\\"\");")
100    f.writeln("    for (size_t i = 0; i < size && i < %d; i++) {" % limit)
101    f.writeln("        if (v_[i] >= ' ' && v_[i] <= '~') {")
102    f.writeln("            printf(\"%c\", v_[i]);")
103    f.writeln("        } else {")
104    f.writeln("            printf(\"\\\\x%02x\", v_[i]);")
105    f.writeln("        }")
106    f.writeln("    }")
107    f.writeln("    if (size > %d) {" % limit)
108    f.writeln("        printf(\"...\");")
109    f.writeln("    }")
110    f.writeln("    printf(\"\\\"\");")
111    f.writeln("}")
112    f.writeln()
113    f.writeln("__attribute__((unused))")
114    f.writeln("static void __pretty_assert_print_str(")
115    f.writeln("        const void *v, size_t size) {")
116    f.writeln("    __pretty_assert_print_mem(v, size);")
117    f.writeln("}")
118    f.writeln()
119    f.writeln("__attribute__((unused, noinline))")
120    f.writeln("static void __pretty_assert_fail(")
121    f.writeln("        const char *file, int line,")
122    f.writeln("        void (*type_print_cb)(const void*, size_t),")
123    f.writeln("        const char *cmp,")
124    f.writeln("        const void *lh, size_t lsize,")
125    f.writeln("        const void *rh, size_t rsize) {")
126    f.writeln("    printf(\"%s:%d:assert: assert failed with \", file, line);")
127    f.writeln("    type_print_cb(lh, lsize);")
128    f.writeln("    printf(\", expected %s \", cmp);")
129    f.writeln("    type_print_cb(rh, rsize);")
130    f.writeln("    printf(\"\\n\");")
131    f.writeln("    fflush(NULL);")
132    f.writeln("    raise(SIGABRT);")
133    f.writeln("}")
134    f.writeln()
135
136    # write assert macros
137    for op, cmp in sorted(CMP.items()):
138        f.writeln("#define __PRETTY_ASSERT_BOOL_%s(lh, rh) do { \\"
139            % cmp.upper())
140        f.writeln("    bool _lh = !!(lh); \\")
141        f.writeln("    bool _rh = !!(rh); \\")
142        f.writeln("    if (!(_lh %s _rh)) { \\" % op)
143        f.writeln("        __pretty_assert_fail( \\")
144        f.writeln("                __FILE__, __LINE__, \\")
145        f.writeln("                __pretty_assert_print_bool, \"%s\", \\"
146            % cmp)
147        f.writeln("                &_lh, 0, \\")
148        f.writeln("                &_rh, 0); \\")
149        f.writeln("    } \\")
150        f.writeln("} while (0)")
151    for op, cmp in sorted(CMP.items()):
152        f.writeln("#define __PRETTY_ASSERT_INT_%s(lh, rh) do { \\"
153            % cmp.upper())
154        f.writeln("    __typeof__(lh) _lh = lh; \\")
155        f.writeln("    __typeof__(lh) _rh = rh; \\")
156        f.writeln("    if (!(_lh %s _rh)) { \\" % op)
157        f.writeln("        __pretty_assert_fail( \\")
158        f.writeln("                __FILE__, __LINE__, \\")
159        f.writeln("                __pretty_assert_print_int, \"%s\", \\"
160            % cmp)
161        f.writeln("                &(intmax_t){_lh}, 0, \\")
162        f.writeln("                &(intmax_t){_rh}, 0); \\")
163        f.writeln("    } \\")
164        f.writeln("} while (0)")
165    for op, cmp in sorted(CMP.items()):
166        f.writeln("#define __PRETTY_ASSERT_MEM_%s(lh, rh, size) do { \\"
167            % cmp.upper())
168        f.writeln("    const void *_lh = lh; \\")
169        f.writeln("    const void *_rh = rh; \\")
170        f.writeln("    if (!(memcmp(_lh, _rh, size) %s 0)) { \\" % op)
171        f.writeln("        __pretty_assert_fail( \\")
172        f.writeln("                __FILE__, __LINE__, \\")
173        f.writeln("                __pretty_assert_print_mem, \"%s\", \\"
174            % cmp)
175        f.writeln("                _lh, size, \\")
176        f.writeln("                _rh, size); \\")
177        f.writeln("    } \\")
178        f.writeln("} while (0)")
179    for op, cmp in sorted(CMP.items()):
180        f.writeln("#define __PRETTY_ASSERT_STR_%s(lh, rh) do { \\"
181            % cmp.upper())
182        f.writeln("    const char *_lh = lh; \\")
183        f.writeln("    const char *_rh = rh; \\")
184        f.writeln("    if (!(strcmp(_lh, _rh) %s 0)) { \\" % op)
185        f.writeln("        __pretty_assert_fail( \\")
186        f.writeln("                __FILE__, __LINE__, \\")
187        f.writeln("                __pretty_assert_print_str, \"%s\", \\"
188            % cmp)
189        f.writeln("                _lh, strlen(_lh), \\")
190        f.writeln("                _rh, strlen(_rh)); \\")
191        f.writeln("    } \\")
192        f.writeln("} while (0)")
193    for op, cmp in sorted(CMP.items()):
194        # Only EQ and NE are supported when compared to NULL.
195        if cmp not in ['eq', 'ne']:
196            continue
197        f.writeln("#define __PRETTY_ASSERT_PTR_%s(lh, rh) do { \\"
198            % cmp.upper())
199        f.writeln("    const void *_lh = (const void*)(uintptr_t)lh; \\")
200        f.writeln("    const void *_rh = (const void*)(uintptr_t)rh; \\")
201        f.writeln("    if (!(_lh %s _rh)) { \\" % op)
202        f.writeln("        __pretty_assert_fail( \\")
203        f.writeln("                __FILE__, __LINE__, \\")
204        f.writeln("                __pretty_assert_print_ptr, \"%s\", \\"
205            % cmp)
206        f.writeln("                (const void*){_lh}, 0, \\")
207        f.writeln("                (const void*){_rh}, 0); \\")
208        f.writeln("    } \\")
209        f.writeln("} while (0)")
210    f.writeln()
211    f.writeln()
212
213def mkassert(type, cmp, lh, rh, size=None):
214    if size is not None:
215        return ("__PRETTY_ASSERT_%s_%s(%s, %s, %s)"
216            % (type.upper(), cmp.upper(), lh, rh, size))
217    else:
218        return ("__PRETTY_ASSERT_%s_%s(%s, %s)"
219            % (type.upper(), cmp.upper(), lh, rh))
220
221
222# simple recursive descent parser
223class ParseFailure(Exception):
224    def __init__(self, expected, found):
225        self.expected = expected
226        self.found = found
227
228    def __str__(self):
229        return "expected %r, found %s..." % (
230            self.expected, repr(self.found)[:70])
231
232class Parser:
233    def __init__(self, in_f, lexemes=LEXEMES):
234        p = '|'.join('(?P<%s>%s)' % (n, '|'.join(l))
235            for n, l in lexemes.items())
236        p = re.compile(p, re.DOTALL)
237        data = in_f.read()
238        tokens = []
239        line = 1
240        col = 0
241        while True:
242            m = p.search(data)
243            if m:
244                if m.start() > 0:
245                    tokens.append((None, data[:m.start()], line, col))
246                tokens.append((m.lastgroup, m.group(), line, col))
247                data = data[m.end():]
248            else:
249                tokens.append((None, data, line, col))
250                break
251        self.tokens = tokens
252        self.off = 0
253
254    def lookahead(self, *pattern):
255        if self.off < len(self.tokens):
256            token = self.tokens[self.off]
257            if token[0] in pattern or token[1] in pattern:
258                self.m = token[1]
259                return self.m
260        self.m = None
261        return self.m
262
263    def accept(self, *patterns):
264        m = self.lookahead(*patterns)
265        if m is not None:
266            self.off += 1
267        return m
268
269    def expect(self, *patterns):
270        m = self.accept(*patterns)
271        if not m:
272            raise ParseFailure(patterns, self.tokens[self.off:])
273        return m
274
275    def push(self):
276        return self.off
277
278    def pop(self, state):
279        self.off = state
280
281def p_assert(p):
282    state = p.push()
283
284    # assert(memcmp(a,b,size) cmp 0)?
285    try:
286        p.expect('assert') ; p.accept('ws')
287        p.expect('(') ; p.accept('ws')
288        p.expect('memcmp') ; p.accept('ws')
289        p.expect('(') ; p.accept('ws')
290        lh = p_expr(p) ; p.accept('ws')
291        p.expect(',') ; p.accept('ws')
292        rh = p_expr(p) ; p.accept('ws')
293        p.expect(',') ; p.accept('ws')
294        size = p_expr(p) ; p.accept('ws')
295        p.expect(')') ; p.accept('ws')
296        cmp = p.expect('cmp') ; p.accept('ws')
297        p.expect('0') ; p.accept('ws')
298        p.expect(')')
299        return mkassert('mem', CMP[cmp], lh, rh, size)
300    except ParseFailure:
301        p.pop(state)
302
303    # assert(strcmp(a,b) cmp 0)?
304    try:
305        p.expect('assert') ; p.accept('ws')
306        p.expect('(') ; p.accept('ws')
307        p.expect('strcmp') ; p.accept('ws')
308        p.expect('(') ; p.accept('ws')
309        lh = p_expr(p) ; p.accept('ws')
310        p.expect(',') ; p.accept('ws')
311        rh = p_expr(p) ; p.accept('ws')
312        p.expect(')') ; p.accept('ws')
313        cmp = p.expect('cmp') ; p.accept('ws')
314        p.expect('0') ; p.accept('ws')
315        p.expect(')')
316        return mkassert('str', CMP[cmp], lh, rh)
317    except ParseFailure:
318        p.pop(state)
319
320    # assert(a cmp b)?
321    try:
322        p.expect('assert') ; p.accept('ws')
323        p.expect('(') ; p.accept('ws')
324        lh = p_expr(p) ; p.accept('ws')
325        cmp = p.expect('cmp') ; p.accept('ws')
326        rh = p_expr(p) ; p.accept('ws')
327        p.expect(')')
328        if rh == 'NULL' or lh == 'NULL':
329            return mkassert('ptr', CMP[cmp], lh, rh)
330        return mkassert('int', CMP[cmp], lh, rh)
331    except ParseFailure:
332        p.pop(state)
333
334    # assert(a)?
335    p.expect('assert') ; p.accept('ws')
336    p.expect('(') ; p.accept('ws')
337    lh = p_exprs(p) ; p.accept('ws')
338    p.expect(')')
339    return mkassert('bool', 'eq', lh, 'true')
340
341def p_expr(p):
342    res = []
343    while True:
344        if p.accept('('):
345            res.append(p.m)
346            while True:
347                res.append(p_exprs(p))
348                if p.accept('sep'):
349                    res.append(p.m)
350                else:
351                    break
352            res.append(p.expect(')'))
353        elif p.lookahead('assert'):
354            state = p.push()
355            try:
356                res.append(p_assert(p))
357            except ParseFailure:
358                p.pop(state)
359                res.append(p.expect('assert'))
360        elif p.accept('string', 'op', 'ws', None):
361            res.append(p.m)
362        else:
363            return ''.join(res)
364
365def p_exprs(p):
366    res = []
367    while True:
368        res.append(p_expr(p))
369        if p.accept('cmp', 'logic', ','):
370            res.append(p.m)
371        else:
372            return ''.join(res)
373
374def p_stmt(p):
375    ws = p.accept('ws') or ''
376
377    # memcmp(lh,rh,size) => 0?
378    if p.lookahead('memcmp'):
379        state = p.push()
380        try:
381            p.expect('memcmp') ; p.accept('ws')
382            p.expect('(') ; p.accept('ws')
383            lh = p_expr(p) ; p.accept('ws')
384            p.expect(',') ; p.accept('ws')
385            rh = p_expr(p) ; p.accept('ws')
386            p.expect(',') ; p.accept('ws')
387            size = p_expr(p) ; p.accept('ws')
388            p.expect(')') ; p.accept('ws')
389            p.expect('=>') ; p.accept('ws')
390            p.expect('0') ; p.accept('ws')
391            return ws + mkassert('mem', 'eq', lh, rh, size)
392        except ParseFailure:
393            p.pop(state)
394
395    # strcmp(lh,rh) => 0?
396    if p.lookahead('strcmp'):
397        state = p.push()
398        try:
399            p.expect('strcmp') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
400            lh = p_expr(p) ; p.accept('ws')
401            p.expect(',') ; p.accept('ws')
402            rh = p_expr(p) ; p.accept('ws')
403            p.expect(')') ; p.accept('ws')
404            p.expect('=>') ; p.accept('ws')
405            p.expect('0') ; p.accept('ws')
406            return ws + mkassert('str', 'eq', lh, rh)
407        except ParseFailure:
408            p.pop(state)
409
410    # lh => rh?
411    lh = p_exprs(p)
412    if p.accept('=>'):
413        rh = p_exprs(p)
414        return ws + mkassert('int', 'eq', lh, rh)
415    else:
416        return ws + lh
417
418def main(input=None, output=None, pattern=[], limit=LIMIT):
419    with openio(input or '-', 'r') as in_f:
420        # create parser
421        lexemes = LEXEMES.copy()
422        lexemes['assert'] += pattern
423        p = Parser(in_f, lexemes)
424
425        with openio(output or '-', 'w') as f:
426            def writeln(s=''):
427                f.write(s)
428                f.write('\n')
429            f.writeln = writeln
430
431            # write extra verbose asserts
432            write_header(f, limit=limit)
433            if input is not None:
434                f.writeln("#line %d \"%s\"" % (1, input))
435
436            # parse and write out stmt at a time
437            try:
438                while True:
439                    f.write(p_stmt(p))
440                    if p.accept('sep'):
441                        f.write(p.m)
442                    else:
443                        break
444            except ParseFailure as e:
445                print('warning: %s' % e)
446                pass
447
448            for i in range(p.off, len(p.tokens)):
449                f.write(p.tokens[i][1])
450
451
452if __name__ == "__main__":
453    import argparse
454    import sys
455    parser = argparse.ArgumentParser(
456        description="Preprocessor that makes asserts easier to debug.",
457        allow_abbrev=False)
458    parser.add_argument(
459        'input',
460        help="Input C file.")
461    parser.add_argument(
462        '-o', '--output',
463        required=True,
464        help="Output C file.")
465    parser.add_argument(
466        '-p', '--pattern',
467        action='append',
468        help="Regex patterns to search for starting an assert statement. This"
469            " implicitly includes \"assert\" and \"=>\".")
470    parser.add_argument(
471        '-l', '--limit',
472        type=lambda x: int(x, 0),
473        default=LIMIT,
474        help="Maximum number of characters to display in strcmp and memcmp. "
475            "Defaults to %r." % LIMIT)
476    sys.exit(main(**{k: v
477        for k, v in vars(parser.parse_intermixed_args()).items()
478        if v is not None}))
479