1#!/usr/bin/env python3
2#
3# Copyright (c) 2017 Intel Corporation
4#
5# SPDX-License-Identifier: Apache-2.0
6
7"""
8Script to generate system call invocation macros
9
10This script parses the system call metadata JSON file emitted by
11parse_syscalls.py to create several files:
12
13- A file containing weak aliases of any potentially unimplemented system calls,
14  as well as the system call dispatch table, which maps system call type IDs
15  to their handler functions.
16
17- A header file defining the system call type IDs, as well as function
18  prototypes for all system call handler functions.
19
20- A directory containing header files. Each header corresponds to a header
21  that was identified as containing system call declarations. These
22  generated headers contain the inline invocation functions for each system
23  call in that header.
24"""
25
26import sys
27import re
28import argparse
29import os
30import json
31
32# Some kernel headers cannot include automated tracing without causing unintended recursion or
33# other serious issues.
34# These headers typically already have very specific tracing hooks for all relevant things
35# written by hand so are excluded.
36notracing = ["kernel.h", "zephyr/kernel.h", "errno_private.h",
37             "zephyr/errno_private.h"]
38
39types64 = ["int64_t", "uint64_t"]
40
41# The kernel linkage is complicated.  These functions from
42# userspace_handlers.c are present in the kernel .a library after
43# userspace.c, which contains the weak fallbacks defined here.  So the
44# linker finds the weak one first and stops searching, and thus won't
45# see the real implementation which should override.  Yet changing the
46# order runs afoul of a comment in CMakeLists.txt that the order is
47# critical.  These are core syscalls that won't ever be unconfigured,
48# just disable the fallback mechanism as a simple workaround.
49noweak = ["z_mrsh_k_object_release",
50          "z_mrsh_k_object_access_grant",
51          "z_mrsh_k_object_alloc"]
52
53table_template = """/* auto-generated by gen_syscalls.py, don't edit */
54
55#include <zephyr/llext/symbol.h>
56
57/* Weak handler functions that get replaced by the real ones unless a system
58 * call is not implemented due to kernel configuration.
59 */
60%s
61
62const _k_syscall_handler_t _k_syscall_table[K_SYSCALL_LIMIT] = {
63\t%s
64};
65"""
66
67list_template = """/* auto-generated by gen_syscalls.py, don't edit */
68
69#ifndef ZEPHYR_SYSCALL_LIST_H
70#define ZEPHYR_SYSCALL_LIST_H
71
72%s
73
74#ifndef _ASMLANGUAGE
75
76#include <stdarg.h>
77#include <stdint.h>
78
79#endif /* _ASMLANGUAGE */
80
81#endif /* ZEPHYR_SYSCALL_LIST_H */
82"""
83
84syscall_template = """/* auto-generated by gen_syscalls.py, don't edit */
85
86{include_guard}
87
88{tracing_include}
89
90#ifndef _ASMLANGUAGE
91
92#include <stdarg.h>
93
94#include <zephyr/syscall_list.h>
95#include <zephyr/syscall.h>
96
97#include <zephyr/linker/sections.h>
98
99
100#ifdef __cplusplus
101extern "C" {{
102#endif
103
104{invocations}
105
106#ifdef __cplusplus
107}}
108#endif
109
110#endif
111#endif /* include guard */
112"""
113
114handler_template = """
115extern uintptr_t z_hdlr_%s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
116                uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
117"""
118
119weak_template = """
120__weak ALIAS_OF(handler_no_syscall)
121uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
122         uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
123"""
124
125# defines a macro wrapper which supersedes the syscall when used
126# and provides tracing enter/exit hooks while allowing per compilation unit
127# enable/disable of syscall tracing. Used for returning functions
128# Note that the last argument to the exit macro is the return value.
129syscall_tracer_with_return_template = """
130#if defined(CONFIG_TRACING_SYSCALL)
131#ifndef DISABLE_SYSCALL_TRACING
132{trace_diagnostic}
133#define {func_name}({argnames}) ({{ \
134	{func_type} syscall__retval; \
135	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
136	syscall__retval = {func_name}({argnames}); \
137	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}, syscall__retval); \
138	syscall__retval; \
139}})
140#endif
141#endif
142"""
143
144# defines a macro wrapper which supersedes the syscall when used
145# and provides tracing enter/exit hooks while allowing per compilation unit
146# enable/disable of syscall tracing. Used for non-returning (void) functions
147syscall_tracer_void_template = """
148#if defined(CONFIG_TRACING_SYSCALL)
149#ifndef DISABLE_SYSCALL_TRACING
150{trace_diagnostic}
151#define {func_name}({argnames}) do {{ \
152	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
153	{func_name}({argnames}); \
154	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}); \
155}} while(false)
156#endif
157#endif
158"""
159
160
161exported_template = """
162/* Export syscalls for extensions */
163static void * const no_handler = NULL;
164
165/* Weak references, if something is not found by the linker, it will be NULL
166 * and simply fail during extension load
167 */
168%s
169
170/* Exported symbols */
171%s
172"""
173
174typename_regex = re.compile(r'(.*?)([A-Za-z0-9_]+)$')
175
176
177class SyscallParseException(Exception):
178    pass
179
180
181def typename_split(item):
182    if "[" in item:
183        raise SyscallParseException(
184            "Please pass arrays to syscalls as pointers, unable to process '%s'" %
185            item)
186
187    if "(" in item:
188        raise SyscallParseException(
189            "Please use typedefs for function pointers")
190
191    mo = typename_regex.match(item)
192    if not mo:
193        raise SyscallParseException("Malformed system call invocation")
194
195    m = mo.groups()
196    return (m[0].strip(), m[1])
197
198def need_split(argtype):
199    return (not args.long_registers) and (argtype in types64)
200
201# Note: "lo" and "hi" are named in little endian conventions,
202# but it doesn't matter as long as they are consistently
203# generated.
204def union_decl(type, split):
205    middle = "struct { uintptr_t lo, hi; } split" if split else "uintptr_t x"
206    return "union { %s; %s val; }" % (middle, type)
207
208def wrapper_defs(func_name, func_type, args, fn, userspace_only):
209    ret64 = need_split(func_type)
210    mrsh_args = [] # List of rvalue expressions for the marshalled invocation
211
212    decl_arglist = ", ".join([" ".join(argrec) for argrec in args]) or "void"
213    syscall_id = "K_SYSCALL_" + func_name.upper()
214
215    wrap = ''
216    if not userspace_only:
217        wrap += "extern %s z_impl_%s(%s);\n" % (func_type, func_name, decl_arglist)
218        wrap += "\n"
219
220    wrap += "__pinned_func\n"
221    wrap += "static inline %s %s(%s)\n" % (func_type, func_name, decl_arglist)
222    wrap += "{\n"
223    if not userspace_only:
224        wrap += "#ifdef CONFIG_USERSPACE\n"
225
226    wrap += ("\t" + "uint64_t ret64;\n") if ret64 else ""
227    if not userspace_only:
228        wrap += "\t" + "if (z_syscall_trap()) {\n"
229
230    valist_args = []
231    for argnum, (argtype, argname) in enumerate(args):
232        split = need_split(argtype)
233        wrap += "\t\t%s parm%d" % (union_decl(argtype, split), argnum)
234        if argtype != "va_list":
235            wrap += " = { .val = %s };\n" % argname
236        else:
237            # va_list objects are ... peculiar.
238            wrap += ";\n" + "\t\t" + "va_copy(parm%d.val, %s);\n" % (argnum, argname)
239            valist_args.append("parm%d.val" % argnum)
240        if split:
241            mrsh_args.append("parm%d.split.lo" % argnum)
242            mrsh_args.append("parm%d.split.hi" % argnum)
243        else:
244            mrsh_args.append("parm%d.x" % argnum)
245
246    if ret64:
247        mrsh_args.append("(uintptr_t)&ret64")
248
249    if len(mrsh_args) > 6:
250        wrap += "\t\t" + "uintptr_t more[] = {\n"
251        wrap += "\t\t\t" + (",\n\t\t\t".join(mrsh_args[5:])) + "\n"
252        wrap += "\t\t" + "};\n"
253        mrsh_args[5:] = ["(uintptr_t) &more"]
254
255    invoke = ("arch_syscall_invoke%d(%s)"
256              % (len(mrsh_args),
257                 ", ".join(mrsh_args + [syscall_id])))
258
259    if ret64:
260        invoke = "\t\t" + "(void) %s;\n" % invoke
261        retcode = "\t\t" + "return (%s) ret64;\n" % func_type
262    elif func_type == "void":
263        invoke = "\t\t" + "(void) %s;\n" % invoke
264        retcode = "\t\t" + "return;\n"
265    elif valist_args:
266        invoke = "\t\t" + "%s invoke__retval = %s;\n" % (func_type, invoke)
267        retcode = "\t\t" + "return invoke__retval;\n"
268    else:
269        invoke = "\t\t" + "return (%s) %s;\n" % (func_type, invoke)
270        retcode = ""
271
272    wrap += invoke
273    for argname in valist_args:
274        wrap += "\t\t" + "va_end(%s);\n" % argname
275    wrap += retcode
276    if not userspace_only:
277        wrap += "\t" + "}\n"
278        wrap += "#endif\n"
279
280        # Otherwise fall through to direct invocation of the impl func.
281        # Note the compiler barrier: that is required to prevent code from
282        # the impl call from being hoisted above the check for user
283        # context.
284        impl_arglist = ", ".join([argrec[1] for argrec in args])
285        impl_call = "z_impl_%s(%s)" % (func_name, impl_arglist)
286        wrap += "\t" + "compiler_barrier();\n"
287        wrap += "\t" + "%s%s;\n" % ("return " if func_type != "void" else "",
288                                   impl_call)
289
290    wrap += "}\n"
291
292    if fn not in notracing:
293        argnames = ", ".join([f"{argname}" for _, argname in args])
294        trace_argnames = ""
295        if len(args) > 0:
296            trace_argnames = ", " + argnames
297        trace_diagnostic = ""
298        if os.getenv('TRACE_DIAGNOSTICS'):
299            trace_diagnostic = f"#warning Tracing {func_name}"
300        if func_type != "void":
301            wrap += syscall_tracer_with_return_template.format(func_type=func_type, func_name=func_name,
302                                                               argnames=argnames, trace_argnames=trace_argnames,
303                                                               syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
304        else:
305            wrap += syscall_tracer_void_template.format(func_type=func_type, func_name=func_name,
306                                                        argnames=argnames, trace_argnames=trace_argnames,
307                                                        syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
308
309    return wrap
310
311# Returns an expression for the specified (zero-indexed!) marshalled
312# parameter to a syscall, with handling for a final "more" parameter.
313def mrsh_rval(mrsh_num, total):
314    if mrsh_num < 5 or total <= 6:
315        return "arg%d" % mrsh_num
316    else:
317        return "(((uintptr_t *)more)[%d])" % (mrsh_num - 5)
318
319def marshall_defs(func_name, func_type, args):
320    mrsh_name = "z_mrsh_" + func_name
321
322    nmrsh = 0        # number of marshalled uintptr_t parameter
323    vrfy_parms = []  # list of (argtype, bool_is_split)
324    for (argtype, _) in args:
325        split = need_split(argtype)
326        vrfy_parms.append((argtype, split))
327        nmrsh += 2 if split else 1
328
329    # Final argument for a 64 bit return value?
330    if need_split(func_type):
331        nmrsh += 1
332
333    decl_arglist = ", ".join([" ".join(argrec) for argrec in args])
334    mrsh = "extern %s z_vrfy_%s(%s);\n" % (func_type, func_name, decl_arglist)
335
336    mrsh += "uintptr_t %s(uintptr_t arg0, uintptr_t arg1, uintptr_t arg2,\n" % mrsh_name
337    if nmrsh <= 6:
338        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, void *ssf)\n"
339    else:
340        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, void *more, void *ssf)\n"
341    mrsh += "{\n"
342    mrsh += "\t" + "_current->syscall_frame = ssf;\n"
343
344    for unused_arg in range(nmrsh, 6):
345        mrsh += "\t(void) arg%d;\t/* unused */\n" % unused_arg
346
347    if nmrsh > 6:
348        mrsh += ("\tK_OOPS(K_SYSCALL_MEMORY_READ(more, "
349                 + str(nmrsh - 5) + " * sizeof(uintptr_t)));\n")
350
351    argnum = 0
352    for i, (argtype, split) in enumerate(vrfy_parms):
353        mrsh += "\t%s parm%d;\n" % (union_decl(argtype, split), i)
354        if split:
355            mrsh += "\t" + "parm%d.split.lo = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
356            argnum += 1
357            mrsh += "\t" + "parm%d.split.hi = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
358        else:
359            mrsh += "\t" + "parm%d.x = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
360        argnum += 1
361
362    # Finally, invoke the verify function
363    out_args = ", ".join(["parm%d.val" % i for i in range(len(args))])
364    vrfy_call = "z_vrfy_%s(%s)" % (func_name, out_args)
365
366    if func_type == "void":
367        mrsh += "\t" + "%s;\n" % vrfy_call
368        mrsh += "\t" + "_current->syscall_frame = NULL;\n"
369        mrsh += "\t" + "return 0;\n"
370    else:
371        mrsh += "\t" + "%s ret = %s;\n" % (func_type, vrfy_call)
372
373        if need_split(func_type):
374            ptr = "((uint64_t *)%s)" % mrsh_rval(nmrsh - 1, nmrsh)
375            mrsh += "\t" + "K_OOPS(K_SYSCALL_MEMORY_WRITE(%s, 8));\n" % ptr
376            mrsh += "\t" + "*%s = ret;\n" % ptr
377            mrsh += "\t" + "_current->syscall_frame = NULL;\n"
378            mrsh += "\t" + "return 0;\n"
379        else:
380            mrsh += "\t" + "_current->syscall_frame = NULL;\n"
381            mrsh += "\t" + "return (uintptr_t) ret;\n"
382
383    mrsh += "}\n"
384
385    return mrsh, mrsh_name
386
387def analyze_fn(match_group, fn, userspace_only):
388    func, args = match_group
389
390    try:
391        if args == "void":
392            args = []
393        else:
394            args = [typename_split(a.strip()) for a in args.split(",")]
395
396        func_type, func_name = typename_split(func)
397    except SyscallParseException:
398        sys.stderr.write("In declaration of %s\n" % func)
399        raise
400
401    sys_id = "K_SYSCALL_" + func_name.upper()
402
403    marshaller = None
404    marshaller, handler = marshall_defs(func_name, func_type, args)
405    invocation = wrapper_defs(func_name, func_type, args, fn, userspace_only)
406
407    # Entry in _k_syscall_table
408    table_entry = "[%s] = %s" % (sys_id, handler)
409
410    return (handler, invocation, marshaller, sys_id, table_entry)
411
412def parse_args():
413    global args
414    parser = argparse.ArgumentParser(
415        description=__doc__,
416        formatter_class=argparse.RawDescriptionHelpFormatter, allow_abbrev=False)
417
418    parser.add_argument("-i", "--json-file", required=True,
419                        help="Read syscall information from json file")
420    parser.add_argument("-d", "--syscall-dispatch", required=True,
421                        help="output C system call dispatch table file")
422    parser.add_argument("-l", "--syscall-list", required=True,
423                        help="output C system call list header")
424    parser.add_argument("-o", "--base-output", required=True,
425                        help="Base output directory for syscall macro headers")
426    parser.add_argument("-s", "--split-type", action="append",
427                        help="A long type that must be split/marshalled on 32-bit systems")
428    parser.add_argument("-x", "--long-registers", action="store_true",
429                        help="Indicates we are on system with 64-bit registers")
430    parser.add_argument("--gen-mrsh-files", action="store_true",
431                        help="Generate marshalling files (*_mrsh.c)")
432    parser.add_argument("-e", "--syscall-export-llext",
433                        help="output C system call export for extensions")
434    parser.add_argument("-u", "--userspace-only", action="store_true",
435                        help="Only generate the userpace path of wrappers")
436    args = parser.parse_args()
437
438
439def main():
440    parse_args()
441
442    if args.split_type is not None:
443        for t in args.split_type:
444            types64.append(t)
445
446    with open(args.json_file, 'r') as fd:
447        syscalls = json.load(fd)
448
449    invocations = {}
450    mrsh_defs = {}
451    mrsh_includes = {}
452    ids_emit = []
453    ids_not_emit = []
454    table_entries = []
455    handlers = []
456    emit_list = []
457    exported = []
458
459    for match_group, fn, to_emit in syscalls:
460        handler, inv, mrsh, sys_id, entry = analyze_fn(match_group, fn, args.userspace_only)
461
462        if fn not in invocations:
463            invocations[fn] = []
464
465        invocations[fn].append(inv)
466        handlers.append(handler)
467
468        if to_emit:
469            ids_emit.append(sys_id)
470            table_entries.append(entry)
471            emit_list.append(handler)
472            exported.append(handler.replace("z_mrsh_", "z_impl_"))
473        else:
474            ids_not_emit.append(sys_id)
475
476        if mrsh and to_emit:
477            syscall = typename_split(match_group[0])[1]
478            mrsh_defs[syscall] = mrsh
479            mrsh_includes[syscall] = "#include <zephyr/syscalls/%s>" % fn
480
481    with open(args.syscall_dispatch, "w") as fp:
482        table_entries.append("[K_SYSCALL_BAD] = handler_bad_syscall")
483
484        weak_defines = "".join([weak_template % name
485                                for name in handlers
486                                if not name in noweak and name in emit_list])
487
488        # The "noweak" ones just get a regular declaration
489        weak_defines += "\n".join(["extern uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);"
490                                   % s for s in noweak])
491
492        fp.write(table_template % (weak_defines,
493                                   ",\n\t".join(table_entries)))
494
495    if args.syscall_export_llext:
496        with open(args.syscall_export_llext, "w") as fp:
497            # Export symbols for emitted syscalls
498            weak_refs = "\n".join("extern __weak ALIAS_OF(no_handler) void * const %s;"
499                                  % e for e in exported)
500            exported_symbols = "\n".join("EXPORT_SYMBOL(%s);"
501                                         % e for e in exported)
502            fp.write(exported_template % (weak_refs, exported_symbols))
503
504    # Listing header emitted to stdout
505    ids_emit.sort()
506    ids_emit.extend(["K_SYSCALL_BAD", "K_SYSCALL_LIMIT"])
507
508    ids_as_defines = ""
509    for i, item in enumerate(ids_emit):
510        ids_as_defines += "#define {} {}\n".format(item, i)
511
512    if ids_not_emit:
513        # There are syscalls that are not used in the image but
514        # their IDs are used in the generated stubs. So need to
515        # make them usable but outside the syscall ID range.
516        ids_as_defines += "\n\n/* Following syscalls are not used in image */\n"
517        ids_not_emit.sort()
518        num_emitted_ids = len(ids_emit)
519        for i, item in enumerate(ids_not_emit):
520            ids_as_defines += "#define {} {}\n".format(item, i + num_emitted_ids)
521
522    with open(args.syscall_list, "w") as fp:
523        fp.write(list_template % ids_as_defines)
524
525    os.makedirs(args.base_output, exist_ok=True)
526    for fn, invo_list in invocations.items():
527        out_fn = os.path.join(args.base_output, fn)
528
529        ig = re.sub("[^a-zA-Z0-9]", "_", "Z_INCLUDE_SYSCALLS_" + fn).upper()
530        include_guard = "#ifndef %s\n#define %s\n" % (ig, ig)
531        tracing_include = ""
532        if fn not in notracing:
533            tracing_include = "#include <zephyr/tracing/tracing_syscall.h>"
534        header = syscall_template.format(include_guard=include_guard, tracing_include=tracing_include, invocations="\n\n".join(invo_list))
535
536        with open(out_fn, "w") as fp:
537            fp.write(header)
538
539    # Likewise emit _mrsh.c files for syscall inclusion
540    if args.gen_mrsh_files:
541        for fn in mrsh_defs:
542            mrsh_fn = os.path.join(args.base_output, fn + "_mrsh.c")
543
544            with open(mrsh_fn, "w") as fp:
545                fp.write("/* auto-generated by gen_syscalls.py, don't edit */\n\n")
546                fp.write(mrsh_includes[fn] + "\n")
547                fp.write("\n")
548                fp.write(mrsh_defs[fn] + "\n")
549
550if __name__ == "__main__":
551    main()
552