1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4# SPDX-License-Identifier: Apache-2.0
5
6"""
7Checks the initialization priorities
8
9This script parses a Zephyr executable file, creates a list of known devices
10and their effective initialization priorities and compares that with the device
11dependencies inferred from the devicetree hierarchy.
12
13This can be used to detect devices that are initialized in the incorrect order,
14but also devices that are initialized at the same priority but depends on each
15other, which can potentially break if the linking order is changed.
16
17Optionally, it can also produce a human readable list of the initialization
18calls for the various init levels.
19"""
20
21import argparse
22import logging
23import os
24import pathlib
25import pickle
26import sys
27
28from elftools.elf.elffile import ELFFile
29from elftools.elf.sections import SymbolTableSection
30
31# This is needed to load edt.pickle files.
32sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..",
33                                "dts", "python-devicetree", "src"))
34from devicetree import edtlib  # pylint: disable=unused-import
35
36# Prefix used for "struct device" reference initialized based on devicetree
37# entries with a known ordinal.
38_DEVICE_ORD_PREFIX = "__device_dts_ord_"
39
40# Defined init level in order of priority.
41_DEVICE_INIT_LEVELS = ["EARLY", "PRE_KERNEL_1", "PRE_KERNEL_2", "POST_KERNEL",
42                      "APPLICATION", "SMP"]
43
44# List of compatibles for node where the initialization priority should be the
45# opposite of the device tree inferred dependency.
46_INVERTED_PRIORITY_COMPATIBLES = frozenset()
47
48# List of compatibles for nodes where we don't check the priority.
49_IGNORE_COMPATIBLES = frozenset([
50        # There is no direct dependency between the CDC ACM UART and the USB
51        # device controller, the logical connection is established after USB
52        # device support is enabled.
53        "zephyr,cdc-acm-uart",
54        ])
55
56class Priority:
57    """Parses and holds a device initialization priority.
58
59    The object can be used for comparing levels with one another.
60
61    Attributes:
62        name: the section name
63    """
64    def __init__(self, level, priority):
65        for idx, level_name in enumerate(_DEVICE_INIT_LEVELS):
66            if level_name == level:
67                self._level = idx
68                self._priority = priority
69                # Tuples compare elementwise in order
70                self._level_priority = (self._level, self._priority)
71                return
72
73        raise ValueError("Unknown level in %s" % level)
74
75    def __repr__(self):
76        return "<%s %s %d>" % (self.__class__.__name__,
77                               _DEVICE_INIT_LEVELS[self._level], self._priority)
78
79    def __str__(self):
80        return "%s %d" % (_DEVICE_INIT_LEVELS[self._level], self._priority)
81
82    def __lt__(self, other):
83        return self._level_priority < other._level_priority
84
85    def __eq__(self, other):
86        return self._level_priority == other._level_priority
87
88    def __hash__(self):
89        return self._level_priority
90
91
92class ZephyrInitLevels:
93    """Load an executable file and find the initialization calls and devices.
94
95    Load a Zephyr executable file and scan for the list of initialization calls
96    and defined devices.
97
98    The list of devices is available in the "devices" class variable in the
99    {ordinal: Priority} format, the list of initilevels is in the "initlevels"
100    class variables in the {"level name": ["call", ...]} format.
101
102    Attributes:
103        file_path: path of the file to be loaded.
104    """
105    def __init__(self, file_path):
106        self.file_path = file_path
107        self._elf = ELFFile(open(file_path, "rb"))
108        self._load_objects()
109        self._load_level_addr()
110        self._process_initlevels()
111
112    def _load_objects(self):
113        """Initialize the object table."""
114        self._objects = {}
115
116        for section in self._elf.iter_sections():
117            if not isinstance(section, SymbolTableSection):
118                continue
119
120            for sym in section.iter_symbols():
121                if (sym.name and
122                    sym.entry.st_size > 0 and
123                    sym.entry.st_info.type in ["STT_OBJECT", "STT_FUNC"]):
124                    self._objects[sym.entry.st_value] = (
125                            sym.name, sym.entry.st_size, sym.entry.st_shndx)
126
127    def _load_level_addr(self):
128        """Find the address associated with known init levels."""
129        self._init_level_addr = {}
130
131        for section in self._elf.iter_sections():
132            if not isinstance(section, SymbolTableSection):
133                continue
134
135            for sym in section.iter_symbols():
136                for level in _DEVICE_INIT_LEVELS:
137                    name = f"__init_{level}_start"
138                    if sym.name == name:
139                        self._init_level_addr[level] = sym.entry.st_value
140                    elif sym.name == "__init_end":
141                        self._init_level_end = sym.entry.st_value
142
143        if len(self._init_level_addr) != len(_DEVICE_INIT_LEVELS):
144            raise ValueError(f"Missing init symbols, found: {self._init_level_addr}")
145
146        if not self._init_level_end:
147            raise ValueError(f"Missing init section end symbol")
148
149    def _device_ord_from_name(self, sym_name):
150        """Find a device ordinal from a symbol name."""
151        if not sym_name:
152            return None
153
154        if not sym_name.startswith(_DEVICE_ORD_PREFIX):
155            return None
156
157        _, device_ord = sym_name.split(_DEVICE_ORD_PREFIX)
158        return int(device_ord)
159
160    def _object_name(self, addr):
161        if not addr:
162            return "NULL"
163        elif addr in self._objects:
164            return self._objects[addr][0]
165        else:
166            return "unknown"
167
168    def _initlevel_pointer(self, addr, idx, shidx):
169        elfclass = self._elf.elfclass
170        if elfclass == 32:
171            ptrsize = 4
172        elif elfclass == 64:
173            ptrsize = 8
174        else:
175            ValueError(f"Unknown pointer size for ELF class f{elfclass}")
176
177        section = self._elf.get_section(shidx)
178        start = section.header.sh_addr
179        data = section.data()
180
181        offset = addr - start
182
183        start = offset + ptrsize * idx
184        stop = offset + ptrsize * (idx + 1)
185
186        return int.from_bytes(data[start:stop], byteorder="little")
187
188    def _process_initlevels(self):
189        """Process the init level and find the init functions and devices."""
190        self.devices = {}
191        self.initlevels = {}
192
193        for i, level in enumerate(_DEVICE_INIT_LEVELS):
194            start = self._init_level_addr[level]
195            if i + 1 == len(_DEVICE_INIT_LEVELS):
196                stop = self._init_level_end
197            else:
198                stop = self._init_level_addr[_DEVICE_INIT_LEVELS[i + 1]]
199
200            self.initlevels[level] = []
201
202            priority = 0
203            addr = start
204            while addr < stop:
205                if addr not in self._objects:
206                    raise ValueError(f"no symbol at addr {addr:08x}")
207                obj, size, shidx = self._objects[addr]
208
209                arg0_name = self._object_name(self._initlevel_pointer(addr, 0, shidx))
210                arg1_name = self._object_name(self._initlevel_pointer(addr, 1, shidx))
211
212                self.initlevels[level].append(f"{obj}: {arg0_name}({arg1_name})")
213
214                ordinal = self._device_ord_from_name(arg1_name)
215                if ordinal:
216                    prio = Priority(level, priority)
217                    self.devices[ordinal] = prio
218
219                addr += size
220                priority += 1
221
222class Validator():
223    """Validates the initialization priorities.
224
225    Scans through a build folder for object files and list all the device
226    initialization priorities. Then compares that against the EDT derived
227    dependency list and log any found priority issue.
228
229    Attributes:
230        elf_file_path: path of the ELF file
231        edt_pickle: name of the EDT pickle file
232        log: a logging.Logger object
233    """
234    def __init__(self, elf_file_path, edt_pickle, log):
235        self.log = log
236
237        edt_pickle_path = pathlib.Path(
238                pathlib.Path(elf_file_path).parent,
239                edt_pickle)
240        with open(edt_pickle_path, "rb") as f:
241            edt = pickle.load(f)
242
243        self._ord2node = edt.dep_ord2node
244
245        self._obj = ZephyrInitLevels(elf_file_path)
246
247        self.warnings = 0
248        self.errors = 0
249
250    def _check_dep(self, dev_ord, dep_ord):
251        """Validate the priority between two devices."""
252        if dev_ord == dep_ord:
253            return
254
255        dev_node = self._ord2node[dev_ord]
256        dep_node = self._ord2node[dep_ord]
257
258        if dev_node._binding:
259            dev_compat = dev_node._binding.compatible
260            if dev_compat in _IGNORE_COMPATIBLES:
261                self.log.info(f"Ignoring priority: {dev_node._binding.compatible}")
262                return
263
264        if dev_node._binding and dep_node._binding:
265            dev_compat = dev_node._binding.compatible
266            dep_compat = dep_node._binding.compatible
267            if (dev_compat, dep_compat) in _INVERTED_PRIORITY_COMPATIBLES:
268                self.log.info(f"Swapped priority: {dev_compat}, {dep_compat}")
269                dev_ord, dep_ord = dep_ord, dev_ord
270
271        dev_prio = self._obj.devices.get(dev_ord, None)
272        dep_prio = self._obj.devices.get(dep_ord, None)
273
274        if not dev_prio or not dep_prio:
275            return
276
277        if dev_prio == dep_prio:
278            self.warnings += 1
279            self.log.warning(
280                    f"{dev_node.path} {dev_prio} == {dep_node.path} {dep_prio}")
281        elif dev_prio < dep_prio:
282            self.errors += 1
283            self.log.error(
284                    f"{dev_node.path} {dev_prio} < {dep_node.path} {dep_prio}")
285        else:
286            self.log.info(
287                    f"{dev_node.path} {dev_prio} > {dep_node.path} {dep_prio}")
288
289    def _check_edt_r(self, dev_ord, dev):
290        """Recursively check for dependencies of a device."""
291        for dep in dev.depends_on:
292            self._check_dep(dev_ord, dep.dep_ordinal)
293        if dev._binding and dev._binding.child_binding:
294            for child in dev.children.values():
295                if "compatible" in child.props:
296                    continue
297                if dev._binding.path != child._binding.path:
298                    continue
299                self._check_edt_r(dev_ord, child)
300
301    def check_edt(self):
302        """Scan through all known devices and validate the init priorities."""
303        for dev_ord in self._obj.devices:
304            dev = self._ord2node[dev_ord]
305            self._check_edt_r(dev_ord, dev)
306
307    def print_initlevels(self):
308        for level, calls in self._obj.initlevels.items():
309            print(level)
310            for call in calls:
311                print(f"  {call}")
312
313def _parse_args(argv):
314    """Parse the command line arguments."""
315    parser = argparse.ArgumentParser(
316        description=__doc__,
317        formatter_class=argparse.RawDescriptionHelpFormatter,
318        allow_abbrev=False)
319
320    parser.add_argument("-f", "--elf-file", default=pathlib.Path("build", "zephyr", "zephyr.elf"),
321                        help="ELF file to use")
322    parser.add_argument("-v", "--verbose", action="count",
323                        help=("enable verbose output, can be used multiple times "
324                              "to increase verbosity level"))
325    parser.add_argument("-w", "--fail-on-warning", action="store_true",
326                        help="fail on both warnings and errors")
327    parser.add_argument("--always-succeed", action="store_true",
328                        help="always exit with a return code of 0, used for testing")
329    parser.add_argument("-o", "--output",
330                        help="write the output to a file in addition to stdout")
331    parser.add_argument("-i", "--initlevels", action="store_true",
332                        help="print the initlevel functions instead of checking the device dependencies")
333    parser.add_argument("--edt-pickle", default=pathlib.Path("edt.pickle"),
334                        help="name of the the pickled edtlib.EDT file",
335                        type=pathlib.Path)
336
337    return parser.parse_args(argv)
338
339def _init_log(verbose, output):
340    """Initialize a logger object."""
341    log = logging.getLogger(__file__)
342
343    console = logging.StreamHandler()
344    console.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
345    log.addHandler(console)
346
347    if output:
348        file = logging.FileHandler(output, mode="w")
349        file.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
350        log.addHandler(file)
351
352    if verbose and verbose > 1:
353        log.setLevel(logging.DEBUG)
354    elif verbose and verbose > 0:
355        log.setLevel(logging.INFO)
356    else:
357        log.setLevel(logging.WARNING)
358
359    return log
360
361def main(argv=None):
362    args = _parse_args(argv)
363
364    log = _init_log(args.verbose, args.output)
365
366    log.info(f"check_init_priorities: {args.elf_file}")
367
368    validator = Validator(args.elf_file, args.edt_pickle, log)
369    if args.initlevels:
370        validator.print_initlevels()
371    else:
372        validator.check_edt()
373
374    if args.always_succeed:
375        return 0
376
377    if args.fail_on_warning and validator.warnings:
378        return 1
379
380    if validator.errors:
381        return 1
382
383    return 0
384
385if __name__ == "__main__":
386    sys.exit(main(sys.argv[1:]))
387