1#!/usr/bin/env python3
2#
3# Copyright (c) 2017 Intel Corporation
4#
5# SPDX-License-Identifier: Apache-2.0
6
7"""
8Script to generate system call invocation macros
9
10This script parses the system call metadata JSON file emitted by
11parse_syscalls.py to create several files:
12
13- A file containing weak aliases of any potentially unimplemented system calls,
14  as well as the system call dispatch table, which maps system call type IDs
15  to their handler functions.
16
17- A header file defining the system call type IDs, as well as function
18  prototypes for all system call handler functions.
19
20- A directory containing header files. Each header corresponds to a header
21  that was identified as containing system call declarations. These
22  generated headers contain the inline invocation functions for each system
23  call in that header.
24"""
25
26import sys
27import re
28import argparse
29import os
30import json
31
32# Some kernel headers cannot include automated tracing without causing unintended recursion or
33# other serious issues.
34# These headers typically already have very specific tracing hooks for all relevant things
35# written by hand so are excluded.
36notracing = ["kernel.h", "zephyr/kernel.h", "errno_private.h",
37             "zephyr/errno_private.h"]
38
39types64 = ["int64_t", "uint64_t"]
40
41# The kernel linkage is complicated.  These functions from
42# userspace_handlers.c are present in the kernel .a library after
43# userspace.c, which contains the weak fallbacks defined here.  So the
44# linker finds the weak one first and stops searching, and thus won't
45# see the real implementation which should override.  Yet changing the
46# order runs afoul of a comment in CMakeLists.txt that the order is
47# critical.  These are core syscalls that won't ever be unconfigured,
48# just disable the fallback mechanism as a simple workaround.
49noweak = ["z_mrsh_k_object_release",
50          "z_mrsh_k_object_access_grant",
51          "z_mrsh_k_object_alloc"]
52
53table_template = """/* auto-generated by gen_syscalls.py, don't edit */
54
55/* Weak handler functions that get replaced by the real ones unless a system
56 * call is not implemented due to kernel configuration.
57 */
58%s
59
60const _k_syscall_handler_t _k_syscall_table[K_SYSCALL_LIMIT] = {
61\t%s
62};
63"""
64
65list_template = """/* auto-generated by gen_syscalls.py, don't edit */
66
67#ifndef ZEPHYR_SYSCALL_LIST_H
68#define ZEPHYR_SYSCALL_LIST_H
69
70%s
71
72#ifndef _ASMLANGUAGE
73
74#include <stdint.h>
75
76#endif /* _ASMLANGUAGE */
77
78#endif /* ZEPHYR_SYSCALL_LIST_H */
79"""
80
81syscall_template = """/* auto-generated by gen_syscalls.py, don't edit */
82
83{include_guard}
84
85{tracing_include}
86
87#ifndef _ASMLANGUAGE
88
89#include <syscall_list.h>
90#include <zephyr/syscall.h>
91
92#include <zephyr/linker/sections.h>
93
94
95#ifdef __cplusplus
96extern "C" {{
97#endif
98
99{invocations}
100
101#ifdef __cplusplus
102}}
103#endif
104
105#endif
106#endif /* include guard */
107"""
108
109handler_template = """
110extern uintptr_t z_hdlr_%s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
111                uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
112"""
113
114weak_template = """
115__weak ALIAS_OF(handler_no_syscall)
116uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3,
117         uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);
118"""
119
120# defines a macro wrapper which supersedes the syscall when used
121# and provides tracing enter/exit hooks while allowing per compilation unit
122# enable/disable of syscall tracing. Used for returning functions
123# Note that the last argument to the exit macro is the return value.
124syscall_tracer_with_return_template = """
125#if defined(CONFIG_TRACING_SYSCALL)
126#ifndef DISABLE_SYSCALL_TRACING
127{trace_diagnostic}
128#define {func_name}({argnames}) ({{ \
129	{func_type} retval; \
130	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
131	retval = {func_name}({argnames}); \
132	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}, retval); \
133	retval; \
134}})
135#endif
136#endif
137"""
138
139# defines a macro wrapper which supersedes the syscall when used
140# and provides tracing enter/exit hooks while allowing per compilation unit
141# enable/disable of syscall tracing. Used for non-returning (void) functions
142syscall_tracer_void_template = """
143#if defined(CONFIG_TRACING_SYSCALL)
144#ifndef DISABLE_SYSCALL_TRACING
145{trace_diagnostic}
146#define {func_name}({argnames}) do {{ \
147	sys_port_trace_syscall_enter({syscall_id}, {func_name}{trace_argnames}); \
148	{func_name}({argnames}); \
149	sys_port_trace_syscall_exit({syscall_id}, {func_name}{trace_argnames}); \
150}} while(false)
151#endif
152#endif
153"""
154
155typename_regex = re.compile(r'(.*?)([A-Za-z0-9_]+)$')
156
157
158class SyscallParseException(Exception):
159    pass
160
161
162def typename_split(item):
163    if "[" in item:
164        raise SyscallParseException(
165            "Please pass arrays to syscalls as pointers, unable to process '%s'" %
166            item)
167
168    if "(" in item:
169        raise SyscallParseException(
170            "Please use typedefs for function pointers")
171
172    mo = typename_regex.match(item)
173    if not mo:
174        raise SyscallParseException("Malformed system call invocation")
175
176    m = mo.groups()
177    return (m[0].strip(), m[1])
178
179def need_split(argtype):
180    return (not args.long_registers) and (argtype in types64)
181
182# Note: "lo" and "hi" are named in little endian conventions,
183# but it doesn't matter as long as they are consistently
184# generated.
185def union_decl(type, split):
186    middle = "struct { uintptr_t lo, hi; } split" if split else "uintptr_t x"
187    return "union { %s; %s val; }" % (middle, type)
188
189def wrapper_defs(func_name, func_type, args, fn):
190    ret64 = need_split(func_type)
191    mrsh_args = [] # List of rvalue expressions for the marshalled invocation
192
193    decl_arglist = ", ".join([" ".join(argrec) for argrec in args]) or "void"
194    syscall_id = "K_SYSCALL_" + func_name.upper()
195
196    wrap = "extern %s z_impl_%s(%s);\n" % (func_type, func_name, decl_arglist)
197    wrap += "\n"
198    wrap += "__pinned_func\n"
199    wrap += "static inline %s %s(%s)\n" % (func_type, func_name, decl_arglist)
200    wrap += "{\n"
201    wrap += "#ifdef CONFIG_USERSPACE\n"
202    wrap += ("\t" + "uint64_t ret64;\n") if ret64 else ""
203    wrap += "\t" + "if (z_syscall_trap()) {\n"
204
205    valist_args = []
206    for argnum, (argtype, argname) in enumerate(args):
207        split = need_split(argtype)
208        wrap += "\t\t%s parm%d" % (union_decl(argtype, split), argnum)
209        if argtype != "va_list":
210            wrap += " = { .val = %s };\n" % argname
211        else:
212            # va_list objects are ... peculiar.
213            wrap += ";\n" + "\t\t" + "va_copy(parm%d.val, %s);\n" % (argnum, argname)
214            valist_args.append("parm%d.val" % argnum)
215        if split:
216            mrsh_args.append("parm%d.split.lo" % argnum)
217            mrsh_args.append("parm%d.split.hi" % argnum)
218        else:
219            mrsh_args.append("parm%d.x" % argnum)
220
221    if ret64:
222        mrsh_args.append("(uintptr_t)&ret64")
223
224    if len(mrsh_args) > 6:
225        wrap += "\t\t" + "uintptr_t more[] = {\n"
226        wrap += "\t\t\t" + (",\n\t\t\t".join(mrsh_args[5:])) + "\n"
227        wrap += "\t\t" + "};\n"
228        mrsh_args[5:] = ["(uintptr_t) &more"]
229
230    invoke = ("arch_syscall_invoke%d(%s)"
231              % (len(mrsh_args),
232                 ", ".join(mrsh_args + [syscall_id])))
233
234    if ret64:
235        invoke = "\t\t" + "(void) %s;\n" % invoke
236        retcode = "\t\t" + "return (%s) ret64;\n" % func_type
237    elif func_type == "void":
238        invoke = "\t\t" + "(void) %s;\n" % invoke
239        retcode = "\t\t" + "return;\n"
240    elif valist_args:
241        invoke = "\t\t" + "%s retval = %s;\n" % (func_type, invoke)
242        retcode = "\t\t" + "return retval;\n"
243    else:
244        invoke = "\t\t" + "return (%s) %s;\n" % (func_type, invoke)
245        retcode = ""
246
247    wrap += invoke
248    for argname in valist_args:
249        wrap += "\t\t" + "va_end(%s);\n" % argname
250    wrap += retcode
251    wrap += "\t" + "}\n"
252    wrap += "#endif\n"
253
254    # Otherwise fall through to direct invocation of the impl func.
255    # Note the compiler barrier: that is required to prevent code from
256    # the impl call from being hoisted above the check for user
257    # context.
258    impl_arglist = ", ".join([argrec[1] for argrec in args])
259    impl_call = "z_impl_%s(%s)" % (func_name, impl_arglist)
260    wrap += "\t" + "compiler_barrier();\n"
261    wrap += "\t" + "%s%s;\n" % ("return " if func_type != "void" else "",
262                               impl_call)
263
264    wrap += "}\n"
265
266    if fn not in notracing:
267        argnames = ", ".join([f"{argname}" for _, argname in args])
268        trace_argnames = ""
269        if len(args) > 0:
270            trace_argnames = ", " + argnames
271        trace_diagnostic = ""
272        if os.getenv('TRACE_DIAGNOSTICS'):
273            trace_diagnostic = f"#warning Tracing {func_name}"
274        if func_type != "void":
275            wrap += syscall_tracer_with_return_template.format(func_type=func_type, func_name=func_name,
276                                                               argnames=argnames, trace_argnames=trace_argnames,
277                                                               syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
278        else:
279            wrap += syscall_tracer_void_template.format(func_type=func_type, func_name=func_name,
280                                                        argnames=argnames, trace_argnames=trace_argnames,
281                                                        syscall_id=syscall_id, trace_diagnostic=trace_diagnostic)
282
283    return wrap
284
285# Returns an expression for the specified (zero-indexed!) marshalled
286# parameter to a syscall, with handling for a final "more" parameter.
287def mrsh_rval(mrsh_num, total):
288    if mrsh_num < 5 or total <= 6:
289        return "arg%d" % mrsh_num
290    else:
291        return "(((uintptr_t *)more)[%d])" % (mrsh_num - 5)
292
293def marshall_defs(func_name, func_type, args):
294    mrsh_name = "z_mrsh_" + func_name
295
296    nmrsh = 0        # number of marshalled uintptr_t parameter
297    vrfy_parms = []  # list of (argtype, bool_is_split)
298    for (argtype, _) in args:
299        split = need_split(argtype)
300        vrfy_parms.append((argtype, split))
301        nmrsh += 2 if split else 1
302
303    # Final argument for a 64 bit return value?
304    if need_split(func_type):
305        nmrsh += 1
306
307    decl_arglist = ", ".join([" ".join(argrec) for argrec in args])
308    mrsh = "extern %s z_vrfy_%s(%s);\n" % (func_type, func_name, decl_arglist)
309
310    mrsh += "uintptr_t %s(uintptr_t arg0, uintptr_t arg1, uintptr_t arg2,\n" % mrsh_name
311    if nmrsh <= 6:
312        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, void *ssf)\n"
313    else:
314        mrsh += "\t\t" + "uintptr_t arg3, uintptr_t arg4, void *more, void *ssf)\n"
315    mrsh += "{\n"
316    mrsh += "\t" + "_current->syscall_frame = ssf;\n"
317
318    for unused_arg in range(nmrsh, 6):
319        mrsh += "\t(void) arg%d;\t/* unused */\n" % unused_arg
320
321    if nmrsh > 6:
322        mrsh += ("\tZ_OOPS(Z_SYSCALL_MEMORY_READ(more, "
323                 + str(nmrsh - 5) + " * sizeof(uintptr_t)));\n")
324
325    argnum = 0
326    for i, (argtype, split) in enumerate(vrfy_parms):
327        mrsh += "\t%s parm%d;\n" % (union_decl(argtype, split), i)
328        if split:
329            mrsh += "\t" + "parm%d.split.lo = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
330            argnum += 1
331            mrsh += "\t" + "parm%d.split.hi = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
332        else:
333            mrsh += "\t" + "parm%d.x = %s;\n" % (i, mrsh_rval(argnum, nmrsh))
334        argnum += 1
335
336    # Finally, invoke the verify function
337    out_args = ", ".join(["parm%d.val" % i for i in range(len(args))])
338    vrfy_call = "z_vrfy_%s(%s)" % (func_name, out_args)
339
340    if func_type == "void":
341        mrsh += "\t" + "%s;\n" % vrfy_call
342        mrsh += "\t" + "_current->syscall_frame = NULL;\n"
343        mrsh += "\t" + "return 0;\n"
344    else:
345        mrsh += "\t" + "%s ret = %s;\n" % (func_type, vrfy_call)
346
347        if need_split(func_type):
348            ptr = "((uint64_t *)%s)" % mrsh_rval(nmrsh - 1, nmrsh)
349            mrsh += "\t" + "Z_OOPS(Z_SYSCALL_MEMORY_WRITE(%s, 8));\n" % ptr
350            mrsh += "\t" + "*%s = ret;\n" % ptr
351            mrsh += "\t" + "_current->syscall_frame = NULL;\n"
352            mrsh += "\t" + "return 0;\n"
353        else:
354            mrsh += "\t" + "_current->syscall_frame = NULL;\n"
355            mrsh += "\t" + "return (uintptr_t) ret;\n"
356
357    mrsh += "}\n"
358
359    return mrsh, mrsh_name
360
361def analyze_fn(match_group, fn):
362    func, args = match_group
363
364    try:
365        if args == "void":
366            args = []
367        else:
368            args = [typename_split(a.strip()) for a in args.split(",")]
369
370        func_type, func_name = typename_split(func)
371    except SyscallParseException:
372        sys.stderr.write("In declaration of %s\n" % func)
373        raise
374
375    sys_id = "K_SYSCALL_" + func_name.upper()
376
377    marshaller = None
378    marshaller, handler = marshall_defs(func_name, func_type, args)
379    invocation = wrapper_defs(func_name, func_type, args, fn)
380
381    # Entry in _k_syscall_table
382    table_entry = "[%s] = %s" % (sys_id, handler)
383
384    return (handler, invocation, marshaller, sys_id, table_entry)
385
386def parse_args():
387    global args
388    parser = argparse.ArgumentParser(
389        description=__doc__,
390        formatter_class=argparse.RawDescriptionHelpFormatter, allow_abbrev=False)
391
392    parser.add_argument("-i", "--json-file", required=True,
393                        help="Read syscall information from json file")
394    parser.add_argument("-d", "--syscall-dispatch", required=True,
395                        help="output C system call dispatch table file")
396    parser.add_argument("-l", "--syscall-list", required=True,
397                        help="output C system call list header")
398    parser.add_argument("-o", "--base-output", required=True,
399                        help="Base output directory for syscall macro headers")
400    parser.add_argument("-s", "--split-type", action="append",
401                        help="A long type that must be split/marshalled on 32-bit systems")
402    parser.add_argument("-x", "--long-registers", action="store_true",
403                        help="Indicates we are on system with 64-bit registers")
404    parser.add_argument("--gen-mrsh-files", action="store_true",
405                        help="Generate marshalling files (*_mrsh.c)")
406    args = parser.parse_args()
407
408
409def main():
410    parse_args()
411
412    if args.split_type is not None:
413        for t in args.split_type:
414            types64.append(t)
415
416    with open(args.json_file, 'r') as fd:
417        syscalls = json.load(fd)
418
419    invocations = {}
420    mrsh_defs = {}
421    mrsh_includes = {}
422    ids = []
423    table_entries = []
424    handlers = []
425
426    for match_group, fn in syscalls:
427        handler, inv, mrsh, sys_id, entry = analyze_fn(match_group, fn)
428
429        if fn not in invocations:
430            invocations[fn] = []
431
432        invocations[fn].append(inv)
433        ids.append(sys_id)
434        table_entries.append(entry)
435        handlers.append(handler)
436
437        if mrsh:
438            syscall = typename_split(match_group[0])[1]
439            mrsh_defs[syscall] = mrsh
440            mrsh_includes[syscall] = "#include <syscalls/%s>" % fn
441
442    with open(args.syscall_dispatch, "w") as fp:
443        table_entries.append("[K_SYSCALL_BAD] = handler_bad_syscall")
444
445        weak_defines = "".join([weak_template % name
446                                for name in handlers
447                                if not name in noweak])
448
449        # The "noweak" ones just get a regular declaration
450        weak_defines += "\n".join(["extern uintptr_t %s(uintptr_t arg1, uintptr_t arg2, uintptr_t arg3, uintptr_t arg4, uintptr_t arg5, uintptr_t arg6, void *ssf);"
451                                   % s for s in noweak])
452
453        fp.write(table_template % (weak_defines,
454                                   ",\n\t".join(table_entries)))
455
456    # Listing header emitted to stdout
457    ids.sort()
458    ids.extend(["K_SYSCALL_BAD", "K_SYSCALL_LIMIT"])
459
460    ids_as_defines = ""
461    for i, item in enumerate(ids):
462        ids_as_defines += "#define {} {}\n".format(item, i)
463
464    with open(args.syscall_list, "w") as fp:
465        fp.write(list_template % ids_as_defines)
466
467    os.makedirs(args.base_output, exist_ok=True)
468    for fn, invo_list in invocations.items():
469        out_fn = os.path.join(args.base_output, fn)
470
471        ig = re.sub("[^a-zA-Z0-9]", "_", "Z_INCLUDE_SYSCALLS_" + fn).upper()
472        include_guard = "#ifndef %s\n#define %s\n" % (ig, ig)
473        tracing_include = ""
474        if fn not in notracing:
475            tracing_include = "#include <zephyr/tracing/tracing_syscall.h>"
476        header = syscall_template.format(include_guard=include_guard, tracing_include=tracing_include, invocations="\n\n".join(invo_list))
477
478        with open(out_fn, "w") as fp:
479            fp.write(header)
480
481    # Likewise emit _mrsh.c files for syscall inclusion
482    if args.gen_mrsh_files:
483        for fn in mrsh_defs:
484            mrsh_fn = os.path.join(args.base_output, fn + "_mrsh.c")
485
486            with open(mrsh_fn, "w") as fp:
487                fp.write("/* auto-generated by gen_syscalls.py, don't edit */\n\n")
488                fp.write(mrsh_includes[fn] + "\n")
489                fp.write("\n")
490                fp.write(mrsh_defs[fn] + "\n")
491
492if __name__ == "__main__":
493    main()
494