1#!/usr/bin/env python3
2#
3# Copyright (c) 2017 Intel Corporation
4#
5# SPDX-License-Identifier: Apache-2.0
6
7"""
8Script to generate system call invocation macros
9
10This script parses the system call metadata JSON file emitted by
11parse_syscalls.py to create several files:
12
13- A file containing weak aliases of any potentially unimplemented system calls,
14  as well as the system call dispatch table, which maps system call type IDs
15  to their handler functions.
16
17- A header file defining the system call type IDs, as well as function
18  prototypes for all system call handler functions.
19
20- A directory containing header files. Each header corresponds to a header
21  that was identified as containing system call declarations. These
22  generated headers contain the inline invocation functions for each system
23  call in that header.
24"""
25
26import sys
27import re
28import argparse
29import os
30import json
31
32# Some kernel headers cannot include automated tracing without causing unintended recursion or
33# other serious issues.
34# These headers typically already have very specific tracing hooks for all relevant things
35# written by hand so are excluded.
36notracing = ["kernel.h", "zephyr/kernel.h", "errno_private.h",
37             "zephyr/errno_private.h"]
38
39types64 = ["int64_t", "uint64_t"]
40
41# The kernel linkage is complicated.  These functions from
42# userspace_handlers.c are present in the kernel .a library after
43# userspace.c, which contains the weak fallbacks defined here.  So the
44# linker finds the weak one first and stops searching, and thus won't
45# see the real implementation which should override.  Yet changing the
46# order runs afoul of a comment in CMakeLists.txt that the order is
47# critical.  These are core syscalls that won't ever be unconfigured,
48# just disable the fallback mechanism as a simple workaround.
49noweak = ["z_mrsh_k_object_release",
50          "z_mrsh_k_object_access_grant",
51          "z_mrsh_k_object_alloc"]
52
53table_template = """/* auto-generated by gen_syscalls.py, don't edit */
54
55#include <zephyr/llext/symbol.h>
56
57/* Weak handler functions that get replaced by the real ones unless a system
58 * call is not implemented due to kernel configuration.
59 */
60%s
61
62const _k_syscall_handler_t _k_syscall_table[K_SYSCALL_LIMIT] = {
63\t%s
64};
65"""
66
67list_template = """/* auto-generated by gen_syscalls.py, don't edit */
68
69#ifndef ZEPHYR_SYSCALL_LIST_H
70#define ZEPHYR_SYSCALL_LIST_H
71
72%s
73
74#ifndef _ASMLANGUAGE
75
76#include <stdarg.h>
77#include <stdint.h>
78
79#endif /* _ASMLANGUAGE */
80
81#endif /* ZEPHYR_SYSCALL_LIST_H */
82"""
83
84syscall_template = """/* auto-generated by gen_syscalls.py, don't edit */
85
86{include_guard}
87
88{tracing_include}
89
90#ifndef _ASMLANGUAGE
91
92#include <stdarg.h>
93
94#include <zephyr/syscall_list.h>
95#include <zephyr/syscall.h>
96
97#include <zephyr/linker/sections.h>
98
99
100#ifdef __cplusplus
101extern "C" {{
102#endif
103
104{invocations}
105
106#ifdef __cplusplus
107}}
108#endif
109
110#endif
111#endif /* include guard */
112"""
113
114handler_template = """
115extern uintptr_t z_hdlr_%s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
116                uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
117"""
118
119weak_template = """
120__weak ALIAS_OF(handler_no_syscall)
121uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
122         uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
123"""
124
125# defines a macro wrapper which supersedes the syscall when used
126# and provides tracing enter/exit hooks while allowing per compilation unit
127# enable/disable of syscall tracing. Used for returning functions
128# Note that the last argument to the exit macro is the return value.
129syscall_tracer_with_return_template = """
130#if defined(CONFIG_TRACING_SYSCALL)
131#ifndef DISABLE_SYSCALL_TRACING
132{trace_diagnostic}
133#define {func_name}({argnames}) ({{ \
134	{func_type} syscall__retval; \
135	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
136	syscall__retval = {func_name}({argnames}); \
137	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}, syscall__retval); \
138	syscall__retval; \
139}})
140#endif
141#endif
142"""
143
144# defines a macro wrapper which supersedes the syscall when used
145# and provides tracing enter/exit hooks while allowing per compilation unit
146# enable/disable of syscall tracing. Used for non-returning (void) functions
147syscall_tracer_void_template = """
148#if defined(CONFIG_TRACING_SYSCALL)
149#ifndef DISABLE_SYSCALL_TRACING
150{trace_diagnostic}
151#define {func_name}({argnames}) do {{ \
152	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
153	{func_name}({argnames}); \
154	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}); \
155}} while(false)
156#endif
157#endif
158"""
159
160
161llext_weakdefs_template = """/* auto-generated by gen_syscalls.py, don't edit */
162
163#include <zephyr/toolchain.h>
164#include <zephyr/llext/symbol.h>
165
166/*
167 * This symbol is placed at address 0 by llext-sections.ld. Its value and
168 * type is not important, we are only interested in its location.
169 */
170static void * const no_syscall_impl Z_GENERIC_SECTION(llext_no_syscall_impl);
171
172/*
173 * Weak references to all syscall implementations. Those not found by the
174 * linker outside this file will be exported as NULL and simply fail when
175 * an extension requiring them is loaded.
176 */
177%s
178"""
179
180
181llext_exports_template = """/* auto-generated by gen_syscalls.py, don't edit */
182
183/*
184 * Export the implementation functions of all emitted syscalls.
185 * Only the symbol names are relevant in this file, they will be
186 * resolved to the actual implementation functions by the linker.
187 */
188
189/* Symbol declarations */
190%s
191
192/* Exported symbols */
193%s
194"""
195
196typename_regex = re.compile(r'(.*?)([A-Za-z0-9_]+)$')
197
198
199class SyscallParseException(Exception):
200    pass
201
202
203def typename_split(item):
204    item = item.strip().replace("\n", " ")
205    if "[" in item:
206        raise SyscallParseException(
207            "Please pass arrays to syscalls as pointers, unable to process '%s'" %
208            item)
209
210    if "(" in item:
211        raise SyscallParseException(
212            "Please use typedefs for function pointers")
213
214    mo = typename_regex.match(item)
215    if not mo:
216        raise SyscallParseException("Malformed system call invocation")
217
218    m = mo.groups()
219    return (m[0].strip(), m[1])
220
221def need_split(argtype):
222    return (not args.long_registers) and (argtype in types64)
223
224# Note: "lo" and "hi" are named in little endian conventions,
225# but it doesn't matter as long as they are consistently
226# generated.
227def union_decl(type, split):
228    middle = "struct { uintptr_t lo, hi; } split" if split else "uintptr_t x"
229    return "union { %s; %s val; }" % (middle, type)
230
231def wrapper_defs(func_name, func_type, args, fn, userspace_only):
232    ret64 = need_split(func_type)
233    mrsh_args = [] # List of rvalue expressions for the marshalled invocation
234
235    decl_arglist = ", ".join([" ".join(argrec) for argrec in args]) or "void"
236    syscall_id = "K_SYSCALL_" + func_name.upper()
237
238    wrap = ''
239    if not userspace_only:
240        wrap += "extern %s z_impl_%s(%s);\n" % (func_type, func_name, decl_arglist)
241        wrap += "\n"
242
243    wrap += "__pinned_func\n"
244    wrap += "static inline %s %s(%s)\n" % (func_type, func_name, decl_arglist)
245    wrap += "{\n"
246    if not userspace_only:
247        wrap += "#ifdef CONFIG_USERSPACE\n"
248
249    wrap += ("\t" + "uint64_t ret64;\n") if ret64 else ""
250    if not userspace_only:
251        wrap += "\t" + "if (z_syscall_trap()) {\n"
252
253    valist_args = []
254    for argnum, (argtype, argname) in enumerate(args):
255        split = need_split(argtype)
256        wrap += "\t\t%s parm%d" % (union_decl(argtype, split), argnum)
257        if argtype != "va_list":
258            wrap += " = { .val = %s };\n" % argname
259        else:
260            # va_list objects are ... peculiar.
261            wrap += ";\n" + "\t\t" + "va_copy(parm%d.val, %s);\n" % (argnum, argname)
262            valist_args.append("parm%d.val" % argnum)
263        if split:
264            mrsh_args.append("parm%d.split.lo" % argnum)
265            mrsh_args.append("parm%d.split.hi" % argnum)
266        else:
267            mrsh_args.append("parm%d.x" % argnum)
268
269    if ret64:
270        mrsh_args.append("(uintptr_t)&ret64")
271
272    if len(mrsh_args) > 6:
273        wrap += "\t\t" + "uintptr_t more[] = {\n"
274        wrap += "\t\t\t" + (",\n\t\t\t".join(mrsh_args[5:])) + "\n"
275        wrap += "\t\t" + "};\n"
276        mrsh_args[5:] = ["(uintptr_t) &more"]
277
278    invoke = ("arch_syscall_invoke%d(%s)"
279              % (len(mrsh_args),
280                 ", ".join(mrsh_args + [syscall_id])))
281
282    if ret64:
283        invoke = "\t\t" + "(void) %s;\n" % invoke
284        retcode = "\t\t" + "return (%s) ret64;\n" % func_type
285    elif func_type == "void":
286        invoke = "\t\t" + "(void) %s;\n" % invoke
287        retcode = "\t\t" + "return;\n"
288    elif valist_args:
289        invoke = "\t\t" + "%s invoke__retval = %s;\n" % (func_type, invoke)
290        retcode = "\t\t" + "return invoke__retval;\n"
291    else:
292        invoke = "\t\t" + "return (%s) %s;\n" % (func_type, invoke)
293        retcode = ""
294
295    wrap += invoke
296    for argname in valist_args:
297        wrap += "\t\t" + "va_end(%s);\n" % argname
298    wrap += retcode
299    if not userspace_only:
300        wrap += "\t" + "}\n"
301        wrap += "#endif\n"
302
303        # Otherwise fall through to direct invocation of the impl func.
304        # Note the compiler barrier: that is required to prevent code from
305        # the impl call from being hoisted above the check for user
306        # context.
307        impl_arglist = ", ".join([argrec[1] for argrec in args])
308        impl_call = "z_impl_%s(%s)" % (func_name, impl_arglist)
309        wrap += "\t" + "compiler_barrier();\n"
310        wrap += "\t" + "%s%s;\n" % ("return " if func_type != "void" else "",
311                                   impl_call)
312
313    wrap += "}\n"
314
315    if fn not in notracing:
316        argnames = ", ".join([f"{argname}" for _, argname in args])
317        trace_argnames = ""
318        if len(args) > 0:
319            trace_argnames = ", " + argnames
320        trace_diagnostic = ""
321        if os.getenv('TRACE_DIAGNOSTICS'):
322            trace_diagnostic = f"#warning Tracing {func_name}"
323        if func_type != "void":
324            wrap += syscall_tracer_with_return_template.format(func_type=func_type, func_name=func_name,
325                                                               argnames=argnames, trace_argnames=trace_argnames,
326                                                               syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
327        else:
328            wrap += syscall_tracer_void_template.format(func_type=func_type, func_name=func_name,
329                                                        argnames=argnames, trace_argnames=trace_argnames,
330                                                        syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
331
332    return wrap
333
334# Returns an expression for the specified (zero-indexed!) marshalled
335# parameter to a syscall, with handling for a final "more" parameter.
336def mrsh_rval(mrsh_num, total):
337    if mrsh_num < 5 or total <= 6:
338        return "arg%d" % mrsh_num
339    else:
340        return "(((uintptr_t *)more)[%d])" % (mrsh_num - 5)
341
342def marshall_defs(func_name, func_type, args):
343    mrsh_name = "z_mrsh_" + func_name
344
345    nmrsh = 0        # number of marshalled uintptr_t parameter
346    vrfy_parms = []  # list of (argtype, bool_is_split)
347    for (argtype, _) in args:
348        split = need_split(argtype)
349        vrfy_parms.append((argtype, split))
350        nmrsh += 2 if split else 1
351
352    # Final argument for a 64 bit return value?
353    if need_split(func_type):
354        nmrsh += 1
355
356    decl_arglist = ", ".join([" ".join(argrec) for argrec in args])
357    mrsh = "extern %s z_vrfy_%s(%s);\n" % (func_type, func_name, decl_arglist)
358
359    mrsh += "uintptr_t %s(uintptr_t arg0, uintptr_t arg1, uintptr_t arg2,\n" % mrsh_name
360    if nmrsh <= 6:
361        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, void *ssf)\n"
362    else:
363        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, void *more, void *ssf)\n"
364    mrsh += "{\n"
365    mrsh += "\t" + "arch_current_thread()->syscall_frame = ssf;\n"
366
367    for unused_arg in range(nmrsh, 6):
368        mrsh += "\t(void) arg%d;\t/* unused */\n" % unused_arg
369
370    if nmrsh > 6:
371        mrsh += ("\tK_OOPS(K_SYSCALL_MEMORY_READ(more, "
372                 + str(nmrsh - 5) + " * sizeof(uintptr_t)));\n")
373
374    argnum = 0
375    for i, (argtype, split) in enumerate(vrfy_parms):
376        mrsh += "\t%s parm%d;\n" % (union_decl(argtype, split), i)
377        if split:
378            mrsh += "\t" + "parm%d.split.lo = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
379            argnum += 1
380            mrsh += "\t" + "parm%d.split.hi = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
381        else:
382            mrsh += "\t" + "parm%d.x = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
383        argnum += 1
384
385    # Finally, invoke the verify function
386    out_args = ", ".join(["parm%d.val" % i for i in range(len(args))])
387    vrfy_call = "z_vrfy_%s(%s)" % (func_name, out_args)
388
389    if func_type == "void":
390        mrsh += "\t" + "%s;\n" % vrfy_call
391        mrsh += "\t" + "arch_current_thread()->syscall_frame = NULL;\n"
392        mrsh += "\t" + "return 0;\n"
393    else:
394        mrsh += "\t" + "%s ret = %s;\n" % (func_type, vrfy_call)
395
396        if need_split(func_type):
397            ptr = "((uint64_t *)%s)" % mrsh_rval(nmrsh - 1, nmrsh)
398            mrsh += "\t" + "K_OOPS(K_SYSCALL_MEMORY_WRITE(%s, 8));\n" % ptr
399            mrsh += "\t" + "*%s = ret;\n" % ptr
400            mrsh += "\t" + "arch_current_thread()->syscall_frame = NULL;\n"
401            mrsh += "\t" + "return 0;\n"
402        else:
403            mrsh += "\t" + "arch_current_thread()->syscall_frame = NULL;\n"
404            mrsh += "\t" + "return (uintptr_t) ret;\n"
405
406    mrsh += "}\n"
407
408    return mrsh, mrsh_name
409
410def analyze_fn(match_group, fn, userspace_only):
411    func, args = match_group
412
413    try:
414        if args == "void":
415            args = []
416        else:
417            args = [typename_split(a) for a in args.split(",")]
418
419        func_type, func_name = typename_split(func)
420    except SyscallParseException:
421        sys.stderr.write("In declaration of %s\n" % func)
422        raise
423
424    sys_id = "K_SYSCALL_" + func_name.upper()
425
426    marshaller = None
427    marshaller, handler = marshall_defs(func_name, func_type, args)
428    invocation = wrapper_defs(func_name, func_type, args, fn, userspace_only)
429
430    # Entry in _k_syscall_table
431    table_entry = "[%s] = %s" % (sys_id, handler)
432
433    return (handler, invocation, marshaller, sys_id, table_entry)
434
435def parse_args():
436    global args
437    parser = argparse.ArgumentParser(
438        description=__doc__,
439        formatter_class=argparse.RawDescriptionHelpFormatter, allow_abbrev=False)
440
441    parser.add_argument("-i", "--json-file", required=True,
442                        help="Read syscall information from json file")
443    parser.add_argument("-d", "--syscall-dispatch", required=True,
444                        help="output C system call dispatch table file")
445    parser.add_argument("-l", "--syscall-list", required=True,
446                        help="output C system call list header")
447    parser.add_argument("-o", "--base-output", required=True,
448                        help="Base output directory for syscall macro headers")
449    parser.add_argument("-s", "--split-type", action="append",
450                        help="A long type that must be split/marshalled on 32-bit systems")
451    parser.add_argument("-x", "--long-registers", action="store_true",
452                        help="Indicates we are on system with 64-bit registers")
453    parser.add_argument("--gen-mrsh-files", action="store_true",
454                        help="Generate marshalling files (*_mrsh.c)")
455    parser.add_argument("-e", "--syscall-exports-llext",
456                        help="output C system call export for extensions")
457    parser.add_argument("-w", "--syscall-weakdefs-llext",
458                        help="output C system call weak definitions")
459    parser.add_argument("-u", "--userspace-only", action="store_true",
460                        help="Only generate the userpace path of wrappers")
461    args = parser.parse_args()
462
463
464def main():
465    parse_args()
466
467    if args.split_type is not None:
468        for t in args.split_type:
469            types64.append(t)
470
471    with open(args.json_file, 'r') as fd:
472        syscalls = json.load(fd)
473
474    invocations = {}
475    mrsh_defs = {}
476    mrsh_includes = {}
477    ids_emit = []
478    ids_not_emit = []
479    table_entries = []
480    handlers = []
481    emit_list = []
482    exported = []
483
484    for match_group, fn, to_emit in syscalls:
485        handler, inv, mrsh, sys_id, entry = analyze_fn(match_group, fn, args.userspace_only)
486
487        if fn not in invocations:
488            invocations[fn] = []
489
490        invocations[fn].append(inv)
491        handlers.append(handler)
492
493        if to_emit:
494            ids_emit.append(sys_id)
495            table_entries.append(entry)
496            emit_list.append(handler)
497            exported.append(handler.replace("z_mrsh_", "z_impl_"))
498        else:
499            ids_not_emit.append(sys_id)
500
501        if mrsh and to_emit:
502            syscall = typename_split(match_group[0])[1]
503            mrsh_defs[syscall] = mrsh
504            mrsh_includes[syscall] = "#include <zephyr/syscalls/%s>" % fn
505
506    with open(args.syscall_dispatch, "w") as fp:
507        table_entries.append("[K_SYSCALL_BAD] = handler_bad_syscall")
508
509        weak_defines = "".join([weak_template % name
510                                for name in handlers
511                                if not name in noweak and name in emit_list])
512
513        # The "noweak" ones just get a regular declaration
514        weak_defines += "\n".join(["extern uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);"
515                                   % s for s in noweak])
516
517        fp.write(table_template % (weak_defines,
518                                   ",\n\t".join(table_entries)))
519
520    exported.sort()
521
522    if args.syscall_weakdefs_llext:
523        with open(args.syscall_weakdefs_llext, "w") as fp:
524            # Provide weak definitions for all emitted syscalls
525            weak_refs = "\n".join("extern __weak ALIAS_OF(no_syscall_impl) void * const %s;"
526                                  % e for e in exported)
527            fp.write(llext_weakdefs_template % weak_refs)
528
529    if args.syscall_exports_llext:
530        with open(args.syscall_exports_llext, "w") as fp:
531            # Export symbols for emitted syscalls
532            extern_refs = "\n".join("extern void * const %s;"
533                                  % e for e in exported)
534            exported_symbols = "\n".join("EXPORT_SYMBOL(%s);"
535                                         % e for e in exported)
536            fp.write(llext_exports_template % (extern_refs, exported_symbols))
537
538    # Listing header emitted to stdout
539    ids_emit.sort()
540    ids_emit.extend(["K_SYSCALL_BAD", "K_SYSCALL_LIMIT"])
541
542    ids_as_defines = ""
543    for i, item in enumerate(ids_emit):
544        ids_as_defines += "#define {} {}\n".format(item, i)
545
546    if ids_not_emit:
547        # There are syscalls that are not used in the image but
548        # their IDs are used in the generated stubs. So need to
549        # make them usable but outside the syscall ID range.
550        ids_as_defines += "\n\n/* Following syscalls are not used in image */\n"
551        ids_not_emit.sort()
552        num_emitted_ids = len(ids_emit)
553        for i, item in enumerate(ids_not_emit):
554            ids_as_defines += "#define {} {}\n".format(item, i + num_emitted_ids)
555
556    with open(args.syscall_list, "w") as fp:
557        fp.write(list_template % ids_as_defines)
558
559    os.makedirs(args.base_output, exist_ok=True)
560    for fn, invo_list in invocations.items():
561        out_fn = os.path.join(args.base_output, fn)
562
563        ig = re.sub("[^a-zA-Z0-9]", "_", "Z_INCLUDE_SYSCALLS_" + fn).upper()
564        include_guard = "#ifndef %s\n#define %s\n" % (ig, ig)
565        tracing_include = ""
566        if fn not in notracing:
567            tracing_include = "#include <zephyr/tracing/tracing_syscall.h>"
568        header = syscall_template.format(include_guard=include_guard, tracing_include=tracing_include, invocations="\n\n".join(invo_list))
569
570        with open(out_fn, "w") as fp:
571            fp.write(header)
572
573    # Likewise emit _mrsh.c files for syscall inclusion
574    if args.gen_mrsh_files:
575        for fn in mrsh_defs:
576            mrsh_fn = os.path.join(args.base_output, fn + "_mrsh.c")
577
578            with open(mrsh_fn, "w") as fp:
579                fp.write("/* auto-generated by gen_syscalls.py, don't edit */\n\n")
580                fp.write(mrsh_includes[fn] + "\n")
581                fp.write("\n")
582                fp.write(mrsh_defs[fn] + "\n")
583
584if __name__ == "__main__":
585    main()
586