1#!/usr/bin/env python3
2#
3# Script to find struct sizes.
4#
5# Example:
6# ./scripts/structs.py lfs.o lfs_util.o -Ssize
7#
8# Copyright (c) 2022, The littlefs authors.
9# SPDX-License-Identifier: BSD-3-Clause
10#
11
12import collections as co
13import csv
14import difflib
15import itertools as it
16import math as m
17import os
18import re
19import shlex
20import subprocess as sp
21
22
23OBJDUMP_PATH = ['objdump']
24
25
26
27# integer fields
28class Int(co.namedtuple('Int', 'x')):
29    __slots__ = ()
30    def __new__(cls, x=0):
31        if isinstance(x, Int):
32            return x
33        if isinstance(x, str):
34            try:
35                x = int(x, 0)
36            except ValueError:
37                # also accept +-∞ and +-inf
38                if re.match('^\s*\+?\s*(?:∞|inf)\s*$', x):
39                    x = m.inf
40                elif re.match('^\s*-\s*(?:∞|inf)\s*$', x):
41                    x = -m.inf
42                else:
43                    raise
44        assert isinstance(x, int) or m.isinf(x), x
45        return super().__new__(cls, x)
46
47    def __str__(self):
48        if self.x == m.inf:
49            return '∞'
50        elif self.x == -m.inf:
51            return '-∞'
52        else:
53            return str(self.x)
54
55    def __int__(self):
56        assert not m.isinf(self.x)
57        return self.x
58
59    def __float__(self):
60        return float(self.x)
61
62    none = '%7s' % '-'
63    def table(self):
64        return '%7s' % (self,)
65
66    diff_none = '%7s' % '-'
67    diff_table = table
68
69    def diff_diff(self, other):
70        new = self.x if self else 0
71        old = other.x if other else 0
72        diff = new - old
73        if diff == +m.inf:
74            return '%7s' % '+∞'
75        elif diff == -m.inf:
76            return '%7s' % '-∞'
77        else:
78            return '%+7d' % diff
79
80    def ratio(self, other):
81        new = self.x if self else 0
82        old = other.x if other else 0
83        if m.isinf(new) and m.isinf(old):
84            return 0.0
85        elif m.isinf(new):
86            return +m.inf
87        elif m.isinf(old):
88            return -m.inf
89        elif not old and not new:
90            return 0.0
91        elif not old:
92            return 1.0
93        else:
94            return (new-old) / old
95
96    def __add__(self, other):
97        return self.__class__(self.x + other.x)
98
99    def __sub__(self, other):
100        return self.__class__(self.x - other.x)
101
102    def __mul__(self, other):
103        return self.__class__(self.x * other.x)
104
105# struct size results
106class StructResult(co.namedtuple('StructResult', ['file', 'struct', 'size'])):
107    _by = ['file', 'struct']
108    _fields = ['size']
109    _sort = ['size']
110    _types = {'size': Int}
111
112    __slots__ = ()
113    def __new__(cls, file='', struct='', size=0):
114        return super().__new__(cls, file, struct,
115            Int(size))
116
117    def __add__(self, other):
118        return StructResult(self.file, self.struct,
119            self.size + other.size)
120
121
122def openio(path, mode='r', buffering=-1):
123    # allow '-' for stdin/stdout
124    if path == '-':
125        if mode == 'r':
126            return os.fdopen(os.dup(sys.stdin.fileno()), mode, buffering)
127        else:
128            return os.fdopen(os.dup(sys.stdout.fileno()), mode, buffering)
129    else:
130        return open(path, mode, buffering)
131
132def collect(obj_paths, *,
133        objdump_path=OBJDUMP_PATH,
134        sources=None,
135        everything=False,
136        internal=False,
137        **args):
138    line_pattern = re.compile(
139        '^\s+(?P<no>[0-9]+)'
140            '(?:\s+(?P<dir>[0-9]+))?'
141            '\s+.*'
142            '\s+(?P<path>[^\s]+)$')
143    info_pattern = re.compile(
144        '^(?:.*(?P<tag>DW_TAG_[a-z_]+).*'
145            '|.*DW_AT_name.*:\s*(?P<name>[^:\s]+)\s*'
146            '|.*DW_AT_decl_file.*:\s*(?P<file>[0-9]+)\s*'
147            '|.*DW_AT_byte_size.*:\s*(?P<size>[0-9]+)\s*)$')
148
149    results = []
150    for path in obj_paths:
151        # find files, we want to filter by structs in .h files
152        dirs = {}
153        files = {}
154        # note objdump-path may contain extra args
155        cmd = objdump_path + ['--dwarf=rawline', path]
156        if args.get('verbose'):
157            print(' '.join(shlex.quote(c) for c in cmd))
158        proc = sp.Popen(cmd,
159            stdout=sp.PIPE,
160            stderr=sp.PIPE if not args.get('verbose') else None,
161            universal_newlines=True,
162            errors='replace',
163            close_fds=False)
164        for line in proc.stdout:
165            # note that files contain references to dirs, which we
166            # dereference as soon as we see them as each file table follows a
167            # dir table
168            m = line_pattern.match(line)
169            if m:
170                if not m.group('dir'):
171                    # found a directory entry
172                    dirs[int(m.group('no'))] = m.group('path')
173                else:
174                    # found a file entry
175                    dir = int(m.group('dir'))
176                    if dir in dirs:
177                        files[int(m.group('no'))] = os.path.join(
178                            dirs[dir],
179                            m.group('path'))
180                    else:
181                        files[int(m.group('no'))] = m.group('path')
182        proc.wait()
183        if proc.returncode != 0:
184            if not args.get('verbose'):
185                for line in proc.stderr:
186                    sys.stdout.write(line)
187            sys.exit(-1)
188
189        # collect structs as we parse dwarf info
190        results_ = []
191        is_struct = False
192        s_name = None
193        s_file = None
194        s_size = None
195        # note objdump-path may contain extra args
196        cmd = objdump_path + ['--dwarf=info', path]
197        if args.get('verbose'):
198            print(' '.join(shlex.quote(c) for c in cmd))
199        proc = sp.Popen(cmd,
200            stdout=sp.PIPE,
201            stderr=sp.PIPE if not args.get('verbose') else None,
202            universal_newlines=True,
203            errors='replace',
204            close_fds=False)
205        for line in proc.stdout:
206            # state machine here to find structs
207            m = info_pattern.match(line)
208            if m:
209                if m.group('tag'):
210                    if is_struct:
211                        file = files.get(s_file, '?')
212                        results_.append(StructResult(file, s_name, s_size))
213                    is_struct = (m.group('tag') == 'DW_TAG_structure_type')
214                elif m.group('name'):
215                    s_name = m.group('name')
216                elif m.group('file'):
217                    s_file = int(m.group('file'))
218                elif m.group('size'):
219                    s_size = int(m.group('size'))
220        if is_struct:
221            file = files.get(s_file, '?')
222            results_.append(StructResult(file, s_name, s_size))
223        proc.wait()
224        if proc.returncode != 0:
225            if not args.get('verbose'):
226                for line in proc.stderr:
227                    sys.stdout.write(line)
228            sys.exit(-1)
229
230        for r in results_:
231            # ignore filtered sources
232            if sources is not None:
233                if not any(
234                        os.path.abspath(r.file) == os.path.abspath(s)
235                        for s in sources):
236                    continue
237            else:
238                # default to only cwd
239                if not everything and not os.path.commonpath([
240                        os.getcwd(),
241                        os.path.abspath(r.file)]) == os.getcwd():
242                    continue
243
244                # limit to .h files unless --internal
245                if not internal and not r.file.endswith('.h'):
246                    continue
247
248            # simplify path
249            if os.path.commonpath([
250                    os.getcwd(),
251                    os.path.abspath(r.file)]) == os.getcwd():
252                file = os.path.relpath(r.file)
253            else:
254                file = os.path.abspath(r.file)
255
256            results.append(r._replace(file=file))
257
258    return results
259
260
261def fold(Result, results, *,
262        by=None,
263        defines=None,
264        **_):
265    if by is None:
266        by = Result._by
267
268    for k in it.chain(by or [], (k for k, _ in defines or [])):
269        if k not in Result._by and k not in Result._fields:
270            print("error: could not find field %r?" % k)
271            sys.exit(-1)
272
273    # filter by matching defines
274    if defines is not None:
275        results_ = []
276        for r in results:
277            if all(getattr(r, k) in vs for k, vs in defines):
278                results_.append(r)
279        results = results_
280
281    # organize results into conflicts
282    folding = co.OrderedDict()
283    for r in results:
284        name = tuple(getattr(r, k) for k in by)
285        if name not in folding:
286            folding[name] = []
287        folding[name].append(r)
288
289    # merge conflicts
290    folded = []
291    for name, rs in folding.items():
292        folded.append(sum(rs[1:], start=rs[0]))
293
294    return folded
295
296def table(Result, results, diff_results=None, *,
297        by=None,
298        fields=None,
299        sort=None,
300        summary=False,
301        all=False,
302        percent=False,
303        **_):
304    all_, all = all, __builtins__.all
305
306    if by is None:
307        by = Result._by
308    if fields is None:
309        fields = Result._fields
310    types = Result._types
311
312    # fold again
313    results = fold(Result, results, by=by)
314    if diff_results is not None:
315        diff_results = fold(Result, diff_results, by=by)
316
317    # organize by name
318    table = {
319        ','.join(str(getattr(r, k) or '') for k in by): r
320        for r in results}
321    diff_table = {
322        ','.join(str(getattr(r, k) or '') for k in by): r
323        for r in diff_results or []}
324    names = list(table.keys() | diff_table.keys())
325
326    # sort again, now with diff info, note that python's sort is stable
327    names.sort()
328    if diff_results is not None:
329        names.sort(key=lambda n: tuple(
330            types[k].ratio(
331                getattr(table.get(n), k, None),
332                getattr(diff_table.get(n), k, None))
333            for k in fields),
334            reverse=True)
335    if sort:
336        for k, reverse in reversed(sort):
337            names.sort(
338                key=lambda n: tuple(
339                    (getattr(table[n], k),)
340                    if getattr(table.get(n), k, None) is not None else ()
341                    for k in ([k] if k else [
342                        k for k in Result._sort if k in fields])),
343                reverse=reverse ^ (not k or k in Result._fields))
344
345
346    # build up our lines
347    lines = []
348
349    # header
350    header = []
351    header.append('%s%s' % (
352        ','.join(by),
353        ' (%d added, %d removed)' % (
354            sum(1 for n in table if n not in diff_table),
355            sum(1 for n in diff_table if n not in table))
356            if diff_results is not None and not percent else '')
357        if not summary else '')
358    if diff_results is None:
359        for k in fields:
360            header.append(k)
361    elif percent:
362        for k in fields:
363            header.append(k)
364    else:
365        for k in fields:
366            header.append('o'+k)
367        for k in fields:
368            header.append('n'+k)
369        for k in fields:
370            header.append('d'+k)
371    header.append('')
372    lines.append(header)
373
374    def table_entry(name, r, diff_r=None, ratios=[]):
375        entry = []
376        entry.append(name)
377        if diff_results is None:
378            for k in fields:
379                entry.append(getattr(r, k).table()
380                    if getattr(r, k, None) is not None
381                    else types[k].none)
382        elif percent:
383            for k in fields:
384                entry.append(getattr(r, k).diff_table()
385                    if getattr(r, k, None) is not None
386                    else types[k].diff_none)
387        else:
388            for k in fields:
389                entry.append(getattr(diff_r, k).diff_table()
390                    if getattr(diff_r, k, None) is not None
391                    else types[k].diff_none)
392            for k in fields:
393                entry.append(getattr(r, k).diff_table()
394                    if getattr(r, k, None) is not None
395                    else types[k].diff_none)
396            for k in fields:
397                entry.append(types[k].diff_diff(
398                        getattr(r, k, None),
399                        getattr(diff_r, k, None)))
400        if diff_results is None:
401            entry.append('')
402        elif percent:
403            entry.append(' (%s)' % ', '.join(
404                '+∞%' if t == +m.inf
405                else '-∞%' if t == -m.inf
406                else '%+.1f%%' % (100*t)
407                for t in ratios))
408        else:
409            entry.append(' (%s)' % ', '.join(
410                    '+∞%' if t == +m.inf
411                    else '-∞%' if t == -m.inf
412                    else '%+.1f%%' % (100*t)
413                    for t in ratios
414                    if t)
415                if any(ratios) else '')
416        return entry
417
418    # entries
419    if not summary:
420        for name in names:
421            r = table.get(name)
422            if diff_results is None:
423                diff_r = None
424                ratios = None
425            else:
426                diff_r = diff_table.get(name)
427                ratios = [
428                    types[k].ratio(
429                        getattr(r, k, None),
430                        getattr(diff_r, k, None))
431                    for k in fields]
432                if not all_ and not any(ratios):
433                    continue
434            lines.append(table_entry(name, r, diff_r, ratios))
435
436    # total
437    r = next(iter(fold(Result, results, by=[])), None)
438    if diff_results is None:
439        diff_r = None
440        ratios = None
441    else:
442        diff_r = next(iter(fold(Result, diff_results, by=[])), None)
443        ratios = [
444            types[k].ratio(
445                getattr(r, k, None),
446                getattr(diff_r, k, None))
447            for k in fields]
448    lines.append(table_entry('TOTAL', r, diff_r, ratios))
449
450    # find the best widths, note that column 0 contains the names and column -1
451    # the ratios, so those are handled a bit differently
452    widths = [
453        ((max(it.chain([w], (len(l[i]) for l in lines)))+1+4-1)//4)*4-1
454        for w, i in zip(
455            it.chain([23], it.repeat(7)),
456            range(len(lines[0])-1))]
457
458    # print our table
459    for line in lines:
460        print('%-*s  %s%s' % (
461            widths[0], line[0],
462            ' '.join('%*s' % (w, x)
463                for w, x in zip(widths[1:], line[1:-1])),
464            line[-1]))
465
466
467def main(obj_paths, *,
468        by=None,
469        fields=None,
470        defines=None,
471        sort=None,
472        **args):
473    # find sizes
474    if not args.get('use', None):
475        results = collect(obj_paths, **args)
476    else:
477        results = []
478        with openio(args['use']) as f:
479            reader = csv.DictReader(f, restval='')
480            for r in reader:
481                if not any('struct_'+k in r and r['struct_'+k].strip()
482                        for k in StructResult._fields):
483                    continue
484                try:
485                    results.append(StructResult(
486                        **{k: r[k] for k in StructResult._by
487                            if k in r and r[k].strip()},
488                        **{k: r['struct_'+k]
489                            for k in StructResult._fields
490                            if 'struct_'+k in r
491                                and r['struct_'+k].strip()}))
492                except TypeError:
493                    pass
494
495    # fold
496    results = fold(StructResult, results, by=by, defines=defines)
497
498    # sort, note that python's sort is stable
499    results.sort()
500    if sort:
501        for k, reverse in reversed(sort):
502            results.sort(
503                key=lambda r: tuple(
504                    (getattr(r, k),) if getattr(r, k) is not None else ()
505                    for k in ([k] if k else StructResult._sort)),
506                reverse=reverse ^ (not k or k in StructResult._fields))
507
508    # write results to CSV
509    if args.get('output'):
510        with openio(args['output'], 'w') as f:
511            writer = csv.DictWriter(f,
512                (by if by is not None else StructResult._by)
513                + ['struct_'+k for k in (
514                    fields if fields is not None else StructResult._fields)])
515            writer.writeheader()
516            for r in results:
517                writer.writerow(
518                    {k: getattr(r, k) for k in (
519                        by if by is not None else StructResult._by)}
520                    | {'struct_'+k: getattr(r, k) for k in (
521                        fields if fields is not None else StructResult._fields)})
522
523    # find previous results?
524    if args.get('diff'):
525        diff_results = []
526        try:
527            with openio(args['diff']) as f:
528                reader = csv.DictReader(f, restval='')
529                for r in reader:
530                    if not any('struct_'+k in r and r['struct_'+k].strip()
531                            for k in StructResult._fields):
532                        continue
533                    try:
534                        diff_results.append(StructResult(
535                            **{k: r[k] for k in StructResult._by
536                                if k in r and r[k].strip()},
537                            **{k: r['struct_'+k]
538                                for k in StructResult._fields
539                                if 'struct_'+k in r
540                                    and r['struct_'+k].strip()}))
541                    except TypeError:
542                        pass
543        except FileNotFoundError:
544            pass
545
546        # fold
547        diff_results = fold(StructResult, diff_results, by=by, defines=defines)
548
549    # print table
550    if not args.get('quiet'):
551        table(StructResult, results,
552            diff_results if args.get('diff') else None,
553            by=by if by is not None else ['struct'],
554            fields=fields,
555            sort=sort,
556            **args)
557
558
559if __name__ == "__main__":
560    import argparse
561    import sys
562    parser = argparse.ArgumentParser(
563        description="Find struct sizes.",
564        allow_abbrev=False)
565    parser.add_argument(
566        'obj_paths',
567        nargs='*',
568        help="Input *.o files.")
569    parser.add_argument(
570        '-v', '--verbose',
571        action='store_true',
572        help="Output commands that run behind the scenes.")
573    parser.add_argument(
574        '-q', '--quiet',
575        action='store_true',
576        help="Don't show anything, useful with -o.")
577    parser.add_argument(
578        '-o', '--output',
579        help="Specify CSV file to store results.")
580    parser.add_argument(
581        '-u', '--use',
582        help="Don't parse anything, use this CSV file.")
583    parser.add_argument(
584        '-d', '--diff',
585        help="Specify CSV file to diff against.")
586    parser.add_argument(
587        '-a', '--all',
588        action='store_true',
589        help="Show all, not just the ones that changed.")
590    parser.add_argument(
591        '-p', '--percent',
592        action='store_true',
593        help="Only show percentage change, not a full diff.")
594    parser.add_argument(
595        '-b', '--by',
596        action='append',
597        choices=StructResult._by,
598        help="Group by this field.")
599    parser.add_argument(
600        '-f', '--field',
601        dest='fields',
602        action='append',
603        choices=StructResult._fields,
604        help="Show this field.")
605    parser.add_argument(
606        '-D', '--define',
607        dest='defines',
608        action='append',
609        type=lambda x: (lambda k,v: (k, set(v.split(','))))(*x.split('=', 1)),
610        help="Only include results where this field is this value.")
611    class AppendSort(argparse.Action):
612        def __call__(self, parser, namespace, value, option):
613            if namespace.sort is None:
614                namespace.sort = []
615            namespace.sort.append((value, True if option == '-S' else False))
616    parser.add_argument(
617        '-s', '--sort',
618        nargs='?',
619        action=AppendSort,
620        help="Sort by this field.")
621    parser.add_argument(
622        '-S', '--reverse-sort',
623        nargs='?',
624        action=AppendSort,
625        help="Sort by this field, but backwards.")
626    parser.add_argument(
627        '-Y', '--summary',
628        action='store_true',
629        help="Only show the total.")
630    parser.add_argument(
631        '-F', '--source',
632        dest='sources',
633        action='append',
634        help="Only consider definitions in this file. Defaults to anything "
635            "in the current directory.")
636    parser.add_argument(
637        '--everything',
638        action='store_true',
639        help="Include builtin and libc specific symbols.")
640    parser.add_argument(
641        '--internal',
642        action='store_true',
643        help="Also show structs in .c files.")
644    parser.add_argument(
645        '--objdump-path',
646        type=lambda x: x.split(),
647        default=OBJDUMP_PATH,
648        help="Path to the objdump executable, may include flags. "
649            "Defaults to %r." % OBJDUMP_PATH)
650    sys.exit(main(**{k: v
651        for k, v in vars(parser.parse_intermixed_args()).items()
652        if v is not None}))
653