1#!/usr/bin/env python3
2#
3# Script to find code size at the function level. Basically just a big wrapper
4# around nm with some extra conveniences for comparing builds. Heavily inspired
5# by Linux's Bloat-O-Meter.
6#
7# Example:
8# ./scripts/code.py lfs.o lfs_util.o -Ssize
9#
10# Copyright (c) 2022, The littlefs authors.
11# Copyright (c) 2020, Arm Limited. All rights reserved.
12# SPDX-License-Identifier: BSD-3-Clause
13#
14
15import collections as co
16import csv
17import difflib
18import itertools as it
19import math as m
20import os
21import re
22import shlex
23import subprocess as sp
24
25
26NM_PATH = ['nm']
27NM_TYPES = 'tTrRdD'
28OBJDUMP_PATH = ['objdump']
29
30
31# integer fields
32class Int(co.namedtuple('Int', 'x')):
33    __slots__ = ()
34    def __new__(cls, x=0):
35        if isinstance(x, Int):
36            return x
37        if isinstance(x, str):
38            try:
39                x = int(x, 0)
40            except ValueError:
41                # also accept +-∞ and +-inf
42                if re.match('^\s*\+?\s*(?:∞|inf)\s*$', x):
43                    x = m.inf
44                elif re.match('^\s*-\s*(?:∞|inf)\s*$', x):
45                    x = -m.inf
46                else:
47                    raise
48        assert isinstance(x, int) or m.isinf(x), x
49        return super().__new__(cls, x)
50
51    def __str__(self):
52        if self.x == m.inf:
53            return '∞'
54        elif self.x == -m.inf:
55            return '-∞'
56        else:
57            return str(self.x)
58
59    def __int__(self):
60        assert not m.isinf(self.x)
61        return self.x
62
63    def __float__(self):
64        return float(self.x)
65
66    none = '%7s' % '-'
67    def table(self):
68        return '%7s' % (self,)
69
70    diff_none = '%7s' % '-'
71    diff_table = table
72
73    def diff_diff(self, other):
74        new = self.x if self else 0
75        old = other.x if other else 0
76        diff = new - old
77        if diff == +m.inf:
78            return '%7s' % '+∞'
79        elif diff == -m.inf:
80            return '%7s' % '-∞'
81        else:
82            return '%+7d' % diff
83
84    def ratio(self, other):
85        new = self.x if self else 0
86        old = other.x if other else 0
87        if m.isinf(new) and m.isinf(old):
88            return 0.0
89        elif m.isinf(new):
90            return +m.inf
91        elif m.isinf(old):
92            return -m.inf
93        elif not old and not new:
94            return 0.0
95        elif not old:
96            return 1.0
97        else:
98            return (new-old) / old
99
100    def __add__(self, other):
101        return self.__class__(self.x + other.x)
102
103    def __sub__(self, other):
104        return self.__class__(self.x - other.x)
105
106    def __mul__(self, other):
107        return self.__class__(self.x * other.x)
108
109# code size results
110class CodeResult(co.namedtuple('CodeResult', [
111        'file', 'function',
112        'size'])):
113    _by = ['file', 'function']
114    _fields = ['size']
115    _sort = ['size']
116    _types = {'size': Int}
117
118    __slots__ = ()
119    def __new__(cls, file='', function='', size=0):
120        return super().__new__(cls, file, function,
121            Int(size))
122
123    def __add__(self, other):
124        return CodeResult(self.file, self.function,
125            self.size + other.size)
126
127
128def openio(path, mode='r', buffering=-1):
129    # allow '-' for stdin/stdout
130    if path == '-':
131        if mode == 'r':
132            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
133        else:
134            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
135    else:
136        return open(path, mode, buffering)
137
138def collect(obj_paths, *,
139        nm_path=NM_PATH,
140        nm_types=NM_TYPES,
141        objdump_path=OBJDUMP_PATH,
142        sources=None,
143        everything=False,
144        **args):
145    size_pattern = re.compile(
146        '^(?P<size>[0-9a-fA-F]+)' +
147        ' (?P<type>[%s])' % re.escape(nm_types) +
148        ' (?P<func>.+?)$')
149    line_pattern = re.compile(
150        '^\s+(?P<no>[0-9]+)'
151            '(?:\s+(?P<dir>[0-9]+))?'
152            '\s+.*'
153            '\s+(?P<path>[^\s]+)$')
154    info_pattern = re.compile(
155        '^(?:.*(?P<tag>DW_TAG_[a-z_]+).*'
156            '|.*DW_AT_name.*:\s*(?P<name>[^:\s]+)\s*'
157            '|.*DW_AT_decl_file.*:\s*(?P<file>[0-9]+)\s*)$')
158
159    results = []
160    for path in obj_paths:
161        # guess the source, if we have debug-info we'll replace this later
162        file = re.sub('(\.o)?$', '.c', path, 1)
163
164        # find symbol sizes
165        results_ = []
166        # note nm-path may contain extra args
167        cmd = nm_path + ['--size-sort', path]
168        if args.get('verbose'):
169            print(' '.join(shlex.quote(c) for c in cmd))
170        proc = sp.Popen(cmd,
171            stdout=sp.PIPE,
172            stderr=sp.PIPE if not args.get('verbose') else None,
173            universal_newlines=True,
174            errors='replace',
175            close_fds=False)
176        for line in proc.stdout:
177            m = size_pattern.match(line)
178            if m:
179                func = m.group('func')
180                # discard internal functions
181                if not everything and func.startswith('__'):
182                    continue
183                results_.append(CodeResult(
184                    file, func,
185                    int(m.group('size'), 16)))
186        proc.wait()
187        if proc.returncode != 0:
188            if not args.get('verbose'):
189                for line in proc.stderr:
190                    sys.stdout.write(line)
191            sys.exit(-1)
192
193
194        # try to figure out the source file if we have debug-info
195        dirs = {}
196        files = {}
197        # note objdump-path may contain extra args
198        cmd = objdump_path + ['--dwarf=rawline', path]
199        if args.get('verbose'):
200            print(' '.join(shlex.quote(c) for c in cmd))
201        proc = sp.Popen(cmd,
202            stdout=sp.PIPE,
203            stderr=sp.PIPE if not args.get('verbose') else None,
204            universal_newlines=True,
205            errors='replace',
206            close_fds=False)
207        for line in proc.stdout:
208            # note that files contain references to dirs, which we
209            # dereference as soon as we see them as each file table follows a
210            # dir table
211            m = line_pattern.match(line)
212            if m:
213                if not m.group('dir'):
214                    # found a directory entry
215                    dirs[int(m.group('no'))] = m.group('path')
216                else:
217                    # found a file entry
218                    dir = int(m.group('dir'))
219                    if dir in dirs:
220                        files[int(m.group('no'))] = os.path.join(
221                            dirs[dir],
222                            m.group('path'))
223                    else:
224                        files[int(m.group('no'))] = m.group('path')
225        proc.wait()
226        if proc.returncode != 0:
227            if not args.get('verbose'):
228                for line in proc.stderr:
229                    sys.stdout.write(line)
230            # do nothing on error, we don't need objdump to work, source files
231            # may just be inaccurate
232            pass
233
234        defs = {}
235        is_func = False
236        f_name = None
237        f_file = None
238        # note objdump-path may contain extra args
239        cmd = objdump_path + ['--dwarf=info', path]
240        if args.get('verbose'):
241            print(' '.join(shlex.quote(c) for c in cmd))
242        proc = sp.Popen(cmd,
243            stdout=sp.PIPE,
244            stderr=sp.PIPE if not args.get('verbose') else None,
245            universal_newlines=True,
246            errors='replace',
247            close_fds=False)
248        for line in proc.stdout:
249            # state machine here to find definitions
250            m = info_pattern.match(line)
251            if m:
252                if m.group('tag'):
253                    if is_func:
254                        defs[f_name] = files.get(f_file, '?')
255                    is_func = (m.group('tag') == 'DW_TAG_subprogram')
256                elif m.group('name'):
257                    f_name = m.group('name')
258                elif m.group('file'):
259                    f_file = int(m.group('file'))
260        if is_func:
261            defs[f_name] = files.get(f_file, '?')
262        proc.wait()
263        if proc.returncode != 0:
264            if not args.get('verbose'):
265                for line in proc.stderr:
266                    sys.stdout.write(line)
267            # do nothing on error, we don't need objdump to work, source files
268            # may just be inaccurate
269            pass
270
271        for r in results_:
272            # find best matching debug symbol, this may be slightly different
273            # due to optimizations
274            if defs:
275                # exact match? avoid difflib if we can for speed
276                if r.function in defs:
277                    file = defs[r.function]
278                else:
279                    _, file = max(
280                        defs.items(),
281                        key=lambda d: difflib.SequenceMatcher(None,
282                            d[0],
283                            r.function, False).ratio())
284            else:
285                file = r.file
286
287            # ignore filtered sources
288            if sources is not None:
289                if not any(
290                        os.path.abspath(file) == os.path.abspath(s)
291                        for s in sources):
292                    continue
293            else:
294                # default to only cwd
295                if not everything and not os.path.commonpath([
296                        os.getcwd(),
297                        os.path.abspath(file)]) == os.getcwd():
298                    continue
299
300            # simplify path
301            if os.path.commonpath([
302                    os.getcwd(),
303                    os.path.abspath(file)]) == os.getcwd():
304                file = os.path.relpath(file)
305            else:
306                file = os.path.abspath(file)
307
308            results.append(r._replace(file=file))
309
310    return results
311
312
313def fold(Result, results, *,
314        by=None,
315        defines=None,
316        **_):
317    if by is None:
318        by = Result._by
319
320    for k in it.chain(by or [], (k for k, _ in defines or [])):
321        if k not in Result._by and k not in Result._fields:
322            print("error: could not find field %r?" % k)
323            sys.exit(-1)
324
325    # filter by matching defines
326    if defines is not None:
327        results_ = []
328        for r in results:
329            if all(getattr(r, k) in vs for k, vs in defines):
330                results_.append(r)
331        results = results_
332
333    # organize results into conflicts
334    folding = co.OrderedDict()
335    for r in results:
336        name = tuple(getattr(r, k) for k in by)
337        if name not in folding:
338            folding[name] = []
339        folding[name].append(r)
340
341    # merge conflicts
342    folded = []
343    for name, rs in folding.items():
344        folded.append(sum(rs[1:], start=rs[0]))
345
346    return folded
347
348def table(Result, results, diff_results=None, *,
349        by=None,
350        fields=None,
351        sort=None,
352        summary=False,
353        all=False,
354        percent=False,
355        **_):
356    all_, all = all, __builtins__.all
357
358    if by is None:
359        by = Result._by
360    if fields is None:
361        fields = Result._fields
362    types = Result._types
363
364    # fold again
365    results = fold(Result, results, by=by)
366    if diff_results is not None:
367        diff_results = fold(Result, diff_results, by=by)
368
369    # organize by name
370    table = {
371        ','.join(str(getattr(r, k) or '') for k in by): r
372        for r in results}
373    diff_table = {
374        ','.join(str(getattr(r, k) or '') for k in by): r
375        for r in diff_results or []}
376    names = list(table.keys() | diff_table.keys())
377
378    # sort again, now with diff info, note that python's sort is stable
379    names.sort()
380    if diff_results is not None:
381        names.sort(key=lambda n: tuple(
382            types[k].ratio(
383                getattr(table.get(n), k, None),
384                getattr(diff_table.get(n), k, None))
385            for k in fields),
386            reverse=True)
387    if sort:
388        for k, reverse in reversed(sort):
389            names.sort(
390                key=lambda n: tuple(
391                    (getattr(table[n], k),)
392                    if getattr(table.get(n), k, None) is not None else ()
393                    for k in ([k] if k else [
394                        k for k in Result._sort if k in fields])),
395                reverse=reverse ^ (not k or k in Result._fields))
396
397
398    # build up our lines
399    lines = []
400
401    # header
402    header = []
403    header.append('%s%s' % (
404        ','.join(by),
405        ' (%d added, %d removed)' % (
406            sum(1 for n in table if n not in diff_table),
407            sum(1 for n in diff_table if n not in table))
408            if diff_results is not None and not percent else '')
409        if not summary else '')
410    if diff_results is None:
411        for k in fields:
412            header.append(k)
413    elif percent:
414        for k in fields:
415            header.append(k)
416    else:
417        for k in fields:
418            header.append('o'+k)
419        for k in fields:
420            header.append('n'+k)
421        for k in fields:
422            header.append('d'+k)
423    header.append('')
424    lines.append(header)
425
426    def table_entry(name, r, diff_r=None, ratios=[]):
427        entry = []
428        entry.append(name)
429        if diff_results is None:
430            for k in fields:
431                entry.append(getattr(r, k).table()
432                    if getattr(r, k, None) is not None
433                    else types[k].none)
434        elif percent:
435            for k in fields:
436                entry.append(getattr(r, k).diff_table()
437                    if getattr(r, k, None) is not None
438                    else types[k].diff_none)
439        else:
440            for k in fields:
441                entry.append(getattr(diff_r, k).diff_table()
442                    if getattr(diff_r, k, None) is not None
443                    else types[k].diff_none)
444            for k in fields:
445                entry.append(getattr(r, k).diff_table()
446                    if getattr(r, k, None) is not None
447                    else types[k].diff_none)
448            for k in fields:
449                entry.append(types[k].diff_diff(
450                        getattr(r, k, None),
451                        getattr(diff_r, k, None)))
452        if diff_results is None:
453            entry.append('')
454        elif percent:
455            entry.append(' (%s)' % ', '.join(
456                '+∞%' if t == +m.inf
457                else '-∞%' if t == -m.inf
458                else '%+.1f%%' % (100*t)
459                for t in ratios))
460        else:
461            entry.append(' (%s)' % ', '.join(
462                    '+∞%' if t == +m.inf
463                    else '-∞%' if t == -m.inf
464                    else '%+.1f%%' % (100*t)
465                    for t in ratios
466                    if t)
467                if any(ratios) else '')
468        return entry
469
470    # entries
471    if not summary:
472        for name in names:
473            r = table.get(name)
474            if diff_results is None:
475                diff_r = None
476                ratios = None
477            else:
478                diff_r = diff_table.get(name)
479                ratios = [
480                    types[k].ratio(
481                        getattr(r, k, None),
482                        getattr(diff_r, k, None))
483                    for k in fields]
484                if not all_ and not any(ratios):
485                    continue
486            lines.append(table_entry(name, r, diff_r, ratios))
487
488    # total
489    r = next(iter(fold(Result, results, by=[])), None)
490    if diff_results is None:
491        diff_r = None
492        ratios = None
493    else:
494        diff_r = next(iter(fold(Result, diff_results, by=[])), None)
495        ratios = [
496            types[k].ratio(
497                getattr(r, k, None),
498                getattr(diff_r, k, None))
499            for k in fields]
500    lines.append(table_entry('TOTAL', r, diff_r, ratios))
501
502    # find the best widths, note that column 0 contains the names and column -1
503    # the ratios, so those are handled a bit differently
504    widths = [
505        ((max(it.chain([w], (len(l[i]) for l in lines)))+1+4-1)//4)*4-1
506        for w, i in zip(
507            it.chain([23], it.repeat(7)),
508            range(len(lines[0])-1))]
509
510    # print our table
511    for line in lines:
512        print('%-*s  %s%s' % (
513            widths[0], line[0],
514            ' '.join('%*s' % (w, x)
515                for w, x in zip(widths[1:], line[1:-1])),
516            line[-1]))
517
518
519def main(obj_paths, *,
520        by=None,
521        fields=None,
522        defines=None,
523        sort=None,
524        **args):
525    # find sizes
526    if not args.get('use', None):
527        results = collect(obj_paths, **args)
528    else:
529        results = []
530        with openio(args['use']) as f:
531            reader = csv.DictReader(f, restval='')
532            for r in reader:
533                if not any('code_'+k in r and r['code_'+k].strip()
534                        for k in CodeResult._fields):
535                    continue
536                try:
537                    results.append(CodeResult(
538                        **{k: r[k] for k in CodeResult._by
539                            if k in r and r[k].strip()},
540                        **{k: r['code_'+k] for k in CodeResult._fields
541                            if 'code_'+k in r and r['code_'+k].strip()}))
542                except TypeError:
543                    pass
544
545    # fold
546    results = fold(CodeResult, results, by=by, defines=defines)
547
548    # sort, note that python's sort is stable
549    results.sort()
550    if sort:
551        for k, reverse in reversed(sort):
552            results.sort(
553                key=lambda r: tuple(
554                    (getattr(r, k),) if getattr(r, k) is not None else ()
555                    for k in ([k] if k else CodeResult._sort)),
556                reverse=reverse ^ (not k or k in CodeResult._fields))
557
558    # write results to CSV
559    if args.get('output'):
560        with openio(args['output'], 'w') as f:
561            writer = csv.DictWriter(f,
562                (by if by is not None else CodeResult._by)
563                + ['code_'+k for k in (
564                    fields if fields is not None else CodeResult._fields)])
565            writer.writeheader()
566            for r in results:
567                writer.writerow(
568                    {k: getattr(r, k) for k in (
569                        by if by is not None else CodeResult._by)}
570                    | {'code_'+k: getattr(r, k) for k in (
571                        fields if fields is not None else CodeResult._fields)})
572
573    # find previous results?
574    if args.get('diff'):
575        diff_results = []
576        try:
577            with openio(args['diff']) as f:
578                reader = csv.DictReader(f, restval='')
579                for r in reader:
580                    if not any('code_'+k in r and r['code_'+k].strip()
581                            for k in CodeResult._fields):
582                        continue
583                    try:
584                        diff_results.append(CodeResult(
585                            **{k: r[k] for k in CodeResult._by
586                                if k in r and r[k].strip()},
587                            **{k: r['code_'+k] for k in CodeResult._fields
588                                if 'code_'+k in r and r['code_'+k].strip()}))
589                    except TypeError:
590                        pass
591        except FileNotFoundError:
592            pass
593
594        # fold
595        diff_results = fold(CodeResult, diff_results, by=by, defines=defines)
596
597    # print table
598    if not args.get('quiet'):
599        table(CodeResult, results,
600            diff_results if args.get('diff') else None,
601            by=by if by is not None else ['function'],
602            fields=fields,
603            sort=sort,
604            **args)
605
606
607if __name__ == "__main__":
608    import argparse
609    import sys
610    parser = argparse.ArgumentParser(
611        description="Find code size at the function level.",
612        allow_abbrev=False)
613    parser.add_argument(
614        'obj_paths',
615        nargs='*',
616        help="Input *.o files.")
617    parser.add_argument(
618        '-v', '--verbose',
619        action='store_true',
620        help="Output commands that run behind the scenes.")
621    parser.add_argument(
622        '-q', '--quiet',
623        action='store_true',
624        help="Don't show anything, useful with -o.")
625    parser.add_argument(
626        '-o', '--output',
627        help="Specify CSV file to store results.")
628    parser.add_argument(
629        '-u', '--use',
630        help="Don't parse anything, use this CSV file.")
631    parser.add_argument(
632        '-d', '--diff',
633        help="Specify CSV file to diff against.")
634    parser.add_argument(
635        '-a', '--all',
636        action='store_true',
637        help="Show all, not just the ones that changed.")
638    parser.add_argument(
639        '-p', '--percent',
640        action='store_true',
641        help="Only show percentage change, not a full diff.")
642    parser.add_argument(
643        '-b', '--by',
644        action='append',
645        choices=CodeResult._by,
646        help="Group by this field.")
647    parser.add_argument(
648        '-f', '--field',
649        dest='fields',
650        action='append',
651        choices=CodeResult._fields,
652        help="Show this field.")
653    parser.add_argument(
654        '-D', '--define',
655        dest='defines',
656        action='append',
657        type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)),
658        help="Only include results where this field is this value.")
659    class AppendSort(argparse.Action):
660        def __call__(self, parser, namespace, value, option):
661            if namespace.sort is None:
662                namespace.sort = []
663            namespace.sort.append((value, True if option == '-S' else False))
664    parser.add_argument(
665        '-s', '--sort',
666        nargs='?',
667        action=AppendSort,
668        help="Sort by this field.")
669    parser.add_argument(
670        '-S', '--reverse-sort',
671        nargs='?',
672        action=AppendSort,
673        help="Sort by this field, but backwards.")
674    parser.add_argument(
675        '-Y', '--summary',
676        action='store_true',
677        help="Only show the total.")
678    parser.add_argument(
679        '-F', '--source',
680        dest='sources',
681        action='append',
682        help="Only consider definitions in this file. Defaults to anything "
683            "in the current directory.")
684    parser.add_argument(
685        '--everything',
686        action='store_true',
687        help="Include builtin and libc specific symbols.")
688    parser.add_argument(
689        '--nm-types',
690        default=NM_TYPES,
691        help="Type of symbols to report, this uses the same single-character "
692            "type-names emitted by nm. Defaults to %r." % NM_TYPES)
693    parser.add_argument(
694        '--nm-path',
695        type=lambda x: x.split(),
696        default=NM_PATH,
697        help="Path to the nm executable, may include flags. "
698            "Defaults to %r." % NM_PATH)
699    parser.add_argument(
700        '--objdump-path',
701        type=lambda x: x.split(),
702        default=OBJDUMP_PATH,
703        help="Path to the objdump executable, may include flags. "
704            "Defaults to %r." % OBJDUMP_PATH)
705    sys.exit(main(**{k: v
706        for k, v in vars(parser.parse_intermixed_args()).items()
707        if v is not None}))
708