1#!/usr/bin/env python3
2#
3# Script to find data size at the function level. Basically just a big wrapper
4# around nm with some extra conveniences for comparing builds. Heavily inspired
5# by Linux's Bloat-O-Meter.
6#
7# Example:
8# ./scripts/data.py lfs.o lfs_util.o -Ssize
9#
10# Copyright (c) 2022, The littlefs authors.
11# Copyright (c) 2020, Arm Limited. All rights reserved.
12# SPDX-License-Identifier: BSD-3-Clause
13#
14
15import collections as co
16import csv
17import difflib
18import itertools as it
19import math as m
20import os
21import re
22import shlex
23import subprocess as sp
24
25
26NM_PATH = ['nm']
27NM_TYPES = 'dDbB'
28OBJDUMP_PATH = ['objdump']
29
30
31# integer fields
32class Int(co.namedtuple('Int', 'x')):
33    __slots__ = ()
34    def __new__(cls, x=0):
35        if isinstance(x, Int):
36            return x
37        if isinstance(x, str):
38            try:
39                x = int(x, 0)
40            except ValueError:
41                # also accept +-∞ and +-inf
42                if re.match('^\s*\+?\s*(?:∞|inf)\s*$', x):
43                    x = m.inf
44                elif re.match('^\s*-\s*(?:∞|inf)\s*$', x):
45                    x = -m.inf
46                else:
47                    raise
48        assert isinstance(x, int) or m.isinf(x), x
49        return super().__new__(cls, x)
50
51    def __str__(self):
52        if self.x == m.inf:
53            return '∞'
54        elif self.x == -m.inf:
55            return '-∞'
56        else:
57            return str(self.x)
58
59    def __int__(self):
60        assert not m.isinf(self.x)
61        return self.x
62
63    def __float__(self):
64        return float(self.x)
65
66    none = '%7s' % '-'
67    def table(self):
68        return '%7s' % (self,)
69
70    diff_none = '%7s' % '-'
71    diff_table = table
72
73    def diff_diff(self, other):
74        new = self.x if self else 0
75        old = other.x if other else 0
76        diff = new - old
77        if diff == +m.inf:
78            return '%7s' % '+∞'
79        elif diff == -m.inf:
80            return '%7s' % '-∞'
81        else:
82            return '%+7d' % diff
83
84    def ratio(self, other):
85        new = self.x if self else 0
86        old = other.x if other else 0
87        if m.isinf(new) and m.isinf(old):
88            return 0.0
89        elif m.isinf(new):
90            return +m.inf
91        elif m.isinf(old):
92            return -m.inf
93        elif not old and not new:
94            return 0.0
95        elif not old:
96            return 1.0
97        else:
98            return (new-old) / old
99
100    def __add__(self, other):
101        return self.__class__(self.x + other.x)
102
103    def __sub__(self, other):
104        return self.__class__(self.x - other.x)
105
106    def __mul__(self, other):
107        return self.__class__(self.x * other.x)
108
109# data size results
110class DataResult(co.namedtuple('DataResult', [
111        'file', 'function',
112        'size'])):
113    _by = ['file', 'function']
114    _fields = ['size']
115    _sort = ['size']
116    _types = {'size': Int}
117
118    __slots__ = ()
119    def __new__(cls, file='', function='', size=0):
120        return super().__new__(cls, file, function,
121            Int(size))
122
123    def __add__(self, other):
124        return DataResult(self.file, self.function,
125            self.size + other.size)
126
127
128def openio(path, mode='r', buffering=-1):
129    # allow '-' for stdin/stdout
130    if path == '-':
131        if mode == 'r':
132            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
133        else:
134            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
135    else:
136        return open(path, mode, buffering)
137
138def collect(obj_paths, *,
139        nm_path=NM_PATH,
140        nm_types=NM_TYPES,
141        objdump_path=OBJDUMP_PATH,
142        sources=None,
143        everything=False,
144        **args):
145    size_pattern = re.compile(
146        '^(?P<size>[0-9a-fA-F]+)' +
147        ' (?P<type>[%s])' % re.escape(nm_types) +
148        ' (?P<func>.+?)$')
149    line_pattern = re.compile(
150        '^\s+(?P<no>[0-9]+)'
151            '(?:\s+(?P<dir>[0-9]+))?'
152            '\s+.*'
153            '\s+(?P<path>[^\s]+)$')
154    info_pattern = re.compile(
155        '^(?:.*(?P<tag>DW_TAG_[a-z_]+).*'
156            '|.*DW_AT_name.*:\s*(?P<name>[^:\s]+)\s*'
157            '|.*DW_AT_decl_file.*:\s*(?P<file>[0-9]+)\s*)$')
158
159    results = []
160    for path in obj_paths:
161        # guess the source, if we have debug-info we'll replace this later
162        file = re.sub('(\.o)?$', '.c', path, 1)
163
164        # find symbol sizes
165        results_ = []
166        # note nm-path may contain extra args
167        cmd = nm_path + ['--size-sort', path]
168        if args.get('verbose'):
169            print(' '.join(shlex.quote(c) for c in cmd))
170        proc = sp.Popen(cmd,
171            stdout=sp.PIPE,
172            stderr=sp.PIPE if not args.get('verbose') else None,
173            universal_newlines=True,
174            errors='replace',
175            close_fds=False)
176        for line in proc.stdout:
177            m = size_pattern.match(line)
178            if m:
179                func = m.group('func')
180                # discard internal functions
181                if not everything and func.startswith('__'):
182                    continue
183                results_.append(DataResult(
184                    file, func,
185                    int(m.group('size'), 16)))
186        proc.wait()
187        if proc.returncode != 0:
188            if not args.get('verbose'):
189                for line in proc.stderr:
190                    sys.stdout.write(line)
191            sys.exit(-1)
192
193
194        # try to figure out the source file if we have debug-info
195        dirs = {}
196        files = {}
197        # note objdump-path may contain extra args
198        cmd = objdump_path + ['--dwarf=rawline', path]
199        if args.get('verbose'):
200            print(' '.join(shlex.quote(c) for c in cmd))
201        proc = sp.Popen(cmd,
202            stdout=sp.PIPE,
203            stderr=sp.PIPE if not args.get('verbose') else None,
204            universal_newlines=True,
205            errors='replace',
206            close_fds=False)
207        for line in proc.stdout:
208            # note that files contain references to dirs, which we
209            # dereference as soon as we see them as each file table follows a
210            # dir table
211            m = line_pattern.match(line)
212            if m:
213                if not m.group('dir'):
214                    # found a directory entry
215                    dirs[int(m.group('no'))] = m.group('path')
216                else:
217                    # found a file entry
218                    dir = int(m.group('dir'))
219                    if dir in dirs:
220                        files[int(m.group('no'))] = os.path.join(
221                            dirs[dir],
222                            m.group('path'))
223                    else:
224                        files[int(m.group('no'))] = m.group('path')
225        proc.wait()
226        if proc.returncode != 0:
227            if not args.get('verbose'):
228                for line in proc.stderr:
229                    sys.stdout.write(line)
230            # do nothing on error, we don't need objdump to work, source files
231            # may just be inaccurate
232            pass
233
234        defs = {}
235        is_func = False
236        f_name = None
237        f_file = None
238        # note objdump-path may contain extra args
239        cmd = objdump_path + ['--dwarf=info', path]
240        if args.get('verbose'):
241            print(' '.join(shlex.quote(c) for c in cmd))
242        proc = sp.Popen(cmd,
243            stdout=sp.PIPE,
244            stderr=sp.PIPE if not args.get('verbose') else None,
245            universal_newlines=True,
246            errors='replace',
247            close_fds=False)
248        for line in proc.stdout:
249            # state machine here to find definitions
250            m = info_pattern.match(line)
251            if m:
252                if m.group('tag'):
253                    if is_func:
254                        defs[f_name] = files.get(f_file, '?')
255                    is_func = (m.group('tag') == 'DW_TAG_subprogram')
256                elif m.group('name'):
257                    f_name = m.group('name')
258                elif m.group('file'):
259                    f_file = int(m.group('file'))
260        if is_func:
261            defs[f_name] = files.get(f_file, '?')
262        proc.wait()
263        if proc.returncode != 0:
264            if not args.get('verbose'):
265                for line in proc.stderr:
266                    sys.stdout.write(line)
267            # do nothing on error, we don't need objdump to work, source files
268            # may just be inaccurate
269            pass
270
271        for r in results_:
272            # find best matching debug symbol, this may be slightly different
273            # due to optimizations
274            if defs:
275                # exact match? avoid difflib if we can for speed
276                if r.function in defs:
277                    file = defs[r.function]
278                else:
279                    _, file = max(
280                        defs.items(),
281                        key=lambda d: difflib.SequenceMatcher(None,
282                            d[0],
283                            r.function, False).ratio())
284            else:
285                file = r.file
286
287            # ignore filtered sources
288            if sources is not None:
289                if not any(
290                        os.path.abspath(file) == os.path.abspath(s)
291                        for s in sources):
292                    continue
293            else:
294                # default to only cwd
295                if not everything and not os.path.commonpath([
296                        os.getcwd(),
297                        os.path.abspath(file)]) == os.getcwd():
298                    continue
299
300            # simplify path
301            if os.path.commonpath([
302                    os.getcwd(),
303                    os.path.abspath(file)]) == os.getcwd():
304                file = os.path.relpath(file)
305            else:
306                file = os.path.abspath(file)
307
308            results.append(r._replace(file=file))
309
310    return results
311
312
313def fold(Result, results, *,
314        by=None,
315        defines=None,
316        **_):
317    if by is None:
318        by = Result._by
319
320    for k in it.chain(by or [], (k for k, _ in defines or [])):
321        if k not in Result._by and k not in Result._fields:
322            print("error: could not find field %r?" % k)
323            sys.exit(-1)
324
325    # filter by matching defines
326    if defines is not None:
327        results_ = []
328        for r in results:
329            if all(getattr(r, k) in vs for k, vs in defines):
330                results_.append(r)
331        results = results_
332
333    # organize results into conflicts
334    folding = co.OrderedDict()
335    for r in results:
336        name = tuple(getattr(r, k) for k in by)
337        if name not in folding:
338            folding[name] = []
339        folding[name].append(r)
340
341    # merge conflicts
342    folded = []
343    for name, rs in folding.items():
344        folded.append(sum(rs[1:], start=rs[0]))
345
346    return folded
347
348def table(Result, results, diff_results=None, *,
349        by=None,
350        fields=None,
351        sort=None,
352        summary=False,
353        all=False,
354        percent=False,
355        **_):
356    all_, all = all, __builtins__.all
357
358    if by is None:
359        by = Result._by
360    if fields is None:
361        fields = Result._fields
362    types = Result._types
363
364    # fold again
365    results = fold(Result, results, by=by)
366    if diff_results is not None:
367        diff_results = fold(Result, diff_results, by=by)
368
369    # organize by name
370    table = {
371        ','.join(str(getattr(r, k) or '') for k in by): r
372        for r in results}
373    diff_table = {
374        ','.join(str(getattr(r, k) or '') for k in by): r
375        for r in diff_results or []}
376    names = list(table.keys() | diff_table.keys())
377
378    # sort again, now with diff info, note that python's sort is stable
379    names.sort()
380    if diff_results is not None:
381        names.sort(key=lambda n: tuple(
382            types[k].ratio(
383                getattr(table.get(n), k, None),
384                getattr(diff_table.get(n), k, None))
385            for k in fields),
386            reverse=True)
387    if sort:
388        for k, reverse in reversed(sort):
389            names.sort(
390                key=lambda n: tuple(
391                    (getattr(table[n], k),)
392                    if getattr(table.get(n), k, None) is not None else ()
393                    for k in ([k] if k else [
394                        k for k in Result._sort if k in fields])),
395                reverse=reverse ^ (not k or k in Result._fields))
396
397
398    # build up our lines
399    lines = []
400
401    # header
402    header = []
403    header.append('%s%s' % (
404        ','.join(by),
405        ' (%d added, %d removed)' % (
406            sum(1 for n in table if n not in diff_table),
407            sum(1 for n in diff_table if n not in table))
408            if diff_results is not None and not percent else '')
409        if not summary else '')
410    if diff_results is None:
411        for k in fields:
412            header.append(k)
413    elif percent:
414        for k in fields:
415            header.append(k)
416    else:
417        for k in fields:
418            header.append('o'+k)
419        for k in fields:
420            header.append('n'+k)
421        for k in fields:
422            header.append('d'+k)
423    header.append('')
424    lines.append(header)
425
426    def table_entry(name, r, diff_r=None, ratios=[]):
427        entry = []
428        entry.append(name)
429        if diff_results is None:
430            for k in fields:
431                entry.append(getattr(r, k).table()
432                    if getattr(r, k, None) is not None
433                    else types[k].none)
434        elif percent:
435            for k in fields:
436                entry.append(getattr(r, k).diff_table()
437                    if getattr(r, k, None) is not None
438                    else types[k].diff_none)
439        else:
440            for k in fields:
441                entry.append(getattr(diff_r, k).diff_table()
442                    if getattr(diff_r, k, None) is not None
443                    else types[k].diff_none)
444            for k in fields:
445                entry.append(getattr(r, k).diff_table()
446                    if getattr(r, k, None) is not None
447                    else types[k].diff_none)
448            for k in fields:
449                entry.append(types[k].diff_diff(
450                        getattr(r, k, None),
451                        getattr(diff_r, k, None)))
452        if diff_results is None:
453            entry.append('')
454        elif percent:
455            entry.append(' (%s)' % ', '.join(
456                '+∞%' if t == +m.inf
457                else '-∞%' if t == -m.inf
458                else '%+.1f%%' % (100*t)
459                for t in ratios))
460        else:
461            entry.append(' (%s)' % ', '.join(
462                    '+∞%' if t == +m.inf
463                    else '-∞%' if t == -m.inf
464                    else '%+.1f%%' % (100*t)
465                    for t in ratios
466                    if t)
467                if any(ratios) else '')
468        return entry
469
470    # entries
471    if not summary:
472        for name in names:
473            r = table.get(name)
474            if diff_results is None:
475                diff_r = None
476                ratios = None
477            else:
478                diff_r = diff_table.get(name)
479                ratios = [
480                    types[k].ratio(
481                        getattr(r, k, None),
482                        getattr(diff_r, k, None))
483                    for k in fields]
484                if not all_ and not any(ratios):
485                    continue
486            lines.append(table_entry(name, r, diff_r, ratios))
487
488    # total
489    r = next(iter(fold(Result, results, by=[])), None)
490    if diff_results is None:
491        diff_r = None
492        ratios = None
493    else:
494        diff_r = next(iter(fold(Result, diff_results, by=[])), None)
495        ratios = [
496            types[k].ratio(
497                getattr(r, k, None),
498                getattr(diff_r, k, None))
499            for k in fields]
500    lines.append(table_entry('TOTAL', r, diff_r, ratios))
501
502    # find the best widths, note that column 0 contains the names and column -1
503    # the ratios, so those are handled a bit differently
504    widths = [
505        ((max(it.chain([w], (len(l[i]) for l in lines)))+1+4-1)//4)*4-1
506        for w, i in zip(
507            it.chain([23], it.repeat(7)),
508            range(len(lines[0])-1))]
509
510    # print our table
511    for line in lines:
512        print('%-*s  %s%s' % (
513            widths[0], line[0],
514            ' '.join('%*s' % (w, x)
515                for w, x in zip(widths[1:], line[1:-1])),
516            line[-1]))
517
518
519def main(obj_paths, *,
520        by=None,
521        fields=None,
522        defines=None,
523        sort=None,
524        **args):
525    # find sizes
526    if not args.get('use', None):
527        results = collect(obj_paths, **args)
528    else:
529        results = []
530        with openio(args['use']) as f:
531            reader = csv.DictReader(f, restval='')
532            for r in reader:
533                try:
534                    results.append(DataResult(
535                        **{k: r[k] for k in DataResult._by
536                            if k in r and r[k].strip()},
537                        **{k: r['data_'+k] for k in DataResult._fields
538                            if 'data_'+k in r and r['data_'+k].strip()}))
539                except TypeError:
540                    pass
541
542    # fold
543    results = fold(DataResult, results, by=by, defines=defines)
544
545    # sort, note that python's sort is stable
546    results.sort()
547    if sort:
548        for k, reverse in reversed(sort):
549            results.sort(
550                key=lambda r: tuple(
551                    (getattr(r, k),) if getattr(r, k) is not None else ()
552                    for k in ([k] if k else DataResult._sort)),
553                reverse=reverse ^ (not k or k in DataResult._fields))
554
555    # write results to CSV
556    if args.get('output'):
557        with openio(args['output'], 'w') as f:
558            writer = csv.DictWriter(f,
559                (by if by is not None else DataResult._by)
560                + ['data_'+k for k in (
561                    fields if fields is not None else DataResult._fields)])
562            writer.writeheader()
563            for r in results:
564                writer.writerow(
565                    {k: getattr(r, k) for k in (
566                        by if by is not None else DataResult._by)}
567                    | {'data_'+k: getattr(r, k) for k in (
568                        fields if fields is not None else DataResult._fields)})
569
570    # find previous results?
571    if args.get('diff'):
572        diff_results = []
573        try:
574            with openio(args['diff']) as f:
575                reader = csv.DictReader(f, restval='')
576                for r in reader:
577                    if not any('data_'+k in r and r['data_'+k].strip()
578                            for k in DataResult._fields):
579                        continue
580                    try:
581                        diff_results.append(DataResult(
582                            **{k: r[k] for k in DataResult._by
583                                if k in r and r[k].strip()},
584                            **{k: r['data_'+k] for k in DataResult._fields
585                                if 'data_'+k in r and r['data_'+k].strip()}))
586                    except TypeError:
587                        pass
588        except FileNotFoundError:
589            pass
590
591        # fold
592        diff_results = fold(DataResult, diff_results, by=by, defines=defines)
593
594    # print table
595    if not args.get('quiet'):
596        table(DataResult, results,
597            diff_results if args.get('diff') else None,
598            by=by if by is not None else ['function'],
599            fields=fields,
600            sort=sort,
601            **args)
602
603
604if __name__ == "__main__":
605    import argparse
606    import sys
607    parser = argparse.ArgumentParser(
608        description="Find data size at the function level.",
609        allow_abbrev=False)
610    parser.add_argument(
611        'obj_paths',
612        nargs='*',
613        help="Input *.o files.")
614    parser.add_argument(
615        '-v', '--verbose',
616        action='store_true',
617        help="Output commands that run behind the scenes.")
618    parser.add_argument(
619        '-q', '--quiet',
620        action='store_true',
621        help="Don't show anything, useful with -o.")
622    parser.add_argument(
623        '-o', '--output',
624        help="Specify CSV file to store results.")
625    parser.add_argument(
626        '-u', '--use',
627        help="Don't parse anything, use this CSV file.")
628    parser.add_argument(
629        '-d', '--diff',
630        help="Specify CSV file to diff against.")
631    parser.add_argument(
632        '-a', '--all',
633        action='store_true',
634        help="Show all, not just the ones that changed.")
635    parser.add_argument(
636        '-p', '--percent',
637        action='store_true',
638        help="Only show percentage change, not a full diff.")
639    parser.add_argument(
640        '-b', '--by',
641        action='append',
642        choices=DataResult._by,
643        help="Group by this field.")
644    parser.add_argument(
645        '-f', '--field',
646        dest='fields',
647        action='append',
648        choices=DataResult._fields,
649        help="Show this field.")
650    parser.add_argument(
651        '-D', '--define',
652        dest='defines',
653        action='append',
654        type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)),
655        help="Only include results where this field is this value.")
656    class AppendSort(argparse.Action):
657        def __call__(self, parser, namespace, value, option):
658            if namespace.sort is None:
659                namespace.sort = []
660            namespace.sort.append((value, True if option == '-S' else False))
661    parser.add_argument(
662        '-s', '--sort',
663        nargs='?',
664        action=AppendSort,
665        help="Sort by this field.")
666    parser.add_argument(
667        '-S', '--reverse-sort',
668        nargs='?',
669        action=AppendSort,
670        help="Sort by this field, but backwards.")
671    parser.add_argument(
672        '-Y', '--summary',
673        action='store_true',
674        help="Only show the total.")
675    parser.add_argument(
676        '-F', '--source',
677        dest='sources',
678        action='append',
679        help="Only consider definitions in this file. Defaults to anything "
680            "in the current directory.")
681    parser.add_argument(
682        '--everything',
683        action='store_true',
684        help="Include builtin and libc specific symbols.")
685    parser.add_argument(
686        '--nm-types',
687        default=NM_TYPES,
688        help="Type of symbols to report, this uses the same single-character "
689            "type-names emitted by nm. Defaults to %r." % NM_TYPES)
690    parser.add_argument(
691        '--nm-path',
692        type=lambda x: x.split(),
693        default=NM_PATH,
694        help="Path to the nm executable, may include flags. "
695            "Defaults to %r." % NM_PATH)
696    parser.add_argument(
697        '--objdump-path',
698        type=lambda x: x.split(),
699        default=OBJDUMP_PATH,
700        help="Path to the objdump executable, may include flags. "
701            "Defaults to %r." % OBJDUMP_PATH)
702    sys.exit(main(**{k: v
703        for k, v in vars(parser.parse_intermixed_args()).items()
704        if v is not None}))
705