1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# Copyright (c) 2021 Intel Corporation
4
5import os
6import sh
7import argparse
8import re
9from unidiff import PatchSet
10
11if "ZEPHYR_BASE" not in os.environ:
12    exit("$ZEPHYR_BASE environment variable undefined.")
13
14RESERVED_NAMES_SCRIPT = "/scripts/coccinelle/reserved_names.cocci"
15
16coccinelle_scripts = [RESERVED_NAMES_SCRIPT,
17                      "/scripts/coccinelle/same_identifier.cocci",
18                      #"/scripts/coccinelle/identifier_length.cocci",
19                      ]
20
21coccinelle_reserved_names_exclude_regex = [
22    r"lib/libc/.*",
23    r"lib/posix/.*",
24    r"include/zephyr/posix/.*",
25]
26
27def parse_coccinelle(contents: str, violations: dict):
28    reg = re.compile("([a-zA-Z0-9_/]*\\.[ch]:[0-9]*)(:[0-9\\-]*: )(.*)")
29    for line in contents.split("\n"):
30        r = reg.match(line)
31        if r:
32            f = r.group(1)
33            if f in violations:
34                violations[f].append(r.group(3))
35            else:
36                violations[r.group(1)] = [r.group(3)]
37
38
39def parse_args():
40    parser = argparse.ArgumentParser(
41        description="Check commits against Cocccinelle rules", allow_abbrev=False)
42    parser.add_argument('-r', "--repository", required=False,
43                        help="Path to repository")
44    parser.add_argument('-c', '--commits', default=None,
45                        help="Commit range in the form: a..b")
46    parser.add_argument("-o", "--output", required=False,
47                        help="Print violation into a file")
48    return parser.parse_args()
49
50
51def main():
52    args = parse_args()
53    if not args.commits:
54        exit("missing commit range")
55
56    if args.repository is None:
57        repository_path = os.environ['ZEPHYR_BASE']
58    else:
59        repository_path = args.repository
60
61    sh_special_args = {
62        '_tty_out': False,
63        '_cwd': repository_path
64    }
65
66    # pylint does not like the 'sh' library
67    # pylint: disable=too-many-function-args,unexpected-keyword-arg
68    commit = sh.git("diff", args.commits, **sh_special_args)
69    patch_set = PatchSet(commit)
70    zephyr_base = os.getenv("ZEPHYR_BASE")
71    violations = {}
72    numViolations = 0
73
74    for f in patch_set:
75        if not f.path.endswith(".c") and not f.path.endswith(".h") or not os.path.exists(zephyr_base + "/" + f.path):
76            continue
77
78        for script in coccinelle_scripts:
79
80            skip_reserved_names = False
81            if script == RESERVED_NAMES_SCRIPT:
82                for path in coccinelle_reserved_names_exclude_regex:
83                    if re.match(path, f.path):
84                        skip_reserved_names = True
85                        break
86
87            if skip_reserved_names:
88                continue
89
90            script_path =zephyr_base + "/" + script
91            print(f"Running {script} on {f.path}")
92            try:
93                cocci = sh.coccicheck(
94                    "--mode=report",
95                    "--cocci=" +
96                    script_path,
97                    f.path,
98                    _timeout=10,
99                    **sh_special_args)
100                parse_coccinelle(cocci, violations)
101            except sh.TimeoutException:
102                print("we timed out waiting, skipping...")
103
104        for hunk in f:
105            for line in hunk:
106                if line.is_added:
107                    violation = "{}:{}".format(f.path, line.target_line_no)
108                    if violation in violations:
109                        numViolations += 1
110                        if args.output:
111                            with open(args.output, "a+") as fp:
112                                fp.write("{}:{}\n".format(
113                                    violation, "\t\n".join(
114                                        violations[violation])))
115                        else:
116                            print(
117                                "{}:{}".format(
118                                    violation, "\t\n".join(
119                                        violations[violation])))
120
121    return numViolations
122
123
124if __name__ == "__main__":
125    ret = main()
126    exit(ret)
127