1#!/usr/bin/env python3
2#
3# Preprocessor that makes asserts easier to debug.
4#
5# Example:
6# ./scripts/prettyasserts.py -p LFS_ASSERT lfs.c -o lfs.a.c
7#
8# Copyright (c) 2022, The littlefs authors.
9# Copyright (c) 2020, Arm Limited. All rights reserved.
10# SPDX-License-Identifier: BSD-3-Clause
11#
12
13import re
14import sys
15
16# NOTE the use of macros here helps keep a consistent stack depth which
17# tools may rely on.
18#
19# If compilation errors are noisy consider using -ftrack-macro-expansion=0.
20#
21
22LIMIT = 16
23
24CMP = {
25    '==': 'eq',
26    '!=': 'ne',
27    '<=': 'le',
28    '>=': 'ge',
29    '<':  'lt',
30    '>':  'gt',
31}
32
33LEXEMES = {
34    'ws':       [r'(?:\s|\n|#.*?\n|//.*?\n|/\*.*?\*/)+'],
35    'assert':   ['assert'],
36    'arrow':    ['=>'],
37    'string':   [r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'"],
38    'paren':    ['\(', '\)'],
39    'cmp':      CMP.keys(),
40    'logic':    ['\&\&', '\|\|'],
41    'sep':      [':', ';', '\{', '\}', ','],
42    'op':       ['->'], # specifically ops that conflict with cmp
43}
44
45
46def openio(path, mode='r', buffering=-1):
47    # allow '-' for stdin/stdout
48    if path == '-':
49        if mode == 'r':
50            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
51        else:
52            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
53    else:
54        return open(path, mode, buffering)
55
56def write_header(f, limit=LIMIT):
57    f.writeln("// Generated by %s:" % sys.argv[0])
58    f.writeln("//")
59    f.writeln("// %s" % ' '.join(sys.argv))
60    f.writeln("//")
61    f.writeln()
62
63    f.writeln("#include <stdbool.h>")
64    f.writeln("#include <stdint.h>")
65    f.writeln("#include <inttypes.h>")
66    f.writeln("#include <stdio.h>")
67    f.writeln("#include <string.h>")
68    f.writeln("#include <signal.h>")
69    # give source a chance to define feature macros
70    f.writeln("#undef _FEATURES_H")
71    f.writeln()
72
73    # write print macros
74    f.writeln("__attribute__((unused))")
75    f.writeln("static void __pretty_assert_print_bool(")
76    f.writeln("        const void *v, size_t size) {")
77    f.writeln("    (void)size;")
78    f.writeln("    printf(\"%s\", *(const bool*)v ? \"true\" : \"false\");")
79    f.writeln("}")
80    f.writeln()
81    f.writeln("__attribute__((unused))")
82    f.writeln("static void __pretty_assert_print_int(")
83    f.writeln("        const void *v, size_t size) {")
84    f.writeln("    (void)size;")
85    f.writeln("    printf(\"%\"PRIiMAX, *(const intmax_t*)v);")
86    f.writeln("}")
87    f.writeln()
88    f.writeln("__attribute__((unused))")
89    f.writeln("static void __pretty_assert_print_mem(")
90    f.writeln("        const void *v, size_t size) {")
91    f.writeln("    const uint8_t *v_ = v;")
92    f.writeln("    printf(\"\\\"\");")
93    f.writeln("    for (size_t i = 0; i < size && i < %d; i++) {" % limit)
94    f.writeln("        if (v_[i] >= ' ' && v_[i] <= '~') {")
95    f.writeln("            printf(\"%c\", v_[i]);")
96    f.writeln("        } else {")
97    f.writeln("            printf(\"\\\\x%02x\", v_[i]);")
98    f.writeln("        }")
99    f.writeln("    }")
100    f.writeln("    if (size > %d) {" % limit)
101    f.writeln("        printf(\"...\");")
102    f.writeln("    }")
103    f.writeln("    printf(\"\\\"\");")
104    f.writeln("}")
105    f.writeln()
106    f.writeln("__attribute__((unused))")
107    f.writeln("static void __pretty_assert_print_str(")
108    f.writeln("        const void *v, size_t size) {")
109    f.writeln("    __pretty_assert_print_mem(v, size);")
110    f.writeln("}")
111    f.writeln()
112    f.writeln("__attribute__((unused, noinline))")
113    f.writeln("static void __pretty_assert_fail(")
114    f.writeln("        const char *file, int line,")
115    f.writeln("        void (*type_print_cb)(const void*, size_t),")
116    f.writeln("        const char *cmp,")
117    f.writeln("        const void *lh, size_t lsize,")
118    f.writeln("        const void *rh, size_t rsize) {")
119    f.writeln("    printf(\"%s:%d:assert: assert failed with \", file, line);")
120    f.writeln("    type_print_cb(lh, lsize);")
121    f.writeln("    printf(\", expected %s \", cmp);")
122    f.writeln("    type_print_cb(rh, rsize);")
123    f.writeln("    printf(\"\\n\");")
124    f.writeln("    fflush(NULL);")
125    f.writeln("    raise(SIGABRT);")
126    f.writeln("}")
127    f.writeln()
128
129    # write assert macros
130    for op, cmp in sorted(CMP.items()):
131        f.writeln("#define __PRETTY_ASSERT_BOOL_%s(lh, rh) do { \\"
132            % cmp.upper())
133        f.writeln("    bool _lh = !!(lh); \\")
134        f.writeln("    bool _rh = !!(rh); \\")
135        f.writeln("    if (!(_lh %s _rh)) { \\" % op)
136        f.writeln("        __pretty_assert_fail( \\")
137        f.writeln("                __FILE__, __LINE__, \\")
138        f.writeln("                __pretty_assert_print_bool, \"%s\", \\"
139            % cmp)
140        f.writeln("                &_lh, 0, \\")
141        f.writeln("                &_rh, 0); \\")
142        f.writeln("    } \\")
143        f.writeln("} while (0)")
144    for op, cmp in sorted(CMP.items()):
145        f.writeln("#define __PRETTY_ASSERT_INT_%s(lh, rh) do { \\"
146            % cmp.upper())
147        f.writeln("    __typeof__(lh) _lh = lh; \\")
148        f.writeln("    __typeof__(lh) _rh = rh; \\")
149        f.writeln("    if (!(_lh %s _rh)) { \\" % op)
150        f.writeln("        __pretty_assert_fail( \\")
151        f.writeln("                __FILE__, __LINE__, \\")
152        f.writeln("                __pretty_assert_print_int, \"%s\", \\"
153            % cmp)
154        f.writeln("                &(intmax_t){_lh}, 0, \\")
155        f.writeln("                &(intmax_t){_rh}, 0); \\")
156        f.writeln("    } \\")
157        f.writeln("} while (0)")
158    for op, cmp in sorted(CMP.items()):
159        f.writeln("#define __PRETTY_ASSERT_MEM_%s(lh, rh, size) do { \\"
160            % cmp.upper())
161        f.writeln("    const void *_lh = lh; \\")
162        f.writeln("    const void *_rh = rh; \\")
163        f.writeln("    if (!(memcmp(_lh, _rh, size) %s 0)) { \\" % op)
164        f.writeln("        __pretty_assert_fail( \\")
165        f.writeln("                __FILE__, __LINE__, \\")
166        f.writeln("                __pretty_assert_print_mem, \"%s\", \\"
167            % cmp)
168        f.writeln("                _lh, size, \\")
169        f.writeln("                _rh, size); \\")
170        f.writeln("    } \\")
171        f.writeln("} while (0)")
172    for op, cmp in sorted(CMP.items()):
173        f.writeln("#define __PRETTY_ASSERT_STR_%s(lh, rh) do { \\"
174            % cmp.upper())
175        f.writeln("    const char *_lh = lh; \\")
176        f.writeln("    const char *_rh = rh; \\")
177        f.writeln("    if (!(strcmp(_lh, _rh) %s 0)) { \\" % op)
178        f.writeln("        __pretty_assert_fail( \\")
179        f.writeln("                __FILE__, __LINE__, \\")
180        f.writeln("                __pretty_assert_print_str, \"%s\", \\"
181            % cmp)
182        f.writeln("                _lh, strlen(_lh), \\")
183        f.writeln("                _rh, strlen(_rh)); \\")
184        f.writeln("    } \\")
185        f.writeln("} while (0)")
186    f.writeln()
187    f.writeln()
188
189def mkassert(type, cmp, lh, rh, size=None):
190    if size is not None:
191        return ("__PRETTY_ASSERT_%s_%s(%s, %s, %s)"
192            % (type.upper(), cmp.upper(), lh, rh, size))
193    else:
194        return ("__PRETTY_ASSERT_%s_%s(%s, %s)"
195            % (type.upper(), cmp.upper(), lh, rh))
196
197
198# simple recursive descent parser
199class ParseFailure(Exception):
200    def __init__(self, expected, found):
201        self.expected = expected
202        self.found = found
203
204    def __str__(self):
205        return "expected %r, found %s..." % (
206            self.expected, repr(self.found)[:70])
207
208class Parser:
209    def __init__(self, in_f, lexemes=LEXEMES):
210        p = '|'.join('(?P<%s>%s)' % (n, '|'.join(l))
211            for n, l in lexemes.items())
212        p = re.compile(p, re.DOTALL)
213        data = in_f.read()
214        tokens = []
215        line = 1
216        col = 0
217        while True:
218            m = p.search(data)
219            if m:
220                if m.start() > 0:
221                    tokens.append((None, data[:m.start()], line, col))
222                tokens.append((m.lastgroup, m.group(), line, col))
223                data = data[m.end():]
224            else:
225                tokens.append((None, data, line, col))
226                break
227        self.tokens = tokens
228        self.off = 0
229
230    def lookahead(self, *pattern):
231        if self.off < len(self.tokens):
232            token = self.tokens[self.off]
233            if token[0] in pattern or token[1] in pattern:
234                self.m = token[1]
235                return self.m
236        self.m = None
237        return self.m
238
239    def accept(self, *patterns):
240        m = self.lookahead(*patterns)
241        if m is not None:
242            self.off += 1
243        return m
244
245    def expect(self, *patterns):
246        m = self.accept(*patterns)
247        if not m:
248            raise ParseFailure(patterns, self.tokens[self.off:])
249        return m
250
251    def push(self):
252        return self.off
253
254    def pop(self, state):
255        self.off = state
256
257def p_assert(p):
258    state = p.push()
259
260    # assert(memcmp(a,b,size) cmp 0)?
261    try:
262        p.expect('assert') ; p.accept('ws')
263        p.expect('(') ; p.accept('ws')
264        p.expect('memcmp') ; p.accept('ws')
265        p.expect('(') ; p.accept('ws')
266        lh = p_expr(p) ; p.accept('ws')
267        p.expect(',') ; p.accept('ws')
268        rh = p_expr(p) ; p.accept('ws')
269        p.expect(',') ; p.accept('ws')
270        size = p_expr(p) ; p.accept('ws')
271        p.expect(')') ; p.accept('ws')
272        cmp = p.expect('cmp') ; p.accept('ws')
273        p.expect('0') ; p.accept('ws')
274        p.expect(')')
275        return mkassert('mem', CMP[cmp], lh, rh, size)
276    except ParseFailure:
277        p.pop(state)
278
279    # assert(strcmp(a,b) cmp 0)?
280    try:
281        p.expect('assert') ; p.accept('ws')
282        p.expect('(') ; p.accept('ws')
283        p.expect('strcmp') ; p.accept('ws')
284        p.expect('(') ; p.accept('ws')
285        lh = p_expr(p) ; p.accept('ws')
286        p.expect(',') ; p.accept('ws')
287        rh = p_expr(p) ; p.accept('ws')
288        p.expect(')') ; p.accept('ws')
289        cmp = p.expect('cmp') ; p.accept('ws')
290        p.expect('0') ; p.accept('ws')
291        p.expect(')')
292        return mkassert('str', CMP[cmp], lh, rh)
293    except ParseFailure:
294        p.pop(state)
295
296    # assert(a cmp b)?
297    try:
298        p.expect('assert') ; p.accept('ws')
299        p.expect('(') ; p.accept('ws')
300        lh = p_expr(p) ; p.accept('ws')
301        cmp = p.expect('cmp') ; p.accept('ws')
302        rh = p_expr(p) ; p.accept('ws')
303        p.expect(')')
304        return mkassert('int', CMP[cmp], lh, rh)
305    except ParseFailure:
306        p.pop(state)
307
308    # assert(a)?
309    p.expect('assert') ; p.accept('ws')
310    p.expect('(') ; p.accept('ws')
311    lh = p_exprs(p) ; p.accept('ws')
312    p.expect(')')
313    return mkassert('bool', 'eq', lh, 'true')
314
315def p_expr(p):
316    res = []
317    while True:
318        if p.accept('('):
319            res.append(p.m)
320            while True:
321                res.append(p_exprs(p))
322                if p.accept('sep'):
323                    res.append(p.m)
324                else:
325                    break
326            res.append(p.expect(')'))
327        elif p.lookahead('assert'):
328            state = p.push()
329            try:
330                res.append(p_assert(p))
331            except ParseFailure:
332                p.pop(state)
333                res.append(p.expect('assert'))
334        elif p.accept('string', 'op', 'ws', None):
335            res.append(p.m)
336        else:
337            return ''.join(res)
338
339def p_exprs(p):
340    res = []
341    while True:
342        res.append(p_expr(p))
343        if p.accept('cmp', 'logic', ','):
344            res.append(p.m)
345        else:
346            return ''.join(res)
347
348def p_stmt(p):
349    ws = p.accept('ws') or ''
350
351    # memcmp(lh,rh,size) => 0?
352    if p.lookahead('memcmp'):
353        state = p.push()
354        try:
355            p.expect('memcmp') ; p.accept('ws')
356            p.expect('(') ; p.accept('ws')
357            lh = p_expr(p) ; p.accept('ws')
358            p.expect(',') ; p.accept('ws')
359            rh = p_expr(p) ; p.accept('ws')
360            p.expect(',') ; p.accept('ws')
361            size = p_expr(p) ; p.accept('ws')
362            p.expect(')') ; p.accept('ws')
363            p.expect('=>') ; p.accept('ws')
364            p.expect('0') ; p.accept('ws')
365            return ws + mkassert('mem', 'eq', lh, rh, size)
366        except ParseFailure:
367            p.pop(state)
368
369    # strcmp(lh,rh) => 0?
370    if p.lookahead('strcmp'):
371        state = p.push()
372        try:
373            p.expect('strcmp') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
374            lh = p_expr(p) ; p.accept('ws')
375            p.expect(',') ; p.accept('ws')
376            rh = p_expr(p) ; p.accept('ws')
377            p.expect(')') ; p.accept('ws')
378            p.expect('=>') ; p.accept('ws')
379            p.expect('0') ; p.accept('ws')
380            return ws + mkassert('str', 'eq', lh, rh)
381        except ParseFailure:
382            p.pop(state)
383
384    # lh => rh?
385    lh = p_exprs(p)
386    if p.accept('=>'):
387        rh = p_exprs(p)
388        return ws + mkassert('int', 'eq', lh, rh)
389    else:
390        return ws + lh
391
392def main(input=None, output=None, pattern=[], limit=LIMIT):
393    with openio(input or '-', 'r') as in_f:
394        # create parser
395        lexemes = LEXEMES.copy()
396        lexemes['assert'] += pattern
397        p = Parser(in_f, lexemes)
398
399        with openio(output or '-', 'w') as f:
400            def writeln(s=''):
401                f.write(s)
402                f.write('\n')
403            f.writeln = writeln
404
405            # write extra verbose asserts
406            write_header(f, limit=limit)
407            if input is not None:
408                f.writeln("#line %d \"%s\"" % (1, input))
409
410            # parse and write out stmt at a time
411            try:
412                while True:
413                    f.write(p_stmt(p))
414                    if p.accept('sep'):
415                        f.write(p.m)
416                    else:
417                        break
418            except ParseFailure as e:
419                print('warning: %s' % e)
420                pass
421
422            for i in range(p.off, len(p.tokens)):
423                f.write(p.tokens[i][1])
424
425
426if __name__ == "__main__":
427    import argparse
428    import sys
429    parser = argparse.ArgumentParser(
430        description="Preprocessor that makes asserts easier to debug.",
431        allow_abbrev=False)
432    parser.add_argument(
433        'input',
434        help="Input C file.")
435    parser.add_argument(
436        '-o', '--output',
437        required=True,
438        help="Output C file.")
439    parser.add_argument(
440        '-p', '--pattern',
441        action='append',
442        help="Regex patterns to search for starting an assert statement. This"
443            " implicitly includes \"assert\" and \"=>\".")
444    parser.add_argument(
445        '-l', '--limit',
446        type=lambda x: int(x, 0),
447        default=LIMIT,
448        help="Maximum number of characters to display in strcmp and memcmp. "
449            "Defaults to %r." % LIMIT)
450    sys.exit(main(**{k: v
451        for k, v in vars(parser.parse_intermixed_args()).items()
452        if v is not None}))
453