1#!/usr/bin/env python3
2#
3# Script to find stack usage at the function level. Will detect recursion and
4# report as infinite stack usage.
5#
6# Example:
7# ./scripts/stack.py lfs.ci lfs_util.ci -Slimit
8#
9# Copyright (c) 2022, The littlefs authors.
10# SPDX-License-Identifier: BSD-3-Clause
11#
12
13import collections as co
14import csv
15import itertools as it
16import math as m
17import os
18import re
19
20
21
22# integer fields
23class Int(co.namedtuple('Int', 'x')):
24    __slots__ = ()
25    def __new__(cls, x=0):
26        if isinstance(x, Int):
27            return x
28        if isinstance(x, str):
29            try:
30                x = int(x, 0)
31            except ValueError:
32                # also accept +-∞ and +-inf
33                if re.match('^\s*\+?\s*(?:∞|inf)\s*$', x):
34                    x = m.inf
35                elif re.match('^\s*-\s*(?:∞|inf)\s*$', x):
36                    x = -m.inf
37                else:
38                    raise
39        assert isinstance(x, int) or m.isinf(x), x
40        return super().__new__(cls, x)
41
42    def __str__(self):
43        if self.x == m.inf:
44            return '∞'
45        elif self.x == -m.inf:
46            return '-∞'
47        else:
48            return str(self.x)
49
50    def __int__(self):
51        assert not m.isinf(self.x)
52        return self.x
53
54    def __float__(self):
55        return float(self.x)
56
57    none = '%7s' % '-'
58    def table(self):
59        return '%7s' % (self,)
60
61    diff_none = '%7s' % '-'
62    diff_table = table
63
64    def diff_diff(self, other):
65        new = self.x if self else 0
66        old = other.x if other else 0
67        diff = new - old
68        if diff == +m.inf:
69            return '%7s' % '+∞'
70        elif diff == -m.inf:
71            return '%7s' % '-∞'
72        else:
73            return '%+7d' % diff
74
75    def ratio(self, other):
76        new = self.x if self else 0
77        old = other.x if other else 0
78        if m.isinf(new) and m.isinf(old):
79            return 0.0
80        elif m.isinf(new):
81            return +m.inf
82        elif m.isinf(old):
83            return -m.inf
84        elif not old and not new:
85            return 0.0
86        elif not old:
87            return 1.0
88        else:
89            return (new-old) / old
90
91    def __add__(self, other):
92        return self.__class__(self.x + other.x)
93
94    def __sub__(self, other):
95        return self.__class__(self.x - other.x)
96
97    def __mul__(self, other):
98        return self.__class__(self.x * other.x)
99
100# size results
101class StackResult(co.namedtuple('StackResult', [
102        'file', 'function', 'frame', 'limit', 'children'])):
103    _by = ['file', 'function']
104    _fields = ['frame', 'limit']
105    _sort = ['limit', 'frame']
106    _types = {'frame': Int, 'limit': Int}
107
108    __slots__ = ()
109    def __new__(cls, file='', function='',
110            frame=0, limit=0, children=set()):
111        return super().__new__(cls, file, function,
112            Int(frame), Int(limit),
113            children)
114
115    def __add__(self, other):
116        return StackResult(self.file, self.function,
117            self.frame + other.frame,
118            max(self.limit, other.limit),
119            self.children | other.children)
120
121
122def openio(path, mode='r', buffering=-1):
123    # allow '-' for stdin/stdout
124    if path == '-':
125        if mode == 'r':
126            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
127        else:
128            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
129    else:
130        return open(path, mode, buffering)
131
132def collect(ci_paths, *,
133        sources=None,
134        everything=False,
135        **args):
136    # parse the vcg format
137    k_pattern = re.compile('([a-z]+)\s*:', re.DOTALL)
138    v_pattern = re.compile('(?:"(.*?)"|([a-z]+))', re.DOTALL)
139    def parse_vcg(rest):
140        def parse_vcg(rest):
141            node = []
142            while True:
143                rest = rest.lstrip()
144                m_ = k_pattern.match(rest)
145                if not m_:
146                    return (node, rest)
147                k, rest = m_.group(1), rest[m_.end(0):]
148
149                rest = rest.lstrip()
150                if rest.startswith('{'):
151                    v, rest = parse_vcg(rest[1:])
152                    assert rest[0] == '}', "unexpected %r" % rest[0:1]
153                    rest = rest[1:]
154                    node.append((k, v))
155                else:
156                    m_ = v_pattern.match(rest)
157                    assert m_, "unexpected %r" % rest[0:1]
158                    v, rest = m_.group(1) or m_.group(2), rest[m_.end(0):]
159                    node.append((k, v))
160
161        node, rest = parse_vcg(rest)
162        assert rest == '', "unexpected %r" % rest[0:1]
163        return node
164
165    # collect into functions
166    callgraph = co.defaultdict(lambda: (None, None, 0, set()))
167    f_pattern = re.compile(
168        r'([^\\]*)\\n([^:]*)[^\\]*\\n([0-9]+) bytes \((.*)\)')
169    for path in ci_paths:
170        with open(path) as f:
171            vcg = parse_vcg(f.read())
172        for k, graph in vcg:
173            if k != 'graph':
174                continue
175            for k, info in graph:
176                if k == 'node':
177                    info = dict(info)
178                    m_ = f_pattern.match(info['label'])
179                    if m_:
180                        function, file, size, type = m_.groups()
181                        if (not args.get('quiet')
182                                and 'static' not in type
183                                and 'bounded' not in type):
184                            print("warning: "
185                                "found non-static stack for %s (%s, %s)" % (
186                                function, type, size))
187                        _, _, _, targets = callgraph[info['title']]
188                        callgraph[info['title']] = (
189                            file, function, int(size), targets)
190                elif k == 'edge':
191                    info = dict(info)
192                    _, _, _, targets = callgraph[info['sourcename']]
193                    targets.add(info['targetname'])
194                else:
195                    continue
196
197    callgraph_ = co.defaultdict(lambda: (None, None, 0, set()))
198    for source, (s_file, s_function, frame, targets) in callgraph.items():
199        # discard internal functions
200        if not everything and s_function.startswith('__'):
201            continue
202        # ignore filtered sources
203        if sources is not None:
204            if not any(
205                    os.path.abspath(s_file) == os.path.abspath(s)
206                    for s in sources):
207                continue
208        else:
209            # default to only cwd
210            if not everything and not os.path.commonpath([
211                    os.getcwd(),
212                    os.path.abspath(s_file)]) == os.getcwd():
213                continue
214
215        # smiplify path
216        if os.path.commonpath([
217                os.getcwd(),
218                os.path.abspath(s_file)]) == os.getcwd():
219            s_file = os.path.relpath(s_file)
220        else:
221            s_file = os.path.abspath(s_file)
222
223        callgraph_[source] = (s_file, s_function, frame, targets)
224    callgraph = callgraph_
225
226    if not everything:
227        callgraph_ = co.defaultdict(lambda: (None, None, 0, set()))
228        for source, (s_file, s_function, frame, targets) in callgraph.items():
229            # discard filtered sources
230            if sources is not None and not any(
231                    os.path.abspath(s_file) == os.path.abspath(s)
232                    for s in sources):
233                continue
234            # discard internal functions
235            if s_function.startswith('__'):
236                continue
237            callgraph_[source] = (s_file, s_function, frame, targets)
238        callgraph = callgraph_
239
240    # find maximum stack size recursively, this requires also detecting cycles
241    # (in case of recursion)
242    def find_limit(source, seen=None):
243        seen = seen or set()
244        if source not in callgraph:
245            return 0
246        _, _, frame, targets = callgraph[source]
247
248        limit = 0
249        for target in targets:
250            if target in seen:
251                # found a cycle
252                return m.inf
253            limit_ = find_limit(target, seen | {target})
254            limit = max(limit, limit_)
255
256        return frame + limit
257
258    def find_children(targets):
259        children = set()
260        for target in targets:
261            if target in callgraph:
262                t_file, t_function, _, _ = callgraph[target]
263                children.add((t_file, t_function))
264        return children
265
266    # build results
267    results = []
268    for source, (s_file, s_function, frame, targets) in callgraph.items():
269        limit = find_limit(source)
270        children = find_children(targets)
271        results.append(StackResult(s_file, s_function, frame, limit, children))
272
273    return results
274
275
276def fold(Result, results, *,
277        by=None,
278        defines=None,
279        **_):
280    if by is None:
281        by = Result._by
282
283    for k in it.chain(by or [], (k for k, _ in defines or [])):
284        if k not in Result._by and k not in Result._fields:
285            print("error: could not find field %r?" % k)
286            sys.exit(-1)
287
288    # filter by matching defines
289    if defines is not None:
290        results_ = []
291        for r in results:
292            if all(getattr(r, k) in vs for k, vs in defines):
293                results_.append(r)
294        results = results_
295
296    # organize results into conflicts
297    folding = co.OrderedDict()
298    for r in results:
299        name = tuple(getattr(r, k) for k in by)
300        if name not in folding:
301            folding[name] = []
302        folding[name].append(r)
303
304    # merge conflicts
305    folded = []
306    for name, rs in folding.items():
307        folded.append(sum(rs[1:], start=rs[0]))
308
309    return folded
310
311def table(Result, results, diff_results=None, *,
312        by=None,
313        fields=None,
314        sort=None,
315        summary=False,
316        all=False,
317        percent=False,
318        tree=False,
319        depth=1,
320        **_):
321    all_, all = all, __builtins__.all
322
323    if by is None:
324        by = Result._by
325    if fields is None:
326        fields = Result._fields
327    types = Result._types
328
329    # fold again
330    results = fold(Result, results, by=by)
331    if diff_results is not None:
332        diff_results = fold(Result, diff_results, by=by)
333
334    # organize by name
335    table = {
336        ','.join(str(getattr(r, k) or '') for k in by): r
337        for r in results}
338    diff_table = {
339        ','.join(str(getattr(r, k) or '') for k in by): r
340        for r in diff_results or []}
341    names = list(table.keys() | diff_table.keys())
342
343    # sort again, now with diff info, note that python's sort is stable
344    names.sort()
345    if diff_results is not None:
346        names.sort(key=lambda n: tuple(
347            types[k].ratio(
348                getattr(table.get(n), k, None),
349                getattr(diff_table.get(n), k, None))
350            for k in fields),
351            reverse=True)
352    if sort:
353        for k, reverse in reversed(sort):
354            names.sort(
355                key=lambda n: tuple(
356                    (getattr(table[n], k),)
357                    if getattr(table.get(n), k, None) is not None else ()
358                    for k in ([k] if k else [
359                        k for k in Result._sort if k in fields])),
360                reverse=reverse ^ (not k or k in Result._fields))
361
362
363    # build up our lines
364    lines = []
365
366    # header
367    header = []
368    header.append('%s%s' % (
369        ','.join(by),
370        ' (%d added, %d removed)' % (
371            sum(1 for n in table if n not in diff_table),
372            sum(1 for n in diff_table if n not in table))
373            if diff_results is not None and not percent else '')
374        if not summary else '')
375    if diff_results is None:
376        for k in fields:
377            header.append(k)
378    elif percent:
379        for k in fields:
380            header.append(k)
381    else:
382        for k in fields:
383            header.append('o'+k)
384        for k in fields:
385            header.append('n'+k)
386        for k in fields:
387            header.append('d'+k)
388    header.append('')
389    lines.append(header)
390
391    def table_entry(name, r, diff_r=None, ratios=[]):
392        entry = []
393        entry.append(name)
394        if diff_results is None:
395            for k in fields:
396                entry.append(getattr(r, k).table()
397                    if getattr(r, k, None) is not None
398                    else types[k].none)
399        elif percent:
400            for k in fields:
401                entry.append(getattr(r, k).diff_table()
402                    if getattr(r, k, None) is not None
403                    else types[k].diff_none)
404        else:
405            for k in fields:
406                entry.append(getattr(diff_r, k).diff_table()
407                    if getattr(diff_r, k, None) is not None
408                    else types[k].diff_none)
409            for k in fields:
410                entry.append(getattr(r, k).diff_table()
411                    if getattr(r, k, None) is not None
412                    else types[k].diff_none)
413            for k in fields:
414                entry.append(types[k].diff_diff(
415                        getattr(r, k, None),
416                        getattr(diff_r, k, None)))
417        if diff_results is None:
418            entry.append('')
419        elif percent:
420            entry.append(' (%s)' % ', '.join(
421                '+∞%' if t == +m.inf
422                else '-∞%' if t == -m.inf
423                else '%+.1f%%' % (100*t)
424                for t in ratios))
425        else:
426            entry.append(' (%s)' % ', '.join(
427                    '+∞%' if t == +m.inf
428                    else '-∞%' if t == -m.inf
429                    else '%+.1f%%' % (100*t)
430                    for t in ratios
431                    if t)
432                if any(ratios) else '')
433        return entry
434
435    # entries
436    if not summary:
437        for name in names:
438            r = table.get(name)
439            if diff_results is None:
440                diff_r = None
441                ratios = None
442            else:
443                diff_r = diff_table.get(name)
444                ratios = [
445                    types[k].ratio(
446                        getattr(r, k, None),
447                        getattr(diff_r, k, None))
448                    for k in fields]
449                if not all_ and not any(ratios):
450                    continue
451            lines.append(table_entry(name, r, diff_r, ratios))
452
453    # total
454    r = next(iter(fold(Result, results, by=[])), None)
455    if diff_results is None:
456        diff_r = None
457        ratios = None
458    else:
459        diff_r = next(iter(fold(Result, diff_results, by=[])), None)
460        ratios = [
461            types[k].ratio(
462                getattr(r, k, None),
463                getattr(diff_r, k, None))
464            for k in fields]
465    lines.append(table_entry('TOTAL', r, diff_r, ratios))
466
467    # find the best widths, note that column 0 contains the names and column -1
468    # the ratios, so those are handled a bit differently
469    widths = [
470        ((max(it.chain([w], (len(l[i]) for l in lines)))+1+4-1)//4)*4-1
471        for w, i in zip(
472            it.chain([23], it.repeat(7)),
473            range(len(lines[0])-1))]
474
475    # adjust the name width based on the expected call depth, though
476    # note this doesn't really work with unbounded recursion
477    if not summary and not m.isinf(depth):
478        widths[0] += 4*(depth-1)
479
480    # print the tree recursively
481    if not tree:
482        print('%-*s  %s%s' % (
483            widths[0], lines[0][0],
484            ' '.join('%*s' % (w, x)
485                for w, x in zip(widths[1:], lines[0][1:-1])),
486            lines[0][-1]))
487
488    if not summary:
489        line_table = {n: l for n, l in zip(names, lines[1:-1])}
490
491        def recurse(names_, depth_, prefixes=('', '', '', '')):
492            for i, name in enumerate(names_):
493                if name not in line_table:
494                    continue
495                line = line_table[name]
496                is_last = (i == len(names_)-1)
497
498                print('%s%-*s ' % (
499                    prefixes[0+is_last],
500                    widths[0] - (
501                        len(prefixes[0+is_last])
502                        if not m.isinf(depth) else 0),
503                    line[0]),
504                    end='')
505                if not tree:
506                    print(' %s%s' % (
507                        ' '.join('%*s' % (w, x)
508                            for w, x in zip(widths[1:], line[1:-1])),
509                        line[-1]),
510                        end='')
511                print()
512
513                # recurse?
514                if name in table and depth_ > 1:
515                    children = {
516                        ','.join(str(getattr(Result(*c), k) or '') for k in by)
517                        for c in table[name].children}
518                    recurse(
519                        # note we're maintaining sort order
520                        [n for n in names if n in children],
521                        depth_-1,
522                        (prefixes[2+is_last] + "|-> ",
523                         prefixes[2+is_last] + "'-> ",
524                         prefixes[2+is_last] + "|   ",
525                         prefixes[2+is_last] + "    "))
526
527        recurse(names, depth)
528
529    if not tree:
530        print('%-*s  %s%s' % (
531            widths[0], lines[-1][0],
532            ' '.join('%*s' % (w, x)
533                for w, x in zip(widths[1:], lines[-1][1:-1])),
534            lines[-1][-1]))
535
536
537def main(ci_paths,
538        by=None,
539        fields=None,
540        defines=None,
541        sort=None,
542        **args):
543    # it doesn't really make sense to not have a depth with tree,
544    # so assume depth=inf if tree by default
545    if args.get('depth') is None:
546        args['depth'] = m.inf if args['tree'] else 1
547    elif args.get('depth') == 0:
548        args['depth'] = m.inf
549
550    # find sizes
551    if not args.get('use', None):
552        results = collect(ci_paths, **args)
553    else:
554        results = []
555        with openio(args['use']) as f:
556            reader = csv.DictReader(f, restval='')
557            for r in reader:
558                if not any('stack_'+k in r and r['stack_'+k].strip()
559                        for k in StackResult._fields):
560                    continue
561                try:
562                    results.append(StackResult(
563                        **{k: r[k] for k in StackResult._by
564                            if k in r and r[k].strip()},
565                        **{k: r['stack_'+k] for k in StackResult._fields
566                            if 'stack_'+k in r and r['stack_'+k].strip()}))
567                except TypeError:
568                    pass
569
570    # fold
571    results = fold(StackResult, results, by=by, defines=defines)
572
573    # sort, note that python's sort is stable
574    results.sort()
575    if sort:
576        for k, reverse in reversed(sort):
577            results.sort(
578                key=lambda r: tuple(
579                    (getattr(r, k),) if getattr(r, k) is not None else ()
580                    for k in ([k] if k else StackResult._sort)),
581                reverse=reverse ^ (not k or k in StackResult._fields))
582
583    # write results to CSV
584    if args.get('output'):
585        with openio(args['output'], 'w') as f:
586            writer = csv.DictWriter(f,
587                (by if by is not None else StackResult._by)
588                + ['stack_'+k for k in (
589                    fields if fields is not None else StackResult._fields)])
590            writer.writeheader()
591            for r in results:
592                writer.writerow(
593                    {k: getattr(r, k) for k in (
594                        by if by is not None else StackResult._by)}
595                    | {'stack_'+k: getattr(r, k) for k in (
596                        fields if fields is not None else StackResult._fields)})
597
598    # find previous results?
599    if args.get('diff'):
600        diff_results = []
601        try:
602            with openio(args['diff']) as f:
603                reader = csv.DictReader(f, restval='')
604                for r in reader:
605                    if not any('stack_'+k in r and r['stack_'+k].strip()
606                            for k in StackResult._fields):
607                        continue
608                    try:
609                        diff_results.append(StackResult(
610                            **{k: r[k] for k in StackResult._by
611                                if k in r and r[k].strip()},
612                            **{k: r['stack_'+k] for k in StackResult._fields
613                                if 'stack_'+k in r and r['stack_'+k].strip()}))
614                    except TypeError:
615                        raise
616        except FileNotFoundError:
617            pass
618
619        # fold
620        diff_results = fold(StackResult, diff_results, by=by, defines=defines)
621
622    # print table
623    if not args.get('quiet'):
624        table(StackResult, results,
625            diff_results if args.get('diff') else None,
626            by=by if by is not None else ['function'],
627            fields=fields,
628            sort=sort,
629            **args)
630
631    # error on recursion
632    if args.get('error_on_recursion') and any(
633            m.isinf(float(r.limit)) for r in results):
634        sys.exit(2)
635
636
637if __name__ == "__main__":
638    import argparse
639    import sys
640    parser = argparse.ArgumentParser(
641        description="Find stack usage at the function level.",
642        allow_abbrev=False)
643    parser.add_argument(
644        'ci_paths',
645        nargs='*',
646        help="Input *.ci files.")
647    parser.add_argument(
648        '-v', '--verbose',
649        action='store_true',
650        help="Output commands that run behind the scenes.")
651    parser.add_argument(
652        '-q', '--quiet',
653        action='store_true',
654        help="Don't show anything, useful with -o.")
655    parser.add_argument(
656        '-o', '--output',
657        help="Specify CSV file to store results.")
658    parser.add_argument(
659        '-u', '--use',
660        help="Don't parse anything, use this CSV file.")
661    parser.add_argument(
662        '-d', '--diff',
663        help="Specify CSV file to diff against.")
664    parser.add_argument(
665        '-a', '--all',
666        action='store_true',
667        help="Show all, not just the ones that changed.")
668    parser.add_argument(
669        '-p', '--percent',
670        action='store_true',
671        help="Only show percentage change, not a full diff.")
672    parser.add_argument(
673        '-b', '--by',
674        action='append',
675        choices=StackResult._by,
676        help="Group by this field.")
677    parser.add_argument(
678        '-f', '--field',
679        dest='fields',
680        action='append',
681        choices=StackResult._fields,
682        help="Show this field.")
683    parser.add_argument(
684        '-D', '--define',
685        dest='defines',
686        action='append',
687        type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)),
688        help="Only include results where this field is this value.")
689    class AppendSort(argparse.Action):
690        def __call__(self, parser, namespace, value, option):
691            if namespace.sort is None:
692                namespace.sort = []
693            namespace.sort.append((value, True if option == '-S' else False))
694    parser.add_argument(
695        '-s', '--sort',
696        nargs='?',
697        action=AppendSort,
698        help="Sort by this field.")
699    parser.add_argument(
700        '-S', '--reverse-sort',
701        nargs='?',
702        action=AppendSort,
703        help="Sort by this field, but backwards.")
704    parser.add_argument(
705        '-Y', '--summary',
706        action='store_true',
707        help="Only show the total.")
708    parser.add_argument(
709        '-F', '--source',
710        dest='sources',
711        action='append',
712        help="Only consider definitions in this file. Defaults to anything "
713            "in the current directory.")
714    parser.add_argument(
715        '--everything',
716        action='store_true',
717        help="Include builtin and libc specific symbols.")
718    parser.add_argument(
719        '--tree',
720        action='store_true',
721        help="Only show the function call tree.")
722    parser.add_argument(
723        '-Z', '--depth',
724        nargs='?',
725        type=lambda x: int(x, 0),
726        const=0,
727        help="Depth of function calls to show. 0 shows all calls but may not "
728            "terminate!")
729    parser.add_argument(
730        '-e', '--error-on-recursion',
731        action='store_true',
732        help="Error if any functions are recursive.")
733    sys.exit(main(**{k: v
734        for k, v in vars(parser.parse_intermixed_args()).items()
735        if v is not None}))
736