1#!/usr/bin/env python3
2
3import re
4import sys
5
6PATTERN = ['LFS_ASSERT', 'assert']
7PREFIX = 'LFS'
8MAXWIDTH = 16
9
10ASSERT = "__{PREFIX}_ASSERT_{TYPE}_{COMP}"
11FAIL = """
12__attribute__((unused))
13static void __{prefix}_assert_fail_{type}(
14        const char *file, int line, const char *comp,
15        {ctype} lh, size_t lsize,
16        {ctype} rh, size_t rsize) {{
17    printf("%s:%d:assert: assert failed with ", file, line);
18    __{prefix}_assert_print_{type}(lh, lsize);
19    printf(", expected %s ", comp);
20    __{prefix}_assert_print_{type}(rh, rsize);
21    printf("\\n");
22    fflush(NULL);
23    raise(SIGABRT);
24}}
25"""
26
27COMP = {
28    '==': 'eq',
29    '!=': 'ne',
30    '<=': 'le',
31    '>=': 'ge',
32    '<':  'lt',
33    '>':  'gt',
34}
35
36TYPE = {
37    'int': {
38        'ctype': 'intmax_t',
39        'fail': FAIL,
40        'print': """
41        __attribute__((unused))
42        static void __{prefix}_assert_print_{type}({ctype} v, size_t size) {{
43            (void)size;
44            printf("%"PRIiMAX, v);
45        }}
46        """,
47        'assert': """
48        #define __{PREFIX}_ASSERT_{TYPE}_{COMP}(file, line, lh, rh)
49        do {{
50            __typeof__(lh) _lh = lh;
51            __typeof__(lh) _rh = (__typeof__(lh))rh;
52            if (!(_lh {op} _rh)) {{
53                __{prefix}_assert_fail_{type}(file, line, "{comp}",
54                        (intmax_t)_lh, 0, (intmax_t)_rh, 0);
55            }}
56        }} while (0)
57        """
58    },
59    'bool': {
60        'ctype': 'bool',
61        'fail': FAIL,
62        'print': """
63        __attribute__((unused))
64        static void __{prefix}_assert_print_{type}({ctype} v, size_t size) {{
65            (void)size;
66            printf("%s", v ? "true" : "false");
67        }}
68        """,
69        'assert': """
70        #define __{PREFIX}_ASSERT_{TYPE}_{COMP}(file, line, lh, rh)
71        do {{
72            bool _lh = !!(lh);
73            bool _rh = !!(rh);
74            if (!(_lh {op} _rh)) {{
75                __{prefix}_assert_fail_{type}(file, line, "{comp}",
76                        _lh, 0, _rh, 0);
77            }}
78        }} while (0)
79        """
80    },
81    'mem': {
82        'ctype': 'const void *',
83        'fail': FAIL,
84        'print': """
85        __attribute__((unused))
86        static void __{prefix}_assert_print_{type}({ctype} v, size_t size) {{
87            const uint8_t *s = v;
88            printf("\\\"");
89            for (size_t i = 0; i < size && i < {maxwidth}; i++) {{
90                if (s[i] >= ' ' && s[i] <= '~') {{
91                    printf("%c", s[i]);
92                }} else {{
93                    printf("\\\\x%02x", s[i]);
94                }}
95            }}
96            if (size > {maxwidth}) {{
97                printf("...");
98            }}
99            printf("\\\"");
100        }}
101        """,
102        'assert': """
103        #define __{PREFIX}_ASSERT_{TYPE}_{COMP}(file, line, lh, rh, size)
104        do {{
105            const void *_lh = lh;
106            const void *_rh = rh;
107            if (!(memcmp(_lh, _rh, size) {op} 0)) {{
108                __{prefix}_assert_fail_{type}(file, line, "{comp}",
109                        _lh, size, _rh, size);
110            }}
111        }} while (0)
112        """
113    },
114    'str': {
115        'ctype': 'const char *',
116        'fail': FAIL,
117        'print': """
118        __attribute__((unused))
119        static void __{prefix}_assert_print_{type}({ctype} v, size_t size) {{
120            __{prefix}_assert_print_mem(v, size);
121        }}
122        """,
123        'assert': """
124        #define __{PREFIX}_ASSERT_{TYPE}_{COMP}(file, line, lh, rh)
125        do {{
126            const char *_lh = lh;
127            const char *_rh = rh;
128            if (!(strcmp(_lh, _rh) {op} 0)) {{
129                __{prefix}_assert_fail_{type}(file, line, "{comp}",
130                        _lh, strlen(_lh), _rh, strlen(_rh));
131            }}
132        }} while (0)
133        """
134    }
135}
136
137def mkdecls(outf, maxwidth=16):
138    outf.write("#include <stdio.h>\n")
139    outf.write("#include <stdbool.h>\n")
140    outf.write("#include <stdint.h>\n")
141    outf.write("#include <inttypes.h>\n")
142    outf.write("#include <signal.h>\n")
143
144    for type, desc in sorted(TYPE.items()):
145        format = {
146            'type': type.lower(), 'TYPE': type.upper(),
147            'ctype': desc['ctype'],
148            'prefix': PREFIX.lower(), 'PREFIX': PREFIX.upper(),
149            'maxwidth': maxwidth,
150        }
151        outf.write(re.sub('\s+', ' ',
152            desc['print'].strip().format(**format))+'\n')
153        outf.write(re.sub('\s+', ' ',
154            desc['fail'].strip().format(**format))+'\n')
155
156        for op, comp in sorted(COMP.items()):
157            format.update({
158                'comp': comp.lower(), 'COMP': comp.upper(),
159                'op': op,
160            })
161            outf.write(re.sub('\s+', ' ',
162                desc['assert'].strip().format(**format))+'\n')
163
164def mkassert(type, comp, lh, rh, size=None):
165    format = {
166        'type': type.lower(), 'TYPE': type.upper(),
167        'comp': comp.lower(), 'COMP': comp.upper(),
168        'prefix': PREFIX.lower(), 'PREFIX': PREFIX.upper(),
169        'lh': lh.strip(' '),
170        'rh': rh.strip(' '),
171        'size': size,
172    }
173    if size:
174        return ((ASSERT + '(__FILE__, __LINE__, {lh}, {rh}, {size})')
175            .format(**format))
176    else:
177        return ((ASSERT + '(__FILE__, __LINE__, {lh}, {rh})')
178            .format(**format))
179
180
181# simple recursive descent parser
182LEX = {
183    'ws':       [r'(?:\s|\n|#.*?\n|//.*?\n|/\*.*?\*/)+'],
184    'assert':   PATTERN,
185    'string':   [r'"(?:\\.|[^"])*"', r"'(?:\\.|[^'])\'"],
186    'arrow':    ['=>'],
187    'paren':    ['\(', '\)'],
188    'op':       ['strcmp', 'memcmp', '->'],
189    'comp':     ['==', '!=', '<=', '>=', '<', '>'],
190    'logic':    ['\&\&', '\|\|'],
191    'sep':      [':', ';', '\{', '\}', ','],
192}
193
194class ParseFailure(Exception):
195    def __init__(self, expected, found):
196        self.expected = expected
197        self.found = found
198
199    def __str__(self):
200        return "expected %r, found %s..." % (
201            self.expected, repr(self.found)[:70])
202
203class Parse:
204    def __init__(self, inf, lexemes):
205        p = '|'.join('(?P<%s>%s)' % (n, '|'.join(l))
206            for n, l in lexemes.items())
207        p = re.compile(p, re.DOTALL)
208        data = inf.read()
209        tokens = []
210        while True:
211            m = p.search(data)
212            if m:
213                if m.start() > 0:
214                    tokens.append((None, data[:m.start()]))
215                tokens.append((m.lastgroup, m.group()))
216                data = data[m.end():]
217            else:
218                tokens.append((None, data))
219                break
220        self.tokens = tokens
221        self.off = 0
222
223    def lookahead(self, *pattern):
224        if self.off < len(self.tokens):
225            token = self.tokens[self.off]
226            if token[0] in pattern or token[1] in pattern:
227                self.m = token[1]
228                return self.m
229        self.m = None
230        return self.m
231
232    def accept(self, *patterns):
233        m = self.lookahead(*patterns)
234        if m is not None:
235            self.off += 1
236        return m
237
238    def expect(self, *patterns):
239        m = self.accept(*patterns)
240        if not m:
241            raise ParseFailure(patterns, self.tokens[self.off:])
242        return m
243
244    def push(self):
245        return self.off
246
247    def pop(self, state):
248        self.off = state
249
250def passert(p):
251    def pastr(p):
252        p.expect('assert') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
253        p.expect('strcmp') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
254        lh = pexpr(p) ; p.accept('ws')
255        p.expect(',') ; p.accept('ws')
256        rh = pexpr(p) ; p.accept('ws')
257        p.expect(')') ; p.accept('ws')
258        comp = p.expect('comp') ; p.accept('ws')
259        p.expect('0') ; p.accept('ws')
260        p.expect(')')
261        return mkassert('str', COMP[comp], lh, rh)
262
263    def pamem(p):
264        p.expect('assert') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
265        p.expect('memcmp') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
266        lh = pexpr(p) ; p.accept('ws')
267        p.expect(',') ; p.accept('ws')
268        rh = pexpr(p) ; p.accept('ws')
269        p.expect(',') ; p.accept('ws')
270        size = pexpr(p) ; p.accept('ws')
271        p.expect(')') ; p.accept('ws')
272        comp = p.expect('comp') ; p.accept('ws')
273        p.expect('0') ; p.accept('ws')
274        p.expect(')')
275        return mkassert('mem', COMP[comp], lh, rh, size)
276
277    def paint(p):
278        p.expect('assert') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
279        lh = pexpr(p) ; p.accept('ws')
280        comp = p.expect('comp') ; p.accept('ws')
281        rh = pexpr(p) ; p.accept('ws')
282        p.expect(')')
283        return mkassert('int', COMP[comp], lh, rh)
284
285    def pabool(p):
286        p.expect('assert') ; p.accept('ws') ; p.expect('(') ; p.accept('ws')
287        lh = pexprs(p) ; p.accept('ws')
288        p.expect(')')
289        return mkassert('bool', 'eq', lh, 'true')
290
291    def pa(p):
292        return p.expect('assert')
293
294    state = p.push()
295    lastf = None
296    for pa in [pastr, pamem, paint, pabool, pa]:
297        try:
298            return pa(p)
299        except ParseFailure as f:
300            p.pop(state)
301            lastf = f
302    else:
303        raise lastf
304
305def pexpr(p):
306    res = []
307    while True:
308        if p.accept('('):
309            res.append(p.m)
310            while True:
311                res.append(pexprs(p))
312                if p.accept('sep'):
313                    res.append(p.m)
314                else:
315                    break
316            res.append(p.expect(')'))
317        elif p.lookahead('assert'):
318            res.append(passert(p))
319        elif p.accept('assert', 'ws', 'string', 'op', None):
320            res.append(p.m)
321        else:
322            return ''.join(res)
323
324def pexprs(p):
325    res = []
326    while True:
327        res.append(pexpr(p))
328        if p.accept('comp', 'logic', ','):
329            res.append(p.m)
330        else:
331            return ''.join(res)
332
333def pstmt(p):
334    ws = p.accept('ws') or ''
335    lh = pexprs(p)
336    if p.accept('=>'):
337        rh = pexprs(p)
338        return ws + mkassert('int', 'eq', lh, rh)
339    else:
340        return ws + lh
341
342
343def main(args):
344    inf = open(args.input, 'r') if args.input else sys.stdin
345    outf = open(args.output, 'w') if args.output else sys.stdout
346
347    lexemes = LEX.copy()
348    if args.pattern:
349        lexemes['assert'] = args.pattern
350    p = Parse(inf, lexemes)
351
352    # write extra verbose asserts
353    mkdecls(outf, maxwidth=args.maxwidth)
354    if args.input:
355        outf.write("#line %d \"%s\"\n" % (1, args.input))
356
357    # parse and write out stmt at a time
358    try:
359        while True:
360            outf.write(pstmt(p))
361            if p.accept('sep'):
362                outf.write(p.m)
363            else:
364                break
365    except ParseFailure as f:
366        pass
367
368    for i in range(p.off, len(p.tokens)):
369        outf.write(p.tokens[i][1])
370
371if __name__ == "__main__":
372    import argparse
373    parser = argparse.ArgumentParser(
374        description="Cpp step that increases assert verbosity")
375    parser.add_argument('input', nargs='?',
376        help="Input C file after cpp.")
377    parser.add_argument('-o', '--output', required=True,
378        help="Output C file.")
379    parser.add_argument('-p', '--pattern', action='append',
380        help="Patterns to search for starting an assert statement.")
381    parser.add_argument('--maxwidth', default=MAXWIDTH, type=int,
382        help="Maximum number of characters to display for strcmp and memcmp.")
383    main(parser.parse_args())
384