1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# Copyright (c) 2021 Intel Corporation
4
5import argparse
6import os
7import re
8
9import sh
10from unidiff import PatchSet
11
12if "ZEPHYR_BASE" not in os.environ:
13    exit("$ZEPHYR_BASE environment variable undefined.")
14
15RESERVED_NAMES_SCRIPT = "/scripts/coccinelle/reserved_names.cocci"
16
17coccinelle_scripts = [
18    RESERVED_NAMES_SCRIPT,
19    "/scripts/coccinelle/same_identifier.cocci",
20    # "/scripts/coccinelle/identifier_length.cocci",
21]
22
23coccinelle_reserved_names_exclude_regex = [
24    r"lib/libc/.*",
25    r"lib/posix/.*",
26    r"include/zephyr/posix/.*",
27]
28
29
30def parse_coccinelle(contents: str, violations: dict):
31    reg = re.compile("([a-zA-Z0-9_/]*\\.[ch]:[0-9]*)(:[0-9\\-]*: )(.*)")
32    for line in contents.split("\n"):
33        r = reg.match(line)
34        if r:
35            f = r.group(1)
36            if f in violations:
37                violations[f].append(r.group(3))
38            else:
39                violations[r.group(1)] = [r.group(3)]
40
41
42def parse_args():
43    parser = argparse.ArgumentParser(
44        description="Check commits against Cocccinelle rules", allow_abbrev=False
45    )
46    parser.add_argument('-r', "--repository", required=False, help="Path to repository")
47    parser.add_argument('-c', '--commits', default=None, help="Commit range in the form: a..b")
48    parser.add_argument("-o", "--output", required=False, help="Print violation into a file")
49    return parser.parse_args()
50
51
52def main():
53    args = parse_args()
54    if not args.commits:
55        exit("missing commit range")
56
57    if args.repository is None:
58        repository_path = os.environ['ZEPHYR_BASE']
59    else:
60        repository_path = args.repository
61
62    sh_special_args = {'_tty_out': False, '_cwd': repository_path}
63
64    # pylint does not like the 'sh' library
65    # pylint: disable=too-many-function-args,unexpected-keyword-arg
66    commit = sh.git("diff", args.commits, **sh_special_args)
67    patch_set = PatchSet(commit)
68    zephyr_base = os.getenv("ZEPHYR_BASE")
69    violations = {}
70    numViolations = 0
71
72    for f in patch_set:
73        c_file = f.path.endswith(".c")
74        h_file = f.path.endswith(".h")
75        exists = os.path.exists(zephyr_base + "/" + f.path)
76        if not c_file and not h_file or not exists:
77            continue
78
79        for script in coccinelle_scripts:
80            skip_reserved_names = False
81            if script == RESERVED_NAMES_SCRIPT:
82                for path in coccinelle_reserved_names_exclude_regex:
83                    if re.match(path, f.path):
84                        skip_reserved_names = True
85                        break
86
87            if skip_reserved_names:
88                continue
89
90            script_path = zephyr_base + "/" + script
91            print(f"Running {script} on {f.path}")
92            try:
93                cocci = sh.coccicheck(
94                    "--mode=report",
95                    "--cocci=" + script_path,
96                    f.path,
97                    _timeout=10,
98                    **sh_special_args,
99                )
100                parse_coccinelle(cocci, violations)
101            except sh.TimeoutException:
102                print("we timed out waiting, skipping...")
103
104        for hunk in f:
105            for line in hunk:
106                if line.is_added:
107                    violation = f"{f.path}:{line.target_line_no}"
108                    if violation in violations:
109                        v_str = "\t\n".join(violations[violation])
110                        out_str = f"{violation}:{v_str}"
111                        numViolations += 1
112                        if args.output:
113                            with open(args.output, "a+") as fp:
114                                fp.write(f"{out_str}\n")
115                        else:
116                            print(out_str)
117
118    return numViolations
119
120
121if __name__ == "__main__":
122    ret = main()
123    exit(ret)
124