1#!/usr/bin/env python3
2#
3# Copyright (c) 2025 Intel Corporation
4#
5# SPDX-License-Identifier: Apache-2.0
6
7import argparse
8import struct
9
10from elftools.elf.elffile import ELFFile
11from elftools.elf.sections import SymbolTableSection
12
13
14def parse_args():
15    global args
16    parser = argparse.ArgumentParser(
17        description=__doc__,
18        formatter_class=argparse.RawDescriptionHelpFormatter,
19        allow_abbrev=False,
20    )
21
22    parser.add_argument("-k", "--kernel", required=False, help="Zephyr kernel image")
23    parser.add_argument("-o", "--header-output", required=False, help="Header output file")
24    parser.add_argument("-c", "--config", required=False, help="Configuration file (.config)")
25    parser.add_argument(
26        "-a",
27        "--arch",
28        required=False,
29        help="Architecture to generate shadow stack for",
30        choices=["x86", "x86_64"],
31        default="x86",
32    )
33    args = parser.parse_args()
34
35
36def get_symbols(obj):
37    for section in obj.iter_sections():
38        if isinstance(section, SymbolTableSection):
39            return {sym.name: sym for sym in section.iter_symbols()}
40
41    raise LookupError("Could not find symbol table")
42
43
44shstk_irq_top_fmt = "<Q"
45
46
47def generate_initialized_irq_array64(sym, section_data, section_addr, isr_depth):
48    offset = sym["st_value"] - section_addr
49    # First four bytes have the number of members of the array
50    nmemb = int.from_bytes(section_data[offset : offset + 4], "little") * isr_depth
51    stack_size = (int)(sym["st_size"] / nmemb)
52
53    # Top of shadow stack is on the form:
54    # [-1] = shadow stack supervisor token
55
56    output = bytearray(sym["st_size"])
57    for i in range(nmemb):
58        token = sym["st_value"] + stack_size * (i + 1) - 8
59
60        struct.pack_into(shstk_irq_top_fmt, output, stack_size * (i + 1) - 8, token)
61
62    return output
63
64
65shstk_top_fmt = "<QQQQ"
66
67
68def generate_initialized_array64(sym, section_data, section_addr, thread_entry):
69    offset = sym["st_value"] - section_addr
70    # First four bytes have the number of members of the array
71    nmemb = int.from_bytes(section_data[offset : offset + 4], "little")
72    if nmemb == 0:
73        # CONFIG_DYNAMIC_THREAD_POOL_SIZE can be zero - in which case isn't
74        # used. To allow the static initialization, we make the corresponding shadow stack
75        # have one item. Here, we recognize it by having 0 members. But we already
76        # wasted the space for this. Sigh...
77        return bytearray(sym["st_size"])
78    stack_size = (int)(sym["st_size"] / nmemb)
79
80    # Top of shadow stack is on the form:
81    # [-5] = shadow stack token, pointing to [-4]
82    # [-4] = previous SSP, pointing to [-1]
83    # [-3] = z_thread_entry
84    # [-2] = X86_KERNEL_CS
85
86    output = bytearray(sym["st_size"])
87    for i in range(nmemb):
88        end = sym["st_value"] + stack_size * (i + 1)
89        token = end - 8 * 4 + 1
90        prev_ssp = end - 8
91        cs = 0x18  # X86_KERNEL_CS
92
93        struct.pack_into(
94            shstk_top_fmt,
95            output,
96            stack_size * (i + 1) - 8 * 5,
97            token,
98            prev_ssp,
99            thread_entry["st_value"],
100            cs,
101        )
102
103    return output
104
105
106shstk_top_fmt32 = "<QI"
107
108
109def generate_initialized_array32(sym, section_data, section_addr, thread_entry):
110    offset = sym["st_value"] - section_addr
111    # First four bytes have the number of members of the array
112    nmemb = int.from_bytes(section_data[offset : offset + 4], "little")
113    if nmemb == 0:
114        # See comment on generate_initialized_array64
115        return bytearray(sym["st_size"])
116    stack_size = (int)(sym["st_size"] / nmemb)
117
118    # Top of shadow stack is on the form:
119    # [-4] = shadow stack token, pointing to [-2]
120    # [-3] = 0 - high order bits of token
121    # [-2] = z_thread_entry/z_x86_thread_entry_wrapper
122
123    output = bytearray(sym["st_size"])
124    for i in range(nmemb):
125        end = sym["st_value"] + stack_size * (i + 1)
126        token = end - 8
127
128        struct.pack_into(
129            shstk_top_fmt32,
130            output,
131            stack_size * (i + 1) - 4 * 4,
132            token,
133            thread_entry["st_value"],
134        )
135
136    return output
137
138
139def generate_initialized_array(sym, section_data, section_addr, thread_entry):
140    if args.arch == "x86":
141        return generate_initialized_array32(sym, section_data, section_addr, thread_entry)
142
143    return generate_initialized_array64(sym, section_data, section_addr, thread_entry)
144
145
146def patch_elf():
147    with open(args.kernel, "r+b") as elf_fp:
148        kernel = ELFFile(elf_fp)
149        section = kernel.get_section_by_name(".x86shadowstack.arr")
150        syms = get_symbols(kernel)
151        thread_entry = syms["z_thread_entry"]
152        if args.arch == "x86" and syms["CONFIG_X86_DEBUG_INFO"].entry.st_value:
153            thread_entry = syms["z_x86_thread_entry_wrapper"]
154
155        updated_section = bytearray()
156        shstk_arr_syms = [
157            sym[1]
158            for sym in syms.items()
159            if sym[0].startswith("__") and sym[0].endswith("_shstk_arr")
160        ]
161        shstk_arr_syms.sort(key=lambda x: x["st_value"])
162        section_data = section.data()
163
164        for sym in shstk_arr_syms:
165            if sym.name == "__z_interrupt_stacks_shstk_arr" and args.arch == "x86_64":
166                isr_depth = syms["CONFIG_ISR_DEPTH"].entry.st_value
167                out = generate_initialized_irq_array64(
168                    sym, section_data, section["sh_addr"], isr_depth
169                )
170            else:
171                out = generate_initialized_array(
172                    sym, section_data, section["sh_addr"], thread_entry
173                )
174
175            updated_section += out
176
177        elf_fp.seek(section["sh_offset"])
178        elf_fp.write(updated_section)
179
180
181def generate_header():
182    if args.config is None:
183        raise ValueError("Configuration file is required to generate header")
184
185    isr_depth = stack_size = alignment = hw_stack_percentage = hw_stack_min_size = None
186
187    with open(args.config) as config_fp:
188        config_lines = config_fp.readlines()
189
190    for line in config_lines:
191        if line.startswith("CONFIG_ISR_DEPTH="):
192            isr_depth = int(line.split("=")[1].strip())
193        if line.startswith("CONFIG_ISR_STACK_SIZE="):
194            stack_size = int(line.split("=")[1].strip())
195        if line.startswith("CONFIG_X86_CET_SHADOW_STACK_ALIGNMENT="):
196            alignment = int(line.split("=")[1].strip())
197        if line.startswith("CONFIG_HW_SHADOW_STACK_PERCENTAGE_SIZE="):
198            hw_stack_percentage = int(line.split("=")[1].strip())
199        if line.startswith("CONFIG_HW_SHADOW_STACK_MIN_SIZE="):
200            hw_stack_min_size = int(line.split("=")[1].strip())
201
202    if isr_depth is None:
203        raise ValueError("Missing CONFIG_ISR_DEPTH in configuration file")
204    if stack_size is None:
205        raise ValueError("Missing CONFIG_ISR_STACK_SIZE in configuration file")
206    if alignment is None:
207        raise ValueError("Missing CONFIG_X86_CET_SHADOW_STACK_ALIGNMENT in configuration file")
208    if hw_stack_percentage is None:
209        raise ValueError("Missing CONFIG_HW_SHADOW_STACK_PERCENTAGE_SIZE in configuration file")
210    if hw_stack_min_size is None:
211        raise ValueError("Missing CONFIG_HW_SHADOW_STACK_MIN_SIZE in configuration file")
212
213    stack_size = int(stack_size * (hw_stack_percentage / 100))
214    stack_size = int((stack_size + alignment - 1) / alignment) * alignment
215    stack_size = max(stack_size, hw_stack_min_size)
216    stack_size = int(stack_size / isr_depth)
217
218    with open(args.header_output, "w") as header_fp:
219        header_fp.write(f"#define X86_CET_IRQ_SHADOW_SUBSTACK_SIZE {stack_size}\n")
220
221
222def main():
223    parse_args()
224
225    if args.kernel is not None:
226        patch_elf()
227
228    if args.header_output is not None:
229        generate_header()
230
231
232if __name__ == "__main__":
233    main()
234