1#!/usr/bin/env python
2#
3# Based on cally.py (https://github.com/chaudron/cally/), Copyright 2018, Eelco Chaudron
4# SPDX-FileCopyrightText: 2020-2023 Espressif Systems (Shanghai) CO LTD
5# SPDX-License-Identifier: Apache-2.0
6
7import argparse
8import os
9import re
10from functools import partial
11from typing import BinaryIO, Callable, Dict, Generator, List, Optional, Tuple
12
13import elftools
14from elftools.elf import elffile
15
16FUNCTION_REGEX = re.compile(
17    r'^;; Function (?P<mangle>.*)\s+\((?P<function>\S+)(,.*)?\).*$'
18)
19CALL_REGEX = re.compile(r'^.*\(call.*"(?P<target>.*)".*$')
20SYMBOL_REF_REGEX = re.compile(r'^.*\(symbol_ref[^()]*\("(?P<target>.*)"\).*$')
21
22
23class RtlFunction(object):
24    def __init__(self, name: str, rtl_filename: str, tu_filename: str) -> None:
25        self.name = name
26        self.rtl_filename = rtl_filename
27        self.tu_filename = tu_filename
28        self.calls: List[str] = list()
29        self.refs: List[str] = list()
30        self.sym = None
31
32
33class SectionAddressRange(object):
34    def __init__(self, name: str, addr: int, size: int) -> None:
35        self.name = name
36        self.low = addr
37        self.high = addr + size
38
39    def __str__(self) -> str:
40        return '{}: 0x{:08x} - 0x{:08x}'.format(self.name, self.low, self.high)
41
42    def contains_address(self, addr: int) -> bool:
43        return self.low <= addr < self.high
44
45
46TARGET_SECTIONS: Dict[str, List[SectionAddressRange]] = {
47    'esp32': [
48        SectionAddressRange('.rom.text', 0x40000000, 0x70000),
49        SectionAddressRange('.rom.rodata', 0x3ff96000, 0x9018)
50    ],
51    'esp32s2': [
52        SectionAddressRange('.rom.text', 0x40000000, 0x1bed0),
53        SectionAddressRange('.rom.rodata', 0x3ffac600, 0x392c)
54    ],
55    'esp32s3': [
56        SectionAddressRange('.rom.text', 0x40000000, 0x568d0),
57        SectionAddressRange('.rom.rodata', 0x3ff071c0, 0x8e30)
58    ]
59}
60
61
62class Symbol(object):
63    def __init__(self, name: str, addr: int, local: bool, filename: Optional[str], section: Optional[str]) -> None:
64        self.name = name
65        self.addr = addr
66        self.local = local
67        self.filename = filename
68        self.section = section
69        self.refers_to: List[Symbol] = list()
70        self.referred_from: List[Symbol] = list()
71
72    def __str__(self) -> str:
73        return '{} @0x{:08x} [{}]{} {}'.format(
74            self.name,
75            self.addr,
76            self.section or 'unknown',
77            ' (local)' if self.local else '',
78            self.filename
79        )
80
81
82class Reference(object):
83    def __init__(self, from_sym: Symbol, to_sym: Symbol) -> None:
84        self.from_sym = from_sym
85        self.to_sym = to_sym
86
87    def __str__(self) -> str:
88        return '{} @0x{:08x} ({}) -> {} @0x{:08x} ({})'.format(
89            self.from_sym.name,
90            self.from_sym.addr,
91            self.from_sym.section,
92            self.to_sym.name,
93            self.to_sym.addr,
94            self.to_sym.section
95        )
96
97
98class IgnorePair():
99    def __init__(self, pair: str) -> None:
100        self.symbol, self.function_call = pair.split('/')
101
102
103class ElfInfo(object):
104    def __init__(self, elf_file: BinaryIO) -> None:
105        self.elf_file = elf_file
106        self.elf_obj = elffile.ELFFile(self.elf_file)
107        self.section_ranges = self._load_sections()
108        self.symbols = self._load_symbols()
109
110    def _load_symbols(self) -> List[Symbol]:
111        symbols = []
112        for s in self.elf_obj.iter_sections():
113            if not isinstance(s, elftools.elf.sections.SymbolTableSection):
114                continue
115            filename = None
116            for sym in s.iter_symbols():
117                sym_type = sym.entry['st_info']['type']
118                if sym_type == 'STT_FILE':
119                    filename = sym.name
120                if sym_type in ['STT_NOTYPE', 'STT_FUNC', 'STT_OBJECT']:
121                    local = sym.entry['st_info']['bind'] == 'STB_LOCAL'
122                    addr = sym.entry['st_value']
123                    symbols.append(
124                        Symbol(
125                            sym.name,
126                            addr,
127                            local,
128                            filename if local else None,
129                            self.section_for_addr(addr),
130                        )
131                    )
132        return symbols
133
134    def _load_sections(self) -> List[SectionAddressRange]:
135        result = []
136        for segment in self.elf_obj.iter_segments():
137            if segment['p_type'] == 'PT_LOAD':
138                for section in self.elf_obj.iter_sections():
139                    if not segment.section_in_segment(section):
140                        continue
141                    result.append(
142                        SectionAddressRange(
143                            section.name, section['sh_addr'], section['sh_size']
144                        )
145                    )
146
147        target = os.environ.get('IDF_TARGET')
148        if target in TARGET_SECTIONS:
149            result += TARGET_SECTIONS[target]
150
151        return result
152
153    def symbols_by_name(self, name: str) -> List['Symbol']:
154        res = []
155        for sym in self.symbols:
156            if sym.name == name:
157                res.append(sym)
158        return res
159
160    def section_for_addr(self, sym_addr: int) -> Optional[str]:
161        for sar in self.section_ranges:
162            if sar.contains_address(sym_addr):
163                return sar.name
164        return None
165
166
167def load_rtl_file(rtl_filename: str, tu_filename: str, functions: List[RtlFunction], ignore_pairs: List[IgnorePair]) -> None:
168    last_function: Optional[RtlFunction] = None
169    for line in open(rtl_filename):
170        # Find function definition
171        match = re.match(FUNCTION_REGEX, line)
172        if match:
173            function_name = match.group('function')
174            last_function = RtlFunction(function_name, rtl_filename, tu_filename)
175            functions.append(last_function)
176            continue
177
178        if last_function:
179            # Find direct function calls
180            match = re.match(CALL_REGEX, line)
181            if match:
182                target = match.group('target')
183
184                # if target matches on of the IgnorePair function_call attributes, remove
185                # the last occurrence of the associated symbol from the last_function.refs list.
186                call_matching_pairs = [pair for pair in ignore_pairs if pair.function_call == target]
187                if call_matching_pairs and last_function and last_function.refs:
188                    for pair in call_matching_pairs:
189                        ignored_symbols = [ref for ref in last_function.refs if pair.symbol in ref]
190                        if ignored_symbols:
191                            last_ref = ignored_symbols.pop()
192                            last_function.refs = [ref for ref in last_function.refs if last_ref != ref]
193
194                if target not in last_function.calls:
195                    last_function.calls.append(target)
196                continue
197
198            # Find symbol references
199            match = re.match(SYMBOL_REF_REGEX, line)
200            if match:
201                target = match.group('target')
202                if target not in last_function.refs:
203                    last_function.refs.append(target)
204                continue
205
206
207def rtl_filename_matches_sym_filename(rtl_filename: str, symbol_filename: str) -> bool:
208    # Symbol file names (from ELF debug info) are short source file names, without path: "cpu_start.c".
209    # RTL file names are paths relative to the build directory, e.g.:
210    # "build/esp-idf/esp_system/CMakeFiles/__idf_esp_system.dir/port/cpu_start.c.234r.expand"
211    #
212    # The check below may give a false positive if there are two files with the same name in
213    # different directories. This doesn't seem to happen in IDF now, but if it does happen,
214    # an assert in find_symbol_by_rtl_func should catch this.
215    #
216    # If this becomes and issue, consider also loading the .map file and using it to figure out
217    # which object file was used as the source of each symbol. Names of the object files and RTL files
218    # should be much easier to match.
219    return os.path.basename(rtl_filename).startswith(symbol_filename)
220
221
222class SymbolNotFound(RuntimeError):
223    pass
224
225
226def find_symbol_by_name(name: str, elfinfo: ElfInfo, local_func_matcher: Callable[[Symbol], bool]) -> Optional[Symbol]:
227    """
228    Find an ELF symbol for the given name.
229    local_func_matcher is a callback function which checks is the candidate local symbol is suitable.
230    """
231    syms = elfinfo.symbols_by_name(name)
232    if not syms:
233        return None
234    if len(syms) == 1:
235        return syms[0]
236    else:
237        # There are multiple symbols with a given name. Find the best fit.
238        local_candidate = None
239        global_candidate = None
240        for sym in syms:
241            if not sym.local:
242                assert not global_candidate  # can't have two global symbols with the same name
243                global_candidate = sym
244            elif local_func_matcher(sym):
245                assert not local_candidate  # can't have two symbols with the same name in a single file
246                local_candidate = sym
247
248        # If two symbols with the same name are defined, a global and a local one,
249        # prefer the local symbol as the reference target.
250        return local_candidate or global_candidate
251
252
253def match_local_source_func(rtl_filename: str, sym: Symbol) -> bool:
254    """
255    Helper for match_rtl_funcs_to_symbols, checks if local symbol sym is a good candidate for the
256    reference source (caller), based on the RTL file name.
257    """
258    assert sym.filename  # should be set for local functions
259    return rtl_filename_matches_sym_filename(rtl_filename, sym.filename)
260
261
262def match_local_target_func(rtl_filename: str, sym_from: Symbol, sym: Symbol) -> bool:
263    """
264    Helper for match_rtl_funcs_to_symbols, checks if local symbol sym is a good candidate for the
265    reference target (callee or referenced data), based on RTL filename of the source symbol
266    and the source symbol itself.
267    """
268    assert sym.filename  # should be set for local functions
269    if sym_from.local:
270        # local symbol referencing another local symbol
271        return sym_from.filename == sym.filename
272    else:
273        # global symbol referencing a local symbol;
274        # source filename is not known, use RTL filename as a hint
275        return rtl_filename_matches_sym_filename(rtl_filename, sym.filename)
276
277
278def match_rtl_funcs_to_symbols(rtl_functions: List[RtlFunction], elfinfo: ElfInfo) -> Tuple[List[Symbol], List[Reference]]:
279    symbols: List[Symbol] = []
280    refs: List[Reference] = []
281
282    # General idea:
283    # - iterate over RTL functions.
284    #   - for each RTL function, find the corresponding symbol
285    #   - iterate over the functions and variables referenced from this RTL function
286    #     - find symbols corresponding to the references
287    #     - record every pair (sym_from, sym_to) as a Reference object
288
289    for source_rtl_func in rtl_functions:
290        maybe_sym_from = find_symbol_by_name(source_rtl_func.name, elfinfo, partial(match_local_source_func, source_rtl_func.rtl_filename))
291        if maybe_sym_from is None:
292            # RTL references a symbol, but the symbol is not defined in the generated object file.
293            # This means that the symbol was likely removed (or not included) at link time.
294            # There is nothing we can do to check section placement in this case.
295            continue
296        sym_from = maybe_sym_from
297
298        if sym_from not in symbols:
299            symbols.append(sym_from)
300
301        for target_rtl_func_name in source_rtl_func.calls + source_rtl_func.refs:
302            if '*.LC' in target_rtl_func_name:  # skip local labels
303                continue
304
305            maybe_sym_to = find_symbol_by_name(target_rtl_func_name, elfinfo, partial(match_local_target_func, source_rtl_func.rtl_filename, sym_from))
306            if not maybe_sym_to:
307                # This may happen for a extern reference in the RTL file, if the reference was later removed
308                # by one of the optimization passes, and the external definition got garbage-collected.
309                # TODO: consider adding some sanity check that we are here not because of some bug in
310                # find_symbol_by_name?..
311                continue
312            sym_to = maybe_sym_to
313
314            sym_from.refers_to.append(sym_to)
315            sym_to.referred_from.append(sym_from)
316            refs.append(Reference(sym_from, sym_to))
317            if sym_to not in symbols:
318                symbols.append(sym_to)
319
320    return symbols, refs
321
322
323def get_symbols_and_refs(rtl_list: List[str], elf_file: BinaryIO, ignore_pairs: List[IgnorePair]) -> Tuple[List[Symbol], List[Reference]]:
324    elfinfo = ElfInfo(elf_file)
325
326    rtl_functions: List[RtlFunction] = []
327    for file_name in rtl_list:
328        load_rtl_file(file_name, file_name, rtl_functions, ignore_pairs)
329
330    return match_rtl_funcs_to_symbols(rtl_functions, elfinfo)
331
332
333def list_refs_from_to_sections(refs: List[Reference], from_sections: List[str], to_sections: List[str]) -> int:
334    found = 0
335    for ref in refs:
336        if (not from_sections or ref.from_sym.section in from_sections) and \
337           (not to_sections or ref.to_sym.section in to_sections):
338            print(str(ref))
339            found += 1
340    return found
341
342
343def find_files_recursive(root_path: str, ext: str) -> Generator[str, None, None]:
344    for root, _, files in os.walk(root_path):
345        for basename in files:
346            if basename.endswith(ext):
347                filename = os.path.join(root, basename)
348                yield filename
349
350
351def main() -> None:
352    parser = argparse.ArgumentParser()
353
354    parser.add_argument(
355        '--rtl-list',
356        help='File with the list of RTL files',
357        type=argparse.FileType('r'),
358    )
359    parser.add_argument(
360        '--rtl-dirs', help='comma-separated list of directories where to look for RTL files, recursively'
361    )
362    parser.add_argument(
363        '--elf-file',
364        required=True,
365        help='Program ELF file',
366        type=argparse.FileType('rb'),
367    )
368    action_sub = parser.add_subparsers(dest='action')
369    find_refs_parser = action_sub.add_parser(
370        'find-refs',
371        help='List the references coming from a given list of source sections'
372             'to a given list of target sections.',
373    )
374    find_refs_parser.add_argument(
375        '--from-sections', help='comma-separated list of source sections'
376    )
377    find_refs_parser.add_argument(
378        '--to-sections', help='comma-separated list of target sections'
379    )
380    find_refs_parser.add_argument(
381        '--ignore-symbols', help='comma-separated list of symbol/function_name pairs. \
382                                  This will force the parser to ignore the symbol preceding the call to function_name'
383    )
384    find_refs_parser.add_argument(
385        '--exit-code',
386        action='store_true',
387        help='If set, exits with non-zero code when any references found',
388    )
389    action_sub.add_parser(
390        'all-refs',
391        help='Print the list of all references',
392    )
393
394    parser.parse_args()
395    args = parser.parse_args()
396    if args.rtl_list:
397        with open(args.rtl_list, 'r') as rtl_list_file:
398            rtl_list = [line.strip() for line in rtl_list_file]
399    else:
400        if not args.rtl_dirs:
401            raise RuntimeError('Either --rtl-list or --rtl-dirs must be specified')
402        rtl_dirs = args.rtl_dirs.split(',')
403        rtl_list = []
404        for dir in rtl_dirs:
405            rtl_list.extend(list(find_files_recursive(dir, '.expand')))
406
407    if not rtl_list:
408        raise RuntimeError('No RTL files specified')
409
410    ignore_pairs = []
411    for pair in args.ignore_symbols.split(',') if args.ignore_symbols else []:
412        ignore_pairs.append(IgnorePair(pair))
413
414    _, refs = get_symbols_and_refs(rtl_list, args.elf_file, ignore_pairs)
415
416    if args.action == 'find-refs':
417        from_sections = args.from_sections.split(',') if args.from_sections else []
418        to_sections = args.to_sections.split(',') if args.to_sections else []
419        found = list_refs_from_to_sections(
420            refs, from_sections, to_sections
421        )
422        if args.exit_code and found:
423            raise SystemExit(1)
424    elif args.action == 'all-refs':
425        for r in refs:
426            print(str(r))
427
428
429if __name__ == '__main__':
430    main()
431