1#!/usr/bin/env python3
2#
3# Copyright (c) 2024 STMicroelectronics
4# SPDX-License-Identifier: Apache-2.0
5
6"""Injects SLIDs in LLEXT ELFs' symbol tables.
7
8When Kconfig option CONFIG_LLEXT_EXPORT_BUILTINS_BY_SLID is enabled,
9all imports from the Zephyr kernel & application are resolved using
10SLIDs instead of symbol names. This script stores the SLID of all
11imported symbols in their associated entry in the ELF symbol table
12to allow the LLEXT subsystem to link it properly at runtime.
13
14Note that this script is idempotent in theory. However, to prevent
15any catastrophic problem, the script will abort if the 'st_value'
16field of the `ElfX_Sym` structure is found to be non-zero, which is
17the case after one invocation. For this reason, in practice, the script
18cannot actually be executed twice on the same ELF file.
19"""
20
21import argparse
22import logging
23import shutil
24import sys
25
26from elftools.elf.elffile import ELFFile
27from elftools.elf.sections import SymbolTableSection
28
29import llext_slidlib
30
31class LLEXTSymtabPreparator():
32    def __init__(self, elf_path, log):
33        self.log = log
34        self.elf_path = elf_path
35        self.elf_fd = open(elf_path, "rb+")
36        self.elf = ELFFile(self.elf_fd)
37
38    def _find_symtab(self):
39        supported_symtab_sections = [
40            ".symtab",
41            ".dynsym",
42        ]
43
44        symtab = None
45        for section_name in supported_symtab_sections:
46            symtab = self.elf.get_section_by_name(section_name)
47            if not isinstance(symtab, SymbolTableSection):
48                self.log.debug(f"section {section_name} not found.")
49            else:
50                self.log.info(f"processing '{section_name}' symbol table...")
51                self.log.debug(f"(symbol table is at file offset 0x{symtab['sh_offset']:X})")
52                break
53        return symtab
54
55    def _find_imports_in_symtab(self, symtab):
56        i = 0
57        imports = []
58        for sym in symtab.iter_symbols():
59            #Check if symbol is an import
60            if sym.entry['st_info']['type'] == 'STT_NOTYPE' and \
61                sym.entry['st_info']['bind'] == 'STB_GLOBAL' and \
62                sym.entry['st_shndx'] == 'SHN_UNDEF':
63
64                self.log.debug(f"found imported symbol '{sym.name}' at index {i}")
65                imports.append((i, sym))
66
67            i += 1
68        return imports
69
70    def _prepare_inner(self):
71        #1) Locate the symbol table
72        symtab = self._find_symtab()
73        if symtab is None:
74            self.log.error("no symbol table found in file")
75            return 1
76
77        #2) Find imported symbols in symbol table
78        imports = self._find_imports_in_symtab(symtab)
79        self.log.info(f"LLEXT has {len(imports)} import(s)")
80
81        #3) Write SLIDs in each symbol's 'st_value' field
82        def make_stvalue_reader_writer():
83            byteorder = "little" if self.elf.little_endian else "big"
84            if self.elf.elfclass == 32:
85                sizeof_Elf_Sym = 0x10    #sizeof(Elf32_Sym)
86                offsetof_st_value = 0x4  #offsetof(Elf32_Sym, st_value)
87                sizeof_st_value = 0x4    #sizeof(Elf32_Sym.st_value)
88            else:
89                sizeof_Elf_Sym = 0x18
90                offsetof_st_value = 0x8
91                sizeof_st_value = 0x8
92
93            def seek(symidx):
94                self.elf_fd.seek(
95                    symtab['sh_offset'] +
96                    symidx * sizeof_Elf_Sym +
97                    offsetof_st_value)
98
99            def reader(symbol_index):
100                seek(symbol_index)
101                return int.from_bytes(self.elf_fd.read(sizeof_st_value), byteorder)
102
103            def writer(symbol_index, st_value):
104                seek(symbol_index)
105                self.elf_fd.write(int.to_bytes(st_value, sizeof_st_value, byteorder))
106
107            return reader, writer
108
109        rd_st_val, wr_st_val = make_stvalue_reader_writer()
110        slid_size = self.elf.elfclass // 8
111
112        for (index, symbol) in imports:
113            slid = llext_slidlib.generate_slid(symbol.name, slid_size)
114            slid_as_str = llext_slidlib.format_slid(slid, slid_size)
115            msg = f"{symbol.name} -> {slid_as_str}"
116
117            self.log.info(msg)
118
119            # Make sure we're not overwriting something actually important
120            original_st_value = rd_st_val(index)
121            if original_st_value != 0:
122                self.log.error(f"unexpected non-zero st_value for symbol {symbol.name}")
123                return 1
124
125            wr_st_val(index, slid)
126
127        return 0
128
129    def prepare_llext(self):
130        res = self._prepare_inner()
131        self.elf_fd.close()
132        return res
133
134# Disable duplicate code warning for the code that follows,
135# as it is expected for these functions to be similar.
136# pylint: disable=duplicate-code
137def _parse_args(argv):
138    """Parse the command line arguments."""
139    parser = argparse.ArgumentParser(
140        description=__doc__,
141        formatter_class=argparse.RawDescriptionHelpFormatter,
142        allow_abbrev=False)
143
144    parser.add_argument("-f", "--elf-file", required=True,
145                        help="LLEXT ELF file to process")
146    parser.add_argument("-o", "--output-file",
147                        help=("Additional output file where processed ELF "
148                        "will be copied"))
149    parser.add_argument("-sl", "--slid-listing",
150                        help="write the SLID listing to a file")
151    parser.add_argument("-v", "--verbose", action="count",
152                        help=("enable verbose output, can be used multiple times "
153                              "to increase verbosity level"))
154    parser.add_argument("--always-succeed", action="store_true",
155                        help="always exit with a return code of 0, used for testing")
156
157    return parser.parse_args(argv)
158
159def _init_log(verbose):
160    """Initialize a logger object."""
161    log = logging.getLogger(__file__)
162
163    console = logging.StreamHandler()
164    console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
165    log.addHandler(console)
166
167    if verbose and verbose > 1:
168        log.setLevel(logging.DEBUG)
169    elif verbose and verbose > 0:
170        log.setLevel(logging.INFO)
171    else:
172        log.setLevel(logging.WARNING)
173
174    return log
175
176def main(argv=None):
177    args = _parse_args(argv)
178
179    log = _init_log(args.verbose)
180
181    log.info(f"inject_slids_in_llext: {args.elf_file}")
182
183    preparator = LLEXTSymtabPreparator(args.elf_file, log)
184
185    res = preparator.prepare_llext()
186
187    if args.always_succeed:
188        return 0
189
190    if res == 0 and args.output_file:
191        shutil.copy(args.elf_file, args.output_file)
192
193    return res
194
195if __name__ == "__main__":
196    sys.exit(main(sys.argv[1:]))
197